给一个pytorch神经网络基础模板
时间: 2024-10-01 08:04:24 浏览: 4
在PyTorch中创建一个基本的神经网络模型通常涉及以下几个步骤:
```python
import torch
import torch.nn as nn
# 定义网络层
class SimpleNeuralNetwork(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNeuralNetwork, self).__init__()
# 输入层到隐藏层
self.fc1 = nn.Linear(input_size, hidden_size)
# 隐藏层到隐藏层(如果有多个隐藏层)
self.fc2 = nn.Linear(hidden_size, hidden_size)
# 最后一层,隐藏层到输出层
self.fc_out = nn.Linear(hidden_size, output_size)
def forward(self, x):
# 前向传播过程
x = torch.relu(self.fc1(x)) # 使用ReLU激活函数
x = torch.relu(self.fc2(x))
x = self.fc_out(x) # 输出前不需要激活,直接输出
return x
# 使用示例
input_size = 784 # 假设输入图片大小为28x28,所以是784
hidden_size = 128
output_size = 10 # 假设是10分类任务(如MNIST)
model = SimpleNeuralNetwork(input_size, hidden_size, output_size)
```
这个模板定义了一个包含两层全连接(fc)的简单神经网络,每个全连接层后面都跟着ReLU激活函数。`forward()`方法描述了网络的计算流程。