Xavier初始化或He初始化代码
时间: 2023-08-06 11:06:48 浏览: 243
Xavier初始化和He初始化是常用的权重初始化方法,适用于不同的神经网络架构。下面分别给出Xavier初始化和He初始化的代码实现(以PyTorch为例):
Xavier初始化:
```python
import math
import torch.nn as nn
def weights_init_xavier(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
std = math.sqrt(2. / n)
m.weight.data.normal_(0, std)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
n = m.in_features
std = math.sqrt(1. / n)
m.weight.data.normal_(0, std)
if m.bias is not None:
m.bias.data.zero_()
model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(128 * 8 * 8, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 10)
)
model.apply(weights_init_xavier)
```
He初始化:
```python
import math
import torch.nn as nn
def weights_init_he(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
std = math.sqrt(2. / n)
m.weight.data.normal_(0, std)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
n = m.in_features
std = math.sqrt(2. / n)
m.weight.data.normal_(0, std)
if m.bias is not None:
m.bias.data.zero_()
model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(128 * 8 * 8, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 10)
)
model.apply(weights_init_he)
```
在以上代码中,`weights_init_xavier`和`weights_init_he`函数分别实现了Xavier和He初始化方法。这两个方法都是在权重张量上进行正态分布采样,其中的标准差由输入和输出张量的维度确定。在卷积层中,标准差由卷积核大小、通道数和输出通道数共同决定;在全连接层中,标准差由输入特征数决定。最后,通过`apply`方法将初始化函数应用到网络的所有层上。
阅读全文