class srmConvFunc(torch.autograd.Function): @staticmethod def forward( ctx, inputs: Tensor, weight: Tensor, taum: float, taus: float, e_taug: float, v_th: float, epsw: Tensor, epst: Tensor, stride: Tuple[int] = (1, 1), padding: Tuple[int] = (0, 0), dilation: Tuple[int] = (1, 1), groups: int = 1 ) -> Tensor: out = torch.nn.functional.conv2d( inputs.view(-1, *inputs.shape[2:]), weight, None, stride, padding, dilation, groups ) spikes, delta_ut, delta_u = srmNeuronFunc.forward( out.view(*inputs.shape[:2], *out.shape[1:]), taum, taus, e_taug, v_th ) ctx.save_for_backward( inputs, weight, epsw, epst, delta_ut, delta_u, spikes, torch.tensor(stride, dtype=torch.int), torch.tensor(padding, dtype=torch.int), torch.tensor(dilation, dtype=torch.int), torch.tensor(groups, dtype=torch.int) ) return spikes @staticmethod def backward(ctx, grad_out: Tensor) -> List[Optional[Tensor]]: inputs, weight, epsw, epst, delta_ut, delta_u, spikes, stride, padding, dilation, groups = ctx.saved_tensors stride = tuple(stride) padding = tuple(padding) dilation = tuple(dilation) groups = int(groups) grad_w, grad_t = srmNeuronFunc.backward(grad_out, delta_ut, delta_u, spikes, epsw, epst) grad_inputs = conv_wrapper.cudnn_convolution_backward_input( inputs.view(-1, *inputs.shape[2:]).shape, grad_t.view(-1, *grad_t.shape[2:]), weight, padding, stride, dilation, groups, cudnn.benchmark, cudnn.deterministic, cudnn.allow_tf32 ) grad_inputs = grad_inputs.view(*inputs.shape) * inputs grad_weight = conv_wrapper.cudnn_convolution_backward_weight( weight.shape, grad_w.view(-1, *grad_w.shape[2:]), inputs.view(-1, *inputs.shape[2:]), padding, stride, dilation, groups, cudnn.benchmark, cudnn.deterministic, cudnn.allow_tf32 ) return grad_inputs * 0.85, grad_weight, None, None, None, None, None, None, None, None, None, None
时间: 2023-11-18 10:02:57 浏览: 134
利用torch.autograd.Function自定义层的forward和backward
这段代码实现了一个自定义的卷积函数 srmConvFunc,其中包括前向传播和反向传播两个方法。这个卷积函数使用了 srmNeuronFunc 中定义的神经元函数来处理卷积的输出,具体实现是先调用 PyTorch 内置的 conv2d 函数计算出卷积的输出,再将输出作为参数传入神经元函数中,得到最终的输出 spikes。
在前向传播方法中,除了调用 conv2d 和神经元函数之外,还需要保存一些参数和中间结果,以便在反向传播时使用。这些参数和中间结果包括输入张量 inputs、卷积核 weight、epsilon 参数 epsw 和 epst、神经元函数的一些中间结果 delta_ut、delta_u 和 spikes,以及卷积的一些参数 stride、padding、dilation 和 groups。
在反向传播方法中,首先通过神经元函数的反向传播方法得到关于输出梯度 grad_out 的梯度 grad_w 和 delta_t,然后调用 PyTorch 提供的 cudnn_convolution_backward_input 和 cudnn_convolution_backward_weight 函数分别计算关于输入和权重的梯度 grad_inputs 和 grad_weight。最后,将 grad_inputs 乘以一个因子 0.85 并返回梯度值。
阅读全文