def BN1dBackward(grad_out, normx, varx, eps, w): grad_bias = grad_out.sum(dim=0) grad_weight = (grad_out * normx).sum(dim=0) grad_normx = grad_out * w grad_x = normx[:, 0].numel() * grad_normx - grad_normx.sum(dim=0) \ - (grad_normx * normx).sum(dim=0) * normx grad_x = grad_x / (normx[:, 0].numel() * torch.sqrt(varx + eps)) return grad_x, grad_weight, grad_bias
时间: 2023-11-26 17:02:59 浏览: 25
这段代码实现的是一维批量归一化(Batch Normalization)的反向传播过程。其中,grad_out 是当前层的梯度,normx 是当前层归一化后的输出,varx 是当前层输出的方差,eps 是一个微小的常数,w 是 BN 层的权重。函数的返回值 grad_x 是传递给上一层的梯度,grad_weight 和 grad_bias 是 BN 层的梯度。具体的实现细节请见下面的解释:
grad_bias:对于 BN 层的每个通道,grad_bias 是 grad_out 对这个通道的求和,即相当于对 BN 层的偏置参数进行求导。
grad_weight:对于 BN 层的每个通道,grad_weight 是 grad_out 与当前层归一化后的输出 normx 的乘积对这个通道的求和,即相当于对 BN 层的权重参数进行求导。
grad_normx:grad_normx 是 grad_out 与 BN 层的权重 w 的乘积,即相当于对 BN 层的输入进行求导。
grad_x:grad_x 是传递给上一层的梯度,需要利用 grad_normx 计算得到。其中,grad_normx 需要对每个通道的值进行求和,然后再减去所有通道的加权和。其中的加权和是 grad_normx 与当前层归一化后的输出 normx 的乘积,对每个通道的值进行求和,然后再乘以 normx。最后,grad_x 还需要除以一个标准差,即对 BN 层的输出方差进行求导。
相关问题
class srmLinearFunc(torch.autograd.Function): @staticmethod def forward(ctx, inputs: Tensor, weight: Tensor, bn_weight: Tensor, bn_bias: Tensor, eps: Tensor, v_th: Tensor, taum: float, taus: float, e_taug: float, epsw: Tensor, epst: Tensor, ) -> Tensor: if bn_weight is not None: x, normx, varx = BN1dForward(weight.t(), bn_weight, bn_bias, eps) else: x = weight.t() normx = varx = bn_weight = bn_bias = eps x = inputs.matmul(x) spikes, delta_ut, delta_u = srmNeuronFunc.forward(x, taum, taus, e_taug, v_th) ctx.save_for_backward( inputs, weight, bn_weight, bn_bias, normx, varx, eps, spikes, delta_ut, delta_u, epsw, epst, ) # print('linear', spikes.sum()) return spikes @staticmethod def backward(ctx, grad_out: Tensor) -> List[Optional[Tensor]]: inputs, weight, bn_weight, bn_bias, normx, varx, eps, spikes, delta_ut, delta_u, epsw, epst, = ctx.saved_tensors grad_w, grad_t = srmNeuronFunc.backward(grad_out, delta_ut, delta_u, spikes, epsw, epst) # grad_w: b t dout, weight: dout din, inputs: b t din grad_w = grad_w.transpose(1, 2).matmul(inputs).sum(dim=0) if eps.shape != bn_weight.shape or eps != bn_weight: grad_w, grad_bnw, grad_bnb = BN1dBackward(grad_w.t(), normx, varx, eps, bn_weight) grad_w = grad_w.t() x = (normx * bn_weight + bn_bias).t() else: grad_bnw = None grad_bnb = None x = weight grad_t = torch.matmul(grad_t, x) return grad_t * 0.85, grad_w, grad_bnw, grad_bnb, None, None, None, None, None, None, None
这是一个使用 PyTorch 实现的线性函数,其中包含了一些参数和激活函数。在前向传播过程中,它通过输入和权重计算出输出,然后使用 SRM 神经元激活函数将输出转换为脉冲信号。在反向传播过程中,它根据输出脉冲和一些其他参数计算出相应的梯度。该函数还包括对 Batch Normalization 的支持。
def BN1dForward(inputs, weight, bias, eps=1e-5): meanx = inputs.mean(dim=0) varx = inputs.var(dim=0) normx = (inputs - meanx) / torch.sqrt(varx + eps) return normx * weight + bias, normx, varx
这是一个用于一维批量归一化的前向传播函数,其中:
- `inputs` 是输入数据,形状为 `(batch_size, num_features)`;
- `weight` 和 `bias` 分别是归一化的缩放因子和偏置项,形状均为 `(num_features,)`;
- `eps` 是一个小常数,用于避免除以零的情况;
- `meanx` 和 `varx` 分别是输入数据在当前 batch 的均值和方差,形状均为 `(num_features,)`;
- `normx` 是归一化后的数据,形状与 `inputs` 相同,计算公式为:$$\text{normx}=\frac{\text{inputs}-\text{meanx}}{\sqrt{\text{varx}+\text{eps}}}$$
- 返回值为归一化后的数据、当前 batch 的均值和方差。