class R_NN(nn.Module): def __init__(self, in_dim, n_class): super(R_NN, self).__init__() self.logstic = nn.Linear(in_dim, n_class) self.predict=nn.Sigmoid() def forward(self, x): out = self.logstic(x) out = out.squeeze(1) #降维 return out
时间: 2023-06-08 13:05:27 浏览: 197
视频中工程代码 class18_TLV1544
这段代码是一个定制化的PyTorch神经网络模块。该模块名为R_NN,有两个输入参数:in_dim表示输入数据的维度,n_class表示分类的数量。在初始化函数中,调用了super()函数来获得父类nn.Module的属性和方法,然后定义了一个线性层(logstic)和一个Sigmoid层(predict)。在前向传播函数forward()中,对输入数据x进行线性变换,再使用squeeze()函数将一维的输出向量压缩为一个标量。
阅读全文