pytorch 线性回归代码
时间: 2024-09-15 10:07:28 浏览: 45
PyTorch线性回归和逻辑回归实战示例
PyTorch是一个强大的深度学习框架,线性回归是机器学习中最基础的模型之一。下面是一个简单的例子,展示如何在PyTorch中实现线性回归:
```python
import torch
from torch import nn
# 定义模型 (只包含一个线性层)
class LinearRegressionModel(nn.Module):
def __init__(self, input_dim):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(input_dim, 1) # 输入维度到单个输出的线性层
def forward(self, x):
return self.linear(x)
# 初始化模型和优化器
input_dim = 10 # 假设输入特征有10维
model = LinearRegressionModel(input_dim)
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# 假设我们有一个训练数据集 X_train 和对应的标签 y_train
X_train = torch.randn(100, input_dim) # 随机生成100个样本
y_train = torch.randn(100, 1) # 目标值也是随机的,假设是一维的
# 定义损失函数 (均方误差)
criterion = nn.MSELoss()
for epoch in range(500): # 迭代次数
# 前向传播
predictions = model(X_train)
# 计算损失
loss = criterion(predictions, y_train)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 100 == 0:
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
阅读全文