oython batch normalization
时间: 2024-12-31 07:37:06 浏览: 7
### Python 中 Batch Normalization 的实现与使用
Batch Normalization 是一种用于加速神经网络训练的技术,通过规范化每一层输入来稳定学习过程并减少内部协变量偏移。这有助于提高模型性能和收敛速度。
#### PyTorch 实现方式
PyTorch 提供了一个内置模块 `torch.nn.BatchNorm1d`、`torch.nn.BatchNorm2d` 和 `torch.nn.BatchNorm3d` 来处理不同维度的数据。下面是一个简单的例子展示如何在一个全连接层之后应用批量归一化:
```python
import torch
from torch import nn
class Net(nn.Module):
def __init__(self, input_size=784, hidden_size=500, output_size=10):
super(Net, self).__init__()
# 定义线性变换
self.fc1 = nn.Linear(input_size, hidden_size)
# 应用批标准化到隐藏层
self.bn1 = nn.BatchNorm1d(hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
out = self.fc1(x)
out = self.bn1(out) # 批量归一化操作
out = self.relu(out)
out = self.fc2(out)
return out
```
对于卷积神经网络 (CNN),可以采用类似的模式,在 Convolutional Layer 后面加入相应的二维或三维批次正则化层 (`nn.BatchNorm2d`)。
#### 自定义实现
如果想要更深入理解其工作原理,则可以从零开始构建一个简易版本的 batch normalization 函数如下所示:
```python
def batch_norm(X, gamma, beta, is_training, moving_mean=None, moving_var=None, eps=1e-5, momentum=0.9):
if not is_training:
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4), "仅支持FC层和Conv层"
mu = X.mean(dim=(0,) if len(X.shape)==2 else (0,2,3))
var = ((X-mu)**2).mean(dim=(0,) if len(X.shape)==2 else (0,2,3))
X_hat = (X - mu.reshape((1,-1)+(1,)*len(X.shape[2:]))) \
/ torch.sqrt(var+eps).reshape((1,-1)+(1,)*len(X.shape[2:]))
if moving_mean is None or moving_var is None:
moving_mean = mu.detach().clone()
moving_var = var.detach().clone()
else:
moving_mean.data[:] = momentum * moving_mean + (1.-momentum) * mu
moving_var.data[:] = momentum * moving_var + (1.-momention) * var
Y = gamma * X_hat + beta
return Y, moving_mean, moving_var
```
此函数接受输入张量 `X`, 缩放参数 `gamma`, 偏置项 `beta`, 训练标志位 `is_training`. 如果是在测试阶段(`is_training=False`), 则会利用之前保存下来的整体均值和方差来进行计算;而在训练过程中(`is_training=True`), 将实时更新这些统计信息.
阅读全文