CNN平均池化层代码
时间: 2024-04-07 14:25:58 浏览: 69
6. 全连接层与全局平均池化1
CNN(卷积神经网络)中的平均池化层用于减少特征图的尺寸,并提取出主要特征。以下是一个简单的CNN平均池化层的代码示例:
```python
import torch
import torch.nn as nn
# 定义一个简单的CNN模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = self.conv1(x)
x = self.pool(x)
return x
# 创建一个CNN模型实例
model = CNN()
# 输入数据
input_data = torch.randn(1, 3, 32, 32) # 假设输入数据为3通道的32x32图像
# 前向传播
output = model(input_data)
print(output.size()) # 打印输出特征图的尺寸
```
在上述代码中,我们首先定义了一个简单的CNN模型,其中包含一个卷积层(`conv1`)和一个平均池化层(`pool`)。在模型的前向传播过程中,输入数据经过卷积层后,再经过平均池化层进行特征提取和尺寸缩减。最后,我们打印出输出特征图的尺寸。
阅读全文