GAT pytorch实现
时间: 2023-07-11 11:49:03 浏览: 114
GAT (Graph Attention Network) 是一种基于注意力机制的图神经网络模型,适用于节点分类、图分类等任务。
以下是使用 PyTorch 实现 GAT 模型的代码示例:
```python
import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GATLayer(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GATLayer, self).__init__(aggr='add') # "Add" aggregation.
self.lin = torch.nn.Linear(in_channels, out_channels)
self.att = torch.nn.Linear(2*out_channels, 1)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute attention coefficients.
edge_src, edge_dst = edge_index
x_i = x[edge_src] # [E, out_channels]
x_j = x[edge_dst] # [E, out_channels]
alpha = self.att(torch.cat([x_i, x_j], dim=-1)) # [E, 1]
alpha = F.leaky_relu(alpha, negative_slope=0.2)
alpha = torch.softmax(alpha, dim=0) # [E, 1]
# Step 4: Message passing.
return self.propagate(edge_index, x=x, alpha=alpha)
def message(self, x_j, alpha):
# x_j has shape [E, out_channels]
# alpha has shape [E, 1]
return alpha * x_j
def update(self, aggr_out):
# aggr_out has shape [N, out_channels]
return aggr_out
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super(GAT, self).__init__()
self.layers = torch.nn.ModuleList()
self.layers.append(GATLayer(in_channels, hidden_channels))
for i in range(num_layers - 2):
self.layers.append(GATLayer(hidden_channels, hidden_channels))
self.layers.append(GATLayer(hidden_channels, out_channels))
def forward(self, x, edge_index):
for layer in self.layers:
x = F.elu(layer(x, edge_index))
return x
```
在上述代码中,`GATLayer` 类表示 GAT 网络中的一层,`GAT` 类表示整个 GAT 网络。`GATLayer` 类继承自 `MessagePassing` 类,表示使用消息传递机制进行计算,`GAT` 类继承自 `torch.nn.Module` 类。在 `GATLayer` 类中,`forward` 方法表示前向传播过程,其中包括添加自环、线性变换、计算注意力系数、消息传递等操作;`message` 方法表示消息传递过程;`update` 方法表示节点更新过程。在 `GAT` 类中,`__init__` 方法中定义了多个 GAT 层,`forward` 方法中通过多次调用 GAT 层实现整个网络的前向传播过程。
阅读全文