def message_func1(self, edges): msg = torch.empty((edges.src['h'].shape[0], self.out_feats), device=edges.src['h'].device) for etype in range(self.num_rels): loc = edges.data['type'] == etype if loc.sum() == 0: continue src = edges.src['h'][loc] dst = edges.dst['h'][loc] sub_msg = self.rel_ME[etype](dst, src) msg[loc] = sub_msg return {'m': msg}
时间: 2024-02-14 17:30:30 浏览: 26
这段代码是 GNNLayer 中的 `message_func1` 方法的具体实现。
`message_func1` 方法用于定义消息传递函数,它接收一个表示边的对象 `edges` 作为输入,并返回一个字典,其中包含消息张量 `m`。
首先,根据源节点的特征维度和输出特征维度,创建一个空的消息张量 `msg`,其形状为 `(edges.src['h'].shape[0], self.out_feats)`,设备与源节点特征张量 `edges.src['h']` 的设备一致。
然后,对于每个关系类型 `etype`,通过判断边的类型 `edges.data['type']` 是否等于当前关系类型 `etype`,得到一个布尔索引数组 `loc`。如果某个关系类型没有对应的边,则 `loc.sum()` 为 0,表示没有需要传递的消息,可以跳过该关系类型。
接下来,根据 `loc` 数组选择对应的源节点特征和目标节点特征,分别存储在变量 `src` 和 `dst` 中。
然后,通过调用记忆编码模块 `self.rel_ME[etype]` 对目标节点特征 `dst` 和源节点特征 `src` 进行记忆编码,并得到子消息张量 `sub_msg`。
最后,将子消息张量 `sub_msg` 根据布尔索引数组 `loc` 更新到消息张量 `msg` 中,只更新那些对应关系类型的位置。
最终,将包含消息张量 `msg` 的字典返回,字典的键为 `'m'`。这样,消息传递阶段就完成了,每个边都会根据其关系类型生成相应的消息,并将其存储在字典中返回。
相关问题
def wraper_func(test_func): @wraps(test_func) # @wraps:避免被装饰函数自身的信息丢失 def inner_func(self): failures = str([fail[0] for fail in self._outcome.result.failures]) if test_func.__name__ in failures: if self._id() == self._testMethodName: test = False else: test = True else: test = True return test(self) if inner_func(self) == True: return True else: return False 这个代码有错吗,怎么更改
这段代码有一些问题。首先,在 `inner_func` 函数中,`test` 变量被赋值为一个布尔值,但在返回时却被当作了一个函数来调用。其次,在 `wraper_func` 函数中,`inner_func` 函数应该返回一个函数对象,而不是一个布尔值。
以下是更改后的代码:
```
from functools import wraps
def wraper_func(test_func):
@wraps(test_func)
def inner_func(self):
failures = str([fail[0] for fail in self._outcome.result.failures])
if test_func.__name__ in failures:
if self._id() == self._testMethodName:
test = lambda: False
else:
test = lambda: True
else:
test = lambda: True
return test
return inner_func(self)
# 测试
class Test:
def test_func(self):
return True
test = Test()
result = wraper_func(test.test_func)
print(result()) # 输出 True
```
在更改后的代码中,`inner_func` 函数返回一个匿名函数,该函数根据条件返回 `True` 或 `False`。在 `wraper_func` 函数中,我们调用 `inner_func` 并传入 `self` 参数,该参数在测试类的实例化对象中提供。然后,我们将返回的函数对象存储在 `result` 变量中,并在需要时调用它来执行测试。
def wraper_func(test_func): @wraps(test_func) # @wraps:避免被装饰函数自身的信息丢失 def inner_func(self): failures = str([fail[0] for fail in self._outcome.result.failures]) flag = (test_func in failures) if failures.find(test_func) != -1: test = unittest.skipIf(flag, "{} failed".format(test_func))(test_func) else: test = test_func return test(self) return inner_func 写的装饰器判断函数为这个,为什么上条用例跳过了,下条还是会跳过,而不是失败时才会跳过用例
根据你提供的代码,我可以看出这是一个装饰器函数,用于判断某个测试函数是否在测试结果的失败列表中。如果在失败列表中,则使用unittest.skipIf将该测试函数跳过。如果不在失败列表中,则直接执行该测试函数。
如果你发现上一条用例跳过了,而下一条用例仍然跳过,这可能是因为你的测试函数在多个测试用例中都被调用了,而且在其中一个测试用例中出现了失败。因此,该测试函数将被跳过,无论它是在哪个测试用例中调用。
解决这个问题的方法是在测试函数中添加一个条件,仅当测试在当前用例中运行时才跳过该测试。你可以通过检查unittest.TestCase实例的_id()是否与当前用例的id相同来实现这一点。例如:
```
import unittest
from functools import wraps
def skip_if_failed(test_func):
@wraps(test_func)
def inner_func(self):
failures = [fail[0] for fail in self._outcome.result.failures]
if test_func.__name__ in failures:
if self._id() == self._testMethodName:
test = unittest.skip("{} failed".format(test_func.__name__))(test_func)
else:
test = test_func
else:
test = test_func
return test(self)
return inner_func
```
在这个修改后的版本中,我们首先获取了所有失败的测试函数名称,并将其存储在failures列表中。然后,我们检查当前测试函数是否在失败列表中,以及当前测试函数的id是否与当前测试用例的id相同。只有在这两个条件都满足时,我们才会将测试函数跳过,否则我们将正常地执行它。