写一段pytorch代码
时间: 2023-03-13 13:46:26 浏览: 86
PyTorch常用代码段合集1
我们可以使用PyTorch来实现一个简单的神经网络:import torch
import torch.nn as nn# 定义输入和输出维度
input_size = 5
output_size = 1# 定义一个简单的神经网络
model = nn.Sequential(nn.Linear(input_size, 10),
nn.ReLU(),
nn.Linear(10, output_size))# 设定损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 训练模型
for epoch in range(1000):
# 获取输入数据
inputs = torch.randn(1, input_size)
targets = torch.randn(1, output_size) # 前向传播
outputs = model(inputs)
loss = criterion(outputs, targets) # 反向传播和参数更新
optimizer.zero_grad()
loss.backward()
optimizer.step()print('训练完成!')
阅读全文