用pytorch实现多层感知机
时间: 2024-05-02 07:14:45 浏览: 31
使用PyTorch实现多层感知机(Multilayer Perceptron, MLP)是一种常见的神经网络模型。下面是一个简单的示例代码:
```python
import torch
import torch.nn as nn
# 定义多层感知机模型
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size) # 第一层全连接层
self.relu = nn.ReLU() # 激活函数
self.fc2 = nn.Linear(hidden_size, output_size) # 第二层全连接层
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
# 创建模型实例
input_size = 784 # 输入大小,MNIST数据集的图片大小为28x28=784
hidden_size = 256 # 隐藏层大小
output_size = 10 # 输出大小,MNIST数据集有10个类别
model = MLP(input_size, hidden_size, output_size)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降优化器
# 训练模型
num_epochs = 10 # 训练轮数
for epoch in range(num_epochs):
for images, labels in train_loader: # 遍历训练数据集
images = images.reshape(-1, 28*28) # 将图片展平为向量
outputs = model(images) # 前向传播
loss = criterion(outputs, labels) # 计算损失
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传播
optimizer.step() # 更新参数
# 使用模型进行预测
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader: # 遍历测试数据集
images = images.reshape(-1, 28*28)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1) # 获取预测结果
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
print('Accuracy: {:.2f}%'.format(accuracy * 100))
```
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)