ANN代码pytorch
时间: 2025-01-07 13:06:24 浏览: 7
### 如何用 PyTorch 实现人工神经网络 (ANN)
#### 创建一个简单的多层感知机
为了创建一个多层感知机(MLP),可以定义一个继承自 `nn.Module` 的类。这个类会初始化模型的各个层次,并定义前向传播的过程。
```python
import torch
from torch import nn, optim
import torch.nn.functional as F
class SimpleMLP(nn.Module):
def __init__(self, input_size=784, hidden_sizes=[512, 256], output_size=10):
super(SimpleMLP, self).__init__()
# 定义隐藏层和输出层
self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_sizes[0])])
layer_sizes = zip(hidden_sizes[:-1], hidden_sizes[1:])
self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])
self.output_layer = nn.Linear(hidden_sizes[-1], output_size)
def forward(self, x):
# 将输入展平成一维张量
x = x.view(x.shape[0], -1)
# 前向传递通过每一层并应用ReLU激活函数
for linear in self.hidden_layers:
x = F.relu(linear(x))
# 应用于最后一层线性变换后的log softmax作为输出
x = F.log_softmax(self.output_layer(x), dim=1)
return x
```
这段代码展示了如何利用 PyTorch 构建一个基本的 MLP 结构,该结构由两个全连接层组成,每层之后都跟随着 ReLU 激活函数[^3]。
#### 训练过程概述
一旦定义好了模型架构,下一步就是编写训练循环来优化参数。这通常涉及到以下几个方面:
- **损失函数**: 对于分类任务来说,交叉熵是一个常见的选择。
- **优化器**: Adam 是一种广泛使用的梯度下降算法变体。
- **数据加载器**: 使用 DataLoader 来批量处理数据集中的样本。
下面是一些辅助性的 Python 函数片段用来完成上述工作:
```python
def train(model, device, train_loader, optimizer, epoch):
model.train()
running_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad() # 清除之前的梯度
output = model(data) # 获取预测结果
loss = F.nll_loss(output, target) # 计算负对数似然损失
loss.backward() # 反向传播求导
optimizer.step() # 更新权重
running_loss += loss.item()
if batch_idx % log_interval == 0 and batch_idx != 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {running_loss/log_interval:.6f}')
running_loss = 0.0
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
```
这些函数实现了标准的小批次随机梯度下降法来进行监督学习的任务,在每次迭代过程中调整模型参数以最小化给定目标函数下的误差[^1]。
阅读全文