pytorch基础代码
时间: 2023-11-03 18:57:17 浏览: 80
Pytorch-基础入门代码
当然,这里是一个使用 PyTorch 的基础示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.sigmoid(self.fc2(x))
return x
# 创建网络实例和损失函数
net = Net()
criterion = nn.MSELoss()
# 创建优化器
optimizer = optim.SGD(net.parameters(), lr=0.01)
# 准备训练数据
inputs = torch.randn(100, 10)
labels = torch.randn(100, 1)
# 进行训练
for epoch in range(100):
optimizer.zero_grad() # 梯度清零
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward() # 反向传播
optimizer.step() # 更新权重
if (epoch+1) % 10 == 0:
print(f"Epoch {epoch+1}, Loss: {loss.item()}")
# 使用训练好的模型进行预测
test_inputs = torch.randn(10, 10)
predicted = net(test_inputs)
print("Prediction results:")
print(predicted)
```
这段代码演示了一个简单的神经网络的创建、训练和预测过程。首先定义了一个简单的神经网络类 `Net`,然后创建网络实例、损失函数和优化器。接下来,使用随机生成的训练数据进行训练,并在每个 epoch 打印出当前的损失值。最后,使用训练好的模型对新的数据进行预测,并输出预测结果。
请注意,这只是一个简单的示例,实际使用过程中可能会涉及到更复杂的网络结构、更多的训练数据和更复杂的训练过程。
阅读全文