graphsage pytorch代码
时间: 2023-04-26 08:03:22 浏览: 228
graphsage-simple:适用代码测试,源代码地址:https
以下是使用PyTorch实现GraphSAGE的示例代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.utils.data import DataLoader
from torch.optim import Adam
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
from dgl.nn.pytorch import GraphConv
class GraphSAGE(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(GraphSAGE, self).__init__()
self.conv1 = GraphConv(in_feats, h_feats)
self.conv2 = GraphConv(h_feats, num_classes)
def forward(self, g, inputs):
h = self.conv1(g, inputs)
h = F.relu(h)
h = self.conv2(g, h)
return h
# 加载数据
data = citegrh.load_cora()
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
g = DGLGraph(data.graph)
# 定义超参数
in_feats = features.shape[1]
h_feats = 16
num_classes = data.num_labels
num_epochs = 10
lr = 0.01
weight_decay = 5e-4
batch_size = 128
# 定义模型、优化器、损失函数
model = GraphSAGE(in_feats, h_feats, num_classes)
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
# 训练模型
train_loader = DataLoader(torch.arange(g.number_of_nodes())[train_mask], batch_size=batch_size, shuffle=True)
val_loader = DataLoader(torch.arange(g.number_of_nodes())[val_mask], batch_size=batch_size, shuffle=False)
test_loader = DataLoader(torch.arange(g.number_of_nodes())[test_mask], batch_size=batch_size, shuffle=False)
for epoch in range(num_epochs):
model.train()
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(g, features)[inputs]
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 在验证集上评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in val_loader:
outputs = model(g, features)[inputs]
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print('Epoch {} | Val Accuracy: {:.2f}%'.format(epoch, accuracy))
# 在测试集上评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = model(g, features)[inputs]
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print('Test Accuracy: {:.2f}%'.format(accuracy))
```
上述代码实现了对Cora数据集的GraphSAGE模型训练和测试,并输出了最终测试集上的准确率。请注意,代码中使用了DGL库中的`DGLGraph`和`GraphConv`,这是以下是使用 PyTorch 实现 GraphSAGE 的代码示例:
```python
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
self.lin = Linear(out_channels, out_channels)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = F.dropout(x, training=self.training)
x = torch.mean(x, dim=0, keepdim=True)
x = self.lin(x)
return x
```
这里实现了一个包含两层 GraphSAGE 层的模型。在初始化函数中,我们定义了两个 GraphSAGE 层,每层使用的卷积核的输入和输出维度分别为 `in_channels -> hidden_channels` 和 `hidden_channels -> out_channels`。我们还定义了一个线性层,用于将 GraphSAGE 层的输出转换为最终的输出向量。
在 `forward` 函数中,我们首先将输入的节点特征矩阵 `x` 传入第一层 GraphSAGE 层中进行卷积操作,然后再将其传入第二层 GraphSAGE 层中进行卷积操作。接着,我们对每个节点的特征向量进行 dropout 操作,并对所有节点的特征向量取平均值。最后,我们将平均值向量传入线性层中进行转换,得到最终的输出向量。
注意,这里的 GraphSAGE 层是使用 PyTorch Geometric 库中的 SAGEConv 类实现的。如果你需要使用其他类型的 GraphSAGE 层,可以参考 PyTorch Geometric 的文档进行实现。
阅读全文