图卷积网络(GCN)代码
时间: 2023-07-20 11:37:38 浏览: 48
以下是使用PyTorch实现的简单图卷积网络(GCN)代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class GCNLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_channels, out_channels)
def forward(self, x, adj):
# x: input features (batch_size, num_nodes, in_channels)
# adj: adjacency matrix (batch_size, num_nodes, num_nodes)
x = self.linear(x)
adj = adj + torch.eye(adj.size(1)).unsqueeze(0).to(adj.device) # add self-loop
degree = torch.sum(adj, dim=2, keepdim=True)
adj = adj / degree # normalize adjacency matrix
x = torch.bmm(adj, x)
return F.relu(x)
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GCN, self).__init__()
self.layer1 = GCNLayer(in_channels, hidden_channels)
self.layer2 = GCNLayer(hidden_channels, out_channels)
def forward(self, x, adj):
x = self.layer1(x, adj)
x = self.layer2(x, adj)
return x
```
以上代码实现了一个包含两个GCN层的GCN模型,可以通过调用`GCN(in_channels, hidden_channels, out_channels)`来实例化该模型。其中,`in_channels`是输入特征的维度,`hidden_channels`是隐藏层的维度,`out_channels`是输出特征的维度。在模型的`forward`函数中,首先通过`GCNLayer`层对输入特征进行线性变换和非线性激活,然后根据邻接矩阵进行卷积操作,最终输出特征。