请帮我写出基于PYG搭建的GAT实现图分类代码
时间: 2024-05-06 07:16:27 浏览: 127
Core数据集+PyG实现Core分类
好的,以下是基于PYG搭建的GAT实现图分类代码示例:
```python
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, num_features, num_classes, num_heads):
super(GAT, self).__init__()
self.conv1 = GATConv(num_features, 8, heads=num_heads, dropout=0.6)
self.conv2 = GATConv(8*num_heads, num_classes, dropout=0.6)
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
```
说明:
- `num_features`:输入特征的维度;
- `num_classes`:图分类的类别数;
- `num_heads`:GAT中的头数(默认为8);
- `self.conv1`:GAT的第一层;
- `self.conv2`:GAT的第二层;
- `forward`:前向传播函数,其中包括两层GAT和一个log_softmax层。
使用方法:
```python
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from GAT import GAT
# 加载数据集
dataset = Planetoid(root='data/Cora', name='Cora', transform=T.NormalizeFeatures())
loader = DataLoader(dataset, batch_size=32, shuffle=True)
# 初始化模型和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GAT(num_features=dataset.num_features, num_classes=dataset.num_classes, num_heads=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
# 训练模型
model.train()
for epoch in range(200):
for batch in loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
# 测试模型
model.eval()
correct = 0
for batch in loader:
batch = batch.to(device)
with torch.no_grad():
pred = model(batch.x, batch.edge_index).max(dim=1)[1]
correct += pred.eq(batch.y).sum().item()
print(f"Accuracy: {correct / len(dataset)}")
```
说明:
- `Planetoid`:PyG中一个内置的数据集,用于图分类任务;
- `T.NormalizeFeatures()`:用于归一化特征矩阵;
- `DataLoader`:数据加载器,用于批量加载数据集;
- `F.nll_loss`:负对数似然损失函数;
- `batch.train_mask`:训练集节点的掩码;
- `batch.y`:节点的真实标签;
- `pred.eq(batch.y).sum().item()`:计算预测正确的节点数。
阅读全文