class GraphSAGE(nn.Module): def __init__(self, in_feats, hidden_feats, out_feats, num_layers, activation): super(GraphSAGE, self).__init__() self.num_layers = num_layers self.conv1 = SAGEConv(in_feats, hidden_feats, aggregator_type='mean') self.convs = nn.ModuleList() for i in range(num_layers - 2): self.convs.append(SAGEConv(hidden_feats, hidden_feats, aggregator_type='mean')) self.conv_last = SAGEConv(hidden_feats, out_feats, aggregator_type='mean') self.activation = activation def forward(self, blocks, x): h = x for i, block in enumerate(blocks): h_dst = h[:block.number_of_dst_nodes()] h = self.convs[i](block, (h, h_dst)) if i != self.num_layers - 2: h = self.activation(h) h = self.conv_last(blocks[-1], (h, h_dst)) return h改写一下,让它适用于异质图
时间: 2023-06-25 13:02:53 浏览: 354
class GraphSAGE(nn.Module):
def __init__(self, in_feats, hidden_feats, out_feats, num_layers, activation):
super(GraphSAGE, self).__init__()
self.num_layers = num_layers
self.conv1 = SAGEConv(in_feats, hidden_feats, aggregator_type='mean')
self.convs = nn.ModuleList()
for i in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_feats, hidden_feats, aggregator_type='mean'))
self.conv_last = SAGEConv(hidden_feats, out_feats, aggregator_type='mean')
self.activation = activation
def forward(self, blocks, x_dict):
h = {k: v for k, v in x_dict.items()}
for i, block in enumerate(blocks):
edge_type = block.edata['type']
h_dst = h[str(edge_type)][block.dstdata[dgl.NID]]
h = self.convs[i](block, (h, h_dst))
if i != self.num_layers - 2:
h = self.activation(h)
h_dst = h[str(edge_type)][blocks[-1].dstdata[dgl.NID]]
h = self.conv_last(blocks[-1], (h, h_dst))
return h
阅读全文