def edge_attention(self, edges): # an edge UDF to compute unnormalized attention values from src and dst if self.l0 == 0: m = self.leaky_relu(edges.src['a1'] + edges.dst['a2']) else: tmp = edges.src['a1'] + edges.dst['a2'] logits = tmp + self.bias_l0 if self.training: m = l0_train(logits, 0, 1) else: m = l0_test(logits, 0, 1) self.loss = get_loss2(logits[:,0,:]).sum() return {'a': m}
时间: 2024-02-14 16:30:47 浏览: 57
edges-master.rar_7FN_Edge Boxes_wnn_检测仿真_边缘检测
在这段代码中,`edge_attention`函数是一个图神经网络中的边自定义函数。它接收一个包含边信息的edges对象作为输入,并计算未归一化的注意力值。
首先,如果`self.l0`等于0,那么将执行以下操作:
- 计算`edges.src['a1']`和`edges.dst['a2']`的和。
- 使用`self.leaky_relu`函数对和进行激活函数处理,得到注意力值`m`。
接下来,如果`self.l0`不等于0,那么将执行以下操作:
- 计算`edges.src['a1']`和`edges.dst['a2']`的和,存储在变量`tmp`中。
- 将`tmp`与`self.bias_l0`相加,得到logits(对数概率)。
- 如果处于训练模式,则调用`l0_train`函数,传入logits、0和1作为参数,得到一个掩码(mask)`m`。同时,计算损失函数`self.loss`,使用`get_loss2`函数计算logits的损失,并对第一个维度求和。
- 如果处于测试模式,则调用`l0_test`函数,传入logits、0和1作为参数,得到一个掩码(mask)`m`。
最后,函数返回一个包含注意力值的字典,键为'a',值为掩码(mask)`m`。
阅读全文