pytorch 训练代码
时间: 2024-09-22 09:06:27 浏览: 39
pytorch训练代码
PyTorch 是一个开源的深度学习库,它提供了一个动态计算图模型,适合快速实验和原型设计。下面是一个简单的 PyTorch 训练神经网络的代码示例,假设我们有一个线性回归任务:
```python
import torch
import torch.nn as nn
import torch.optim as optim
# 定义网络结构(线性模型)
class LinearRegression(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
# 创建模型、输入和目标
input_dim = 10
output_dim = 1
model = LinearRegression(input_dim, output_dim)
x_data = torch.randn(100, input_dim)
y_data = torch.randn(100, output_dim) * 2 # 假设数据有一定的噪声
# 定义损失函数和优化器
criterion = nn.MSELoss() # 用均方误差作为损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01) # 使用随机梯度下降优化器
# 训练过程
for epoch in range(100): # 进行多次迭代(一轮训练)
# 前向传播
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
# 反向传播和优化
optimizer.zero_grad() # 清零梯度
loss.backward()
optimizer.step()
# 打印训练信息
if (epoch+1) % 10 == 0:
print(f"Epoch [{epoch+1}/{100}], Loss: {loss.item():.4f}")
#
阅读全文