pytorch cnn
时间: 2023-09-12 17:03:26 浏览: 88
pytorch-cnn
PyTorch中的卷积神经网络(Convolutional Neural Networks,CNN)通常使用torch.nn.Conv2d模块实现。Conv2d模块需要输入四个参数:输入通道数(in_channels),输出通道数(out_channels),卷积核大小(kernel_size)和步长(stride)。以下是一个简单的CNN模型的例子:
```python
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1)
self.fc1 = nn.Linear(64 * 6 * 6, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = self.pool(x)
x = nn.functional.relu(self.conv2(x))
x = self.pool(x)
x = x.view(-1, 64 * 6 * 6)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
```
上面的代码定义了一个具有两个卷积层和两个全连接层的CNN模型。在forward函数中,我们首先使用ReLU激活函数对卷积层的输出进行激活,然后使用最大池化层对特征图进行下采样。最后,我们将特征图展平为向量,并通过两个全连接层进行分类。
阅读全文