pytorch 卷积神经网络模板
时间: 2023-09-28 14:12:18 浏览: 96
pytorch_pytorch_卷积神经网络_
5星 · 资源好评率100%
下面是一个基本的 PyTorch 卷积神经网络模板:
```python
import torch
import torch.nn as nn
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
# 定义卷积层
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
# 定义池化层
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 定义全连接层
self.fc1 = nn.Linear(128 * 8 * 8, 1024)
self.fc2 = nn.Linear(1024, 10)
# 定义激活函数
self.relu = nn.ReLU()
def forward(self, x):
# 前向传播
x = self.relu(self.conv1(x))
x = self.pool(x)
x = self.relu(self.conv2(x))
x = self.pool(x)
x = self.relu(self.conv3(x))
x = self.pool(x)
x = x.view(-1, 128 * 8 * 8)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
```
这个模板包含了卷积层、池化层、全连接层和激活函数。可以根据自己的需求修改模型结构。
阅读全文