pytorch之线性回归答案
时间: 2024-09-19 11:07:34 浏览: 51
PyTorch中的线性回归是一个简单的机器学习模型,用于预测连续变量。它通常作为深度学习入门的例子,因为它的计算流程直观且易于理解。以下是使用PyTorch实现线性回归的基本步骤:
1. 导入必要的库:
```python
import torch
import torch.nn as nn
```
2. 定义数据结构和超参数:
```python
# 假设我们有输入特征x和对应的标签y
X = torch.tensor([[1], [2], [3], [4]], dtype=torch.float32)
Y = torch.tensor([[2], [4], [6], [8]], dtype=torch.float32)
learning_rate = 0.01
num_epochs = 1000
```
3. 创建线性模型(nn.Linear):
```python
model = nn.Linear(1, 1) # 输入维度1,输出维度1(假设是单变量线性回归)
```
4. 定义损失函数(通常使用均方误差MSE)和优化器(如SGD):
```python
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
```
5. 训练循环:
```python
for epoch in range(num_epochs):
# 前向传播
y_pred = model(X)
# 求损失
loss = criterion(y_pred, Y)
# 反向传播和优化
optimizer.zero_grad() # 清零梯度
loss.backward() # 计算梯度
optimizer.step() # 更新模型参数
# 打印训练信息(例如每100次迭代打印一次)
if (epoch + 1) % 100 == 0:
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
```
训练完成后,`model`的权重将包含最佳拟合的参数,可以用来对新的输入数据进行预测。
阅读全文