class LeNet5(nn.Module): def __init__(self): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(1, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.conv3=nn.Conv2d(16,32,5) self.fc1 = nn.Linear(32 * 4 * 4, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = torch.relu(self.conv1(x)) x = torch.max_pool2d(x, 2) x = torch.relu(self.conv2(x)) x = torch.max_pool2d(x, 2) x = torch.relu(self.conv3(x)) x = torch.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x
时间: 2023-10-20 12:39:25 浏览: 328
LeNet5_LeNet5
5星 · 资源好评率100%
这是一个使用 PyTorch 实现的 LeNet5 模型。LeNet5 是一个经典的卷积神经网络模型,用于识别手写数字图像。这个模型包含了两个卷积层和三个全连接层。在 forward 函数中,输入数据先经过卷积层和池化层的处理,然后通过全连接层进行分类。最后返回分类结果。其中,Conv2d 表示卷积层,Linear 表示全连接层,ReLU 表示激活函数,max_pool2d 表示池化层,view 表示将张量变形。
阅读全文