pytorch初始化模型参数
时间: 2023-07-04 11:16:46 浏览: 122
在 PyTorch 中,我们可以使用模型类的 `__init__` 方法来初始化模型参数。以下是一个简单的例子:
```python
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 初始化模型参数
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 2)
def forward(self, x):
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x
```
在上面的例子中,我们定义了一个名为 `MyModel` 的模型类,并在其 `__init__` 方法中初始化了两个全连接层 `fc1` 和 `fc2` 的参数。这里我们使用了 PyTorch 中的 `nn.Linear` 类来创建全连接层,并将其作为模型类的成员变量。在模型的前向传播方法 `forward` 中,我们通过 `self.fc1` 和 `self.fc2` 来调用这两个全连接层,从而实现了模型的前向计算。
当我们使用这个模型时,可以直接通过创建一个 `MyModel` 的实例来获得一个已经初始化好的模型。
```python
model = MyModel()
# 模型已经初始化,可以进行前向计算
x = torch.randn(1, 10)
y = model(x)
print(y)
```
阅读全文