class MultiHeadAttentionGraph(nn.Module): def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): super().__init__() self.n_head = n_head self.d_model = d_model self.d_k = d_k self.d_v = d_v self.W_Q = nn.Linear(d_model, n_head*d_k) # account for the fact that the relational edge information has double # the length self.W_K = nn.Linear(d_model*2, n_head*d_k) self.W_V = nn.Linear(d_model*2, n_head*d_v) self.W_O = nn.Linear(n_head*d_v, d_model) self.softmax = nn.Softmax(dim=-1) self.layer_norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, nodes, edges): n_batch, n_nodes, n_neighbors = edges.shape[:3] Q = self.W_Q(nodes).view([n_batch, n_nodes, 1, self.n_head, 1, self.d_k]) K = self.W_K(edges).view([n_batch, n_nodes, n_neighbors, self.n_head, self.d_k, 1]) attention = torch.matmul(Q, K).view([n_batch, n_nodes, n_neighbors, self.n_head]).transpose(-2,-1) attention = attention /np.sqrt(self.d_k) attention = self.softmax(attention) V = self.W_V(edges).view([n_batch, n_nodes, n_neighbors, self.n_head, self.d_v]).transpose(2,3) attention = attention.unsqueeze(-2) output = torch.matmul(attention, V).view([n_batch, n_nodes, self.d_v*self.n_head]) output = self.W_O(output) output = self.dropout(output) output = self.layer_norm(output + nodes) attention = attention.squeeze(-2).transpose(-2,-1) return output, attention
时间: 2024-02-14 17:07:12 浏览: 124
Python RuntimeError: thread.__init__() not called解决方法
这是一个PyTorch模型的类定义,名为MultiHeadAttentionGraph。它继承自nn.Module类,表示这是一个可训练的模型。该模型实现了多头注意力机制,用于图神经网络中节点特征的更新。
该类的初始化方法中,有5个参数:n_head表示注意力头数,d_model表示输入节点特征的维度,d_k和d_v分别表示每个注意力头的查询、键、值向量的维度,dropout表示Dropout层的丢弃率。
该类的forward方法中,有2个参数:nodes表示输入的节点特征,edges表示节点之间的关系信息。其中,nodes的形状为(batch_size, num_nodes, d_model),edges的形状为(batch_size, num_nodes, num_neighbors, 2*d_model),其中2*d_model表示每条关系信息包含起点和终点的节点特征。
在forward方法中,首先通过全连接层W_Q、W_K、W_V将输入特征映射到查询、键、值向量,并对维度进行调整,得到Q、K、V三个张量。然后使用torch.matmul函数计算Q、K的点积,再进行softmax得到每个节点与邻居节点的注意力权重。接着使用torch.matmul函数计算注意力权重与V的加权和,并将结果维度调整为(batch_size, num_nodes, n_head*d_v),通过全连接层W_O得到输出特征。最后使用Dropout层和LayerNorm层对输出特征进行处理,得到最终的节点特征输出。注意力权重也作为函数的输出返回。
阅读全文