写一段使用pytorch的cnn代码
时间: 2024-05-01 12:20:37 浏览: 92
pytorch-cnn
以下是一个使用PyTorch实现卷积神经网络的示例代码:
```python
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.fc1 = nn.Linear(in_features=7*7*32, out_features=128)
self.fc2 = nn.Linear(in_features=128, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = self.pool2(x)
x = x.view(-1, 7*7*32) # flatten
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x
```
在这个示例中,我们定义了一个简单的卷积神经网络,它包含两个卷积层和两个全连接层。在定义网络时,我们使用了PyTorch的`nn.Module`基类,并在构造函数中定义了每一层的结构。在`forward`函数中,我们定义了数据在网络中的传递过程,包括卷积、池化、激活和全连接层。最终,我们返回网络输出。
阅读全文