pytorch实现lenet-5
时间: 2023-12-13 07:04:16 浏览: 33
你可以使用PyTorch来实现LeNet-5模型,以下是一个简单的示例代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
self.fc1 = nn.Linear(16*4*4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 16*4*4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 创建LeNet-5模型实例
model = LeNet5()
# 打印模型结构
print(model)
```
这段代码定义了一个LeNet5类,继承自`nn.Module`,并在构造函数`__init__`中定义了LeNet-5的网络结构。`forward`函数定义了数据在网络中的前向传播路径。
你可以根据自己的数据和任务需求,进一步扩展和训练该模型。