如何进行He初始化 给出代码
时间: 2024-03-09 12:44:12 浏览: 21
He初始化是一种常用的模型参数初始化方法,可以有效地提高神经网络模型的训练效果。具体实现方法如下:
对于某一层的权重矩阵 $W$,假设该层的输入维度为 $n$,输出维度为 $m$,则可以采用如下方式进行He初始化:
```python
import numpy as np
# 初始化权重矩阵W
W = np.random.randn(n, m) * np.sqrt(2.0 / n)
```
其中,`np.random.randn(n, m)` 表示从标准正态分布中采样 $n \times m$ 个样本,`np.sqrt(2.0 / n)` 表示对采样结果进行缩放,使其满足He初始化的要求。
完整的He初始化代码示例如下:
```python
import numpy as np
# 定义神经网络模型
class Net:
def __init__(self, input_size, hidden_size, output_size):
# 初始化输入层到隐藏层的权重矩阵
self.W1 = np.random.randn(input_size, hidden_size) * np.sqrt(2.0 / input_size)
# 初始化隐藏层到输出层的权重矩阵
self.W2 = np.random.randn(hidden_size, output_size) * np.sqrt(2.0 / hidden_size)
# 正向传播
def forward(self, X):
# 计算隐藏层的输出
self.Z1 = np.dot(X, self.W1)
self.A1 = np.tanh(self.Z1)
# 计算输出层的输出
self.Z2 = np.dot(self.A1, self.W2)
self.A2 = np.softmax(self.Z2, axis=1)
return self.A2
# 反向传播
def backward(self, X, y, y_pred, lr):
# 计算输出层的误差
delta2 = y_pred - y
# 计算隐藏层的误差
delta1 = np.dot(delta2, self.W2.T) * (1 - np.power(self.A1, 2))
# 更新权重矩阵
self.W2 -= lr * np.dot(self.A1.T, delta2)
self.W1 -= lr * np.dot(X.T, delta1)
```
在上述代码中,初始化权重矩阵时使用了He初始化方法,可以有效提高模型的训练效果。
相关推荐
![docx](https://img-home.csdnimg.cn/images/20210720083331.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)