GAT pytorch代码
时间: 2023-10-23 22:10:40 浏览: 123
以下是使用PyTorch实现GAT的代码示例:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GATConv(MessagePassing):
def __init__(self, in_channels, out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0):
super(GATConv, self).__init__(aggr='add')
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels))
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
nn.init.xavier_uniform_(self.att)
nn.init.constant_(self.bias, 0)
def forward(self, x, edge_index):
x = torch.matmul(x, self.weight).view(-1, self.heads, self.out_channels)
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j, edge_index):
edge_index = add_self_loops(edge_index, num_nodes=x_i.size(0))
edge_weight = None
if self.concat:
x_i = x_i.view(-1, self.heads, 1, self.out_channels)
x_j = x_j.view(-1, self.heads, 1, self.out_channels)
alpha = torch.cat([x_i, x_j], dim=2)
alpha = torch.sum(alpha * self.att, dim=-1)
else:
alpha = torch.cat([x_i, x_j], dim=-1)
alpha = torch.sum(alpha * self.att, dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = self.softmax(alpha, edge_index)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return x_j * alpha.view(-1, self.heads, 1)
def softmax(self, alpha, edge_index):
row, col = edge_index
alpha = alpha - alpha.max(dim=-1, keepdim=True)[0]
alpha_exp = alpha.exp()
alpha_exp_sum = degree(col, alpha_exp, dtype=alpha.dtype)
return alpha_exp / alpha_exp_sum[row].view(-1, 1)
class GAT(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=1, num_layers=2, dropout=0):
super(GAT, self).__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.heads = heads
self.num_layers = num_layers
self.dropout = dropout
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, concat=True, dropout=dropout)
self.convs = nn.ModuleList()
for i in range(num_layers - 2):
self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads, concat=True, dropout=dropout))
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout)
def forward(self, x, edge_index):
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.elu(self.conv1(x, edge_index))
for conv in self.convs:
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.elu(conv(x, edge_index))
x = self.conv2(x, edge_index)
return x
```
阅读全文