net初始化线性回归模型
时间: 2023-12-20 16:07:00 浏览: 29
下面是使用PyTorch框架初始化一个简单的线性回归模型的示例代码:
```python
import torch.nn as nn
class LinearRegression(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
out = self.linear(x)
return out
```
在这个示例中,我们定义了一个名为`LinearRegression`的类,它继承自`nn.Module`类。这个类的构造函数中定义了一个`nn.Linear`模块,它将输入维度`input_dim`映射到输出维度`output_dim`。`forward`函数实现了模型的前向传播,其中`x`是输入的数据。
使用这个模型时,我们可以将输入数据`x`传递给模型进行预测:
```python
# 初始化模型
model = LinearRegression(input_dim=1, output_dim=1)
# 输入数据
x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
# 预测
y_pred = model(x)
print(y_pred)
```
输出:
```
tensor([[-0.8445],
[-1.7377],
[-2.6309],
[-3.5242]], grad_fn=<AddmmBackward>)
```
这个示例中,我们使用一个一维的输入数据`x`,并将它传递给模型进行预测。模型输出的结果也是一个一维的张量,表示模型对输入数据的预测。