pytorch打印网络结构
时间: 2023-09-14 14:03:10 浏览: 56
要打印出PyTorch中的网络结构,可以使用以下代码:
```python
import torch
import torch.nn as nn
# 定义一个简单的网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc = nn.Linear(32 * 16 * 16, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 创建网络实例
net = Net()
# 打印网络结构
print(net)
```
这段代码定义了一个简单的卷积神经网络,并通过打印`net`来显示网络结构。注意,在此示例中,我们定义了一个包含卷积层、ReLU激活函数、池化层和全连接层的简单网络。