class srmConv2d(nn.Conv2d): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, v_th=1.0, taum=5., taus=3., taug=2.5 ) -> None: super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) nn.init.orthogonal_(self.weight) self.taum = taum self.taus = taus self.taug = taug self.v_th = v_th self.epsw = None self.epst = None self.e_taum = 1. - 1. / taum self.e_taus = 1. - 1. / taus self.e_taug = 1. - 1. / taug self.conv_func = srmConvFunc.apply def batch_reset(self, inputs: Tensor) -> None: if self.epsw is None or self.epsw.shape[0] != inputs.shape[1]: coefficient = self.taum / (self.taum - self.taus) # for i in range(inputs.shape[1]): self.epst = torch.FloatTensor([-self.e_taug ** (1 + i) for i in range(inputs.shape[1])]).to(inputs) self.epsw = torch.FloatTensor( [coefficient * (self.e_taum ** (1 + i) - self.e_taus ** (1 + i)) for i in range(inputs.shape[1])] ).to(inputs) def forward(self, inputs): self.batch_reset(inputs) return self.conv_func( inputs, self.weight, self.taum, self.taus, self.e_taug, self.v_th, self.epsw, self.epst, self.stride, self.padding, self.dilation, self.groups )
时间: 2023-12-09 12:03:28 浏览: 171
这是一个自定义的卷积层类 srmConv2d,继承自 PyTorch 自带的 nn.Conv2d。它的初始化函数中有一些额外的参数,包括 v_th、taum、taus、taug,分别表示神经元的阈值、膜电位时间常数、突触前电位时间常数和突触后电位时间常数。该类还定义了一个 batch_reset 方法,用于根据输入数据重新计算权重的一些参数,如 epst 和 epsw。最后,类中的 forward 方法调用了自定义的卷积函数 srmConvFunc.apply,该函数实现了 SRM 卷积的计算过程。
阅读全文