class srmLinearFunc(torch.autograd.Function): @staticmethod def forward(ctx, inputs: Tensor, weight: Tensor, bn_weight: Tensor, bn_bias: Tensor, eps: Tensor, v_th: Tensor, taum: float, taus: float, e_taug: float, epsw: Tensor, epst: Tensor, ) -> Tensor: if bn_weight is not None: x, normx, varx = BN1dForward(weight.t(), bn_weight, bn_bias, eps) else: x = weight.t() normx = varx = bn_weight = bn_bias = eps x = inputs.matmul(x) spikes, delta_ut, delta_u = srmNeuronFunc.forward(x, taum, taus, e_taug, v_th) ctx.save_for_backward( inputs, weight, bn_weight, bn_bias, normx, varx, eps, spikes, delta_ut, delta_u, epsw, epst, ) # print('linear', spikes.sum()) return spikes @staticmethod def backward(ctx, grad_out: Tensor) -> List[Optional[Tensor]]: inputs, weight, bn_weight, bn_bias, normx, varx, eps, spikes, delta_ut, delta_u, epsw, epst, = ctx.saved_tensors grad_w, grad_t = srmNeuronFunc.backward(grad_out, delta_ut, delta_u, spikes, epsw, epst) # grad_w: b t dout, weight: dout din, inputs: b t din grad_w = grad_w.transpose(1, 2).matmul(inputs).sum(dim=0) if eps.shape != bn_weight.shape or eps != bn_weight: grad_w, grad_bnw, grad_bnb = BN1dBackward(grad_w.t(), normx, varx, eps, bn_weight) grad_w = grad_w.t() x = (normx * bn_weight + bn_bias).t() else: grad_bnw = None grad_bnb = None x = weight grad_t = torch.matmul(grad_t, x) return grad_t * 0.85, grad_w, grad_bnw, grad_bnb, None, None, None, None, None, None, None
时间: 2023-11-11 14:04:13 浏览: 195
Pytorch学习之torch用法—-比较操作(Comparison Ops)
这是一个使用 PyTorch 实现的线性函数,其中包含了一些参数和激活函数。在前向传播过程中,它通过输入和权重计算出输出,然后使用 SRM 神经元激活函数将输出转换为脉冲信号。在反向传播过程中,它根据输出脉冲和一些其他参数计算出相应的梯度。该函数还包括对 Batch Normalization 的支持。
阅读全文