def message_func1(self, edges): msg = torch.empty((edges.src['h'].shape[0], self.out_feats), device=edges.src['h'].device) for etype in range(self.num_rels): loc = edges.data['type'] == etype if loc.sum() == 0: continue src = edges.src['h'][loc] dst = edges.dst['h'][loc] sub_msg = self.rel_ME[etype](dst, src) msg[loc] = sub_msg return {'m': msg}
时间: 2024-02-14 09:30:30 浏览: 115
TypeError: _queue_reduction(): incompatible function arguments.
这段代码是 GNNLayer 中的 `message_func1` 方法的具体实现。
`message_func1` 方法用于定义消息传递函数,它接收一个表示边的对象 `edges` 作为输入,并返回一个字典,其中包含消息张量 `m`。
首先,根据源节点的特征维度和输出特征维度,创建一个空的消息张量 `msg`,其形状为 `(edges.src['h'].shape[0], self.out_feats)`,设备与源节点特征张量 `edges.src['h']` 的设备一致。
然后,对于每个关系类型 `etype`,通过判断边的类型 `edges.data['type']` 是否等于当前关系类型 `etype`,得到一个布尔索引数组 `loc`。如果某个关系类型没有对应的边,则 `loc.sum()` 为 0,表示没有需要传递的消息,可以跳过该关系类型。
接下来,根据 `loc` 数组选择对应的源节点特征和目标节点特征,分别存储在变量 `src` 和 `dst` 中。
然后,通过调用记忆编码模块 `self.rel_ME[etype]` 对目标节点特征 `dst` 和源节点特征 `src` 进行记忆编码,并得到子消息张量 `sub_msg`。
最后,将子消息张量 `sub_msg` 根据布尔索引数组 `loc` 更新到消息张量 `msg` 中,只更新那些对应关系类型的位置。
最终,将包含消息张量 `msg` 的字典返回,字典的键为 `'m'`。这样,消息传递阶段就完成了,每个边都会根据其关系类型生成相应的消息,并将其存储在字典中返回。
阅读全文