torchinfo.summary参数
时间: 2023-12-02 17:37:34 浏览: 37
torchinfo.summary是一个用于打印PyTorch模型详细信息的工具函数。其主要参数如下:
- model:传入要打印信息的PyTorch模型。
- input_size:模型接收的输入大小,可以是单个数字,也可以是元组。
- input_data:如果没有提供input_size,则需要提供一个输入张量作为输入数据。
- device:指定模型运行的设备,默认为"cuda"(如果可用)或"cpu"。
示例代码:
```python
import torch
import torch.nn as nn
import torchvision.models as models
from torchinfo import summary
model = models.resnet18()
summary(model, input_size=(1, 3, 224, 224))
```
输出结果:
```
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ResNet -- --
├─Conv2d: 1-1 [1, 64, 112, 112] 9,408
├─BatchNorm2d: 1-2 [1, 64, 112, 112] 128
├─ReLU: 1-3 [1, 64, 112, 112] --
...
```