class Hypergraph_Infomax(nn.Module): def __init__(self): super(Hypergraph_Infomax, self).__init__() self.Hypergraph = Hypergraph() self.readout = AvgReadout() self.sigm = nn.Sigmoid() self.disc = Discriminator() def forward(self, eb_pos, eb_neg): h_pos = self.Hypergraph(eb_pos) c = self.readout(h_pos) score = self.sigm(c) h_neg = self.Hypergraph(eb_neg) ret = self.disc(score, h_pos, h_neg) return h_pos, ret
时间: 2023-11-29 09:03:43 浏览: 125
使用Hypergraph学习以提取高光谱图像特征
这段代码实现了一个名为Hypergraph_Infomax的神经网络模型,用于对超图进行信息最大化学习。该模型由三个子模块组成:
- Hypergraph:超图编码器,用于将输入的边集(eb_pos和eb_neg)编码为超图。
- AvgReadout:超图池化操作,用于将超图中的节点信息汇总成一个固定长度的向量。
- Discriminator:判别器,用于判断输入的超图是否为真实超图(eb_pos)。
在forward函数中,首先用Hypergraph将eb_pos和eb_neg分别编码为超图h_pos和h_neg,然后通过AvgReadout将h_pos池化为一个向量c,并使用Sigmoid函数将c映射到[0, 1]之间得到score。最后,将score、h_pos和h_neg输入到Discriminator中,得到ret作为模型的输出。
阅读全文