GraphSAGE代码示例
时间: 2023-07-02 14:07:02 浏览: 48
以下是使用DGL库实现的GraphSAGE代码示例:
```python
import dgl
import torch
import torch.nn as nn
import torch.optim as optim
from dgl.nn import SAGEConv
# 构建图
g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 0]))
g.ndata["feat"] = torch.randn(6, 10) # 节点特征
# 定义模型
class GraphSAGE(nn.Module):
def __init__(self, in_feats, hidden_feats, num_layers):
super(GraphSAGE, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(SAGEConv(in_feats, hidden_feats, "mean"))
for i in range(num_layers - 1):
self.layers.append(SAGEConv(hidden_feats, hidden_feats, "mean"))
def forward(self, g, h):
for layer in self.layers:
h = layer(g, h)
return h
# 训练模型
model = GraphSAGE(10, 16, 2)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
logits = model(g, g.ndata["feat"])
labels = torch.tensor([0, 1, 2, 3, 4, 5])
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Epoch {:03d} | Loss {:.4f}".format(epoch, loss.item()))
```
这个示例代码定义了一个包含两个GraphSAGE层的模型,并使用Cross-Entropy作为损失函数,Adam作为优化器进行训练。在训练过程中,使用节点特征和边信息构建了一个图,每个epoch中将图数据输入模型进行训练。