PyTorch 机器学习代码
时间: 2024-01-15 21:04:02 浏览: 94
PyTorch 是一个深度学习框架,可以用来开发各种机器学习模型。下面是一个简单的 PyTorch 机器学习代码示例,用于训练一个简单的线性回归模型。
```
import torch
import torch.nn as nn
import torch.optim as optim
# 定义数据集
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]])
# 定义模型
model = nn.Linear(1, 1)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(1000):
optimizer.zero_grad()
y_pred = model(x_train)
loss = criterion(y_pred, y_train)
loss.backward()
optimizer.step()
# 预测新数据
x_new = torch.tensor([[5.0]])
y_new = model(x_new)
print(y_new)
```
这个例子中,我们首先定义了一个简单的数据集 x_train 和 y_train,然后使用 nn.Linear 定义了一个线性回归模型。
接着,我们定义了损失函数和优化器,并使用循环来训练模型。每次训练时,我们计算模型的预测值和真实值之间的均方误差,并使用反向传播更新模型参数。
最后,我们使用训练好的模型预测了新的数据 x_new 的输出 y_new。
阅读全文