GraphSAGE代码实现
时间: 2023-09-17 08:11:54 浏览: 45
GraphSAGE是一种基于图神经网络的节点嵌入方法,可以用于节点分类、链接预测等任务。以下是GraphSAGE的代码实现:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super(GraphSAGE, self).__init__()
self.convs = nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for i in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
def forward(self, x, edge_index):
for conv in self.convs[:-1]:
x = conv(x, edge_index)
x = F.relu(x)
x = self.convs[-1](x, edge_index)
return x
```
上述代码中,我们使用了PyTorch Geometric库中的SAGEConv,它是GraphSAGE中的一层。GraphSAGE由多层SAGEConv组成,其中每一层都使用relu函数作为非线性激活函数。在forward函数中,我们遍历所有的层,并应用到输入x上,最后输出最终的节点嵌入。