线性模型pytorch
时间: 2023-11-07 17:05:31 浏览: 29
线性模型是一种基本的机器学习模型,可以用来建立输入特征和输出标签之间的线性关系。在Pytorch中,可以使用torch.nn.Linear模块来实现线性模型。它接受输入特征,并输出预测结果。
下面是一个使用Pytorch实现线性模型的示例代码:
```python
import torch
import torch.nn as nn
# 定义线性模型
class LinearModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearModel, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
# 创建模型实例
input_dim = 1 # 输入特征维度
output_dim = 1 # 输出标签维度
model = LinearModel(input_dim, output_dim)
# 定义损失函数和优化器
criterion = nn.MSELoss() # 均方损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降优化器
# 准备输入数据
x_data = torch.tensor([[1.0], [2.0], [3.0]]) # 输入特征
y_data = torch.tensor([[2.0], [4.0], [6.0]]) # 输出标签
# 训练模型
epochs = 100 # 迭代次数
for epoch in range(epochs):
# 前向传播
y_pred = model(x_data)
# 计算损失
loss = criterion(y_pred, y_data)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印训练信息
if (epoch+1) % 10 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
# 使用模型进行预测
x_test = torch.tensor([[4.0]])
y_pred = model(x_test)
print('预测结果:', y_pred.item())
```