可视化pytorch网络结构
时间: 2023-11-03 12:05:05 浏览: 96
Python Unet网络结构pytorch简单实现+torchsummary可视化(可以直接运行)
5星 · 资源好评率100%
可以使用PyTorch提供的`torchsummary`库来可视化PyTorch网络结构。以下是一个示例代码:
```python
import torch
import torch.nn as nn
from torchsummary import summary
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(64*8*8, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.pool2(x)
x = x.view(-1, 64*8*8)
x = self.fc1(x)
x = self.relu3(x)
x = self.fc2(x)
return x
model = MyModel()
summary(model, (3, 32, 32))
```
在这个例子中,我们定义了`MyModel`类作为我们的模型,并使用`torchsummary`的`summary`函数来打印出模型的结构和参数数量。在`summary`函数中,我们需要传入模型实例和输入数据的形状。
运行上述代码,可以得到以下输出:
```
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 32, 32, 32] 896
BatchNorm2d-2 [-1, 32, 32, 32] 64
ReLU-3 [-1, 32, 32, 32] 0
MaxPool2d-4 [-1, 32, 16, 16] 0
Conv2d-5 [-1, 64, 16, 16] 18,496
BatchNorm2d-6 [-1, 64, 16, 16] 128
ReLU-7 [-1, 64, 16, 16] 0
MaxPool2d-8 [-1, 64, 8, 8] 0
Linear-9 [-1, 128] 524,416
ReLU-10 [-1, 128] 0
Linear-11 [-1, 10] 1,290
================================================================
Total params: 545,290
Trainable params: 545,290
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1.50
Params size (MB): 2.08
Estimated Total Size (MB): 3.59
----------------------------------------------------------------
```
可以看到,我们成功地输出了模型的结构和参数数量。
阅读全文