写一段GraphSAGE用于节点分类的代码
时间: 2023-06-14 11:06:00 浏览: 175
GraphSAGE是一种图神经网络,用于节点分类任务。以下是一段使用DGL库实现的GraphSAGE代码。
```python
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
class SAGE(nn.Module):
def __init__(self, in_feats, hidden_feats, num_classes):
super(SAGE, self).__init__()
self.conv1 = GraphConv(in_feats, hidden_feats)
self.conv2 = GraphConv(hidden_feats, num_classes)
def forward(self, g, features):
h = self.conv1(g, features)
h = F.relu(h)
h = self.conv2(g, h)
return h
# 创建一个简单的图
g = dgl.graph(([0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 0]))
features = torch.randn(7, 10) # 7个节点,每个节点有10个特征
labels = torch.tensor([0, 1, 1, 0, 1, 0, 0]) # 7个节点,每个节点有一个标签
model = SAGE(10, 16, 2) # 输入10个特征,隐藏层有16个神经元,输出2个类别
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
logits = model(g, features)
loss = F.cross_entropy(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))
```
这段代码首先定义了一个SAGE模型,包含两个GraphConv层,其中隐藏层有16个神经元,输出层有2个类别。然后创建了一个简单的图,包含7个节点和一些边。接着使用交叉熵损失训练模型。在每个epoch中,计算模型输出的logits和标签的交叉熵损失,然后使用Adam优化器更新模型参数。最后输出每个epoch的损失。
阅读全文