class MultiHeadGraphAttention(torch.nn.Module): def __init__(self, num_heads, dim_in, dim_k, dim_v): super(MultiHeadGraphAttention, self).__init__() #"dim_k and dim_v must be multiple of num_heads" assert dim_k % num_heads == 0 and dim_v % num_heads == 0 self.num_heads = num_heads self.dim_in = dim_in self.dim_k = dim_k self.dim_v = dim_v self.linear_q = torch.nn.Linear(dim_in, dim_k, bias=False) self.linear_k = torch.nn.Linear(dim_in, dim_k, bias=False) self.linear_v = torch.nn.Linear(dim_in, dim_v, bias=False) self.leaky_relu = torch.nn.LeakyReLU(negative_slope=0.2) self._nor_fact = 1 / sqrt(dim_k // num_heads)
时间: 2023-12-29 11:02:59 浏览: 126
lbcnn.torch-master.zip_CNN_LBC_LBP CNN_becomeg53_torch
这是一个实现多头图注意力机制的 PyTorch 模块。该模块将输入的节点特征矩阵作为 Q(查询)、K(键)和 V(值)三个线性变换的输入,并将其分别映射为 dim_k、dim_k 和 dim_v 维的特征矩阵。然后,将这些特征矩阵按照 num_heads 头进行切分,每个头的维度为 dim_k/num_heads 和 dim_v/num_heads,然后进行注意力计算。最后将每个头的结果拼接在一起,经过一次线性变换输出。其中,_nor_fact 是一个归一化因子,用于控制注意力的大小。
需要注意的是,这个模块只处理了节点之间的注意力计算,如果要考虑边上的权重信息,还需要在输入特征矩阵中加入边的特征信息,并在计算注意力时将其考虑进去。
阅读全文