torch 权重初始化
时间: 2023-09-05 22:08:40 浏览: 181
在 PyTorch 中,权重初始化可以通过 nn.Module 的构造函数中的参数进行设置。常用的权重初始化方法有以下几种:
1. 随机初始化:使用 PyTorch 自带的 nn.init 模块中的函数进行随机初始化,如nn.init.xavier_uniform_()、nn.init.xavier_normal_()、nn.init.kaiming_uniform_()、nn.init.kaiming_normal_()等。
2. 预训练初始化:如果已经有预训练的模型,可以使用预训练模型的参数进行初始化。
3. 固定初始化:可以手动设置固定的值来进行初始化,如全零初始化(nn.init.zeros_())或全一初始化(nn.init.ones_())。
4. 自定义初始化:可以根据特定的需求,自定义初始化方法。
示例代码:
```python
import torch.nn as nn
class Net(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.zeros_(self.fc1.bias)
self.fc2 = nn.Linear(hidden_size, output_size)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.zeros_(self.fc2.bias)
def forward(self, x):
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x
```
在上述代码中,我们使用了 nn.init.xavier_uniform_() 函数对权重进行随机初始化,使用 nn.init.zeros_() 函数对偏差进行初始化。
阅读全文