请用torch编写一个卷积神经网络模型,用于图像8分类问题,并用summary函数输出每一层的形状。已知输入网络形状大小为(10,3,256,256)
时间: 2024-10-25 07:12:22 浏览: 35
在PyTorch中,我们可以使用`torch.nn`模块构建一个卷积神经网络(CNN)模型。这里是一个简单的例子,创建一个用于8分类图像识别的卷积神经网络结构,假设我们采用的是LeNet-5的架构作为基础:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义卷积神经网络模型
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
# 第一层卷积层
self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 下采样
# 第二层卷积层
self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 再次下采样
# 全连接层
self.fc1 = nn.Linear(16 * 4 * 4, 120) # 将池化后的尺寸计算出来
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 8) # 输出层,对应8分类任务
def forward(self, x):
x = F.relu(self.conv1(x)) # 激活函数
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
# 展平并连接全连接层
x = x.view(-1, self.num_flat_features(x)) # 把张量展平成一维向量
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # 排除batch_size维度
num_features = 1
for s in size:
num_features *= s
return num_features
# 初始化模型并使用summary函数获取模型结构信息
model = ConvNet()
input_shape = (10, 3, 256, 256) # 输入数据形状
with torch.no_grad():
print(model(torch.randn(input_shape)).shape) # 验证形状
print(nn.utils.summary.summary(model, input_shape))
#
阅读全文