沿用上述dgl的代码,增加 显示边属性 的代码
时间: 2023-07-10 18:08:45 浏览: 124
要显示边属性,我们需要在 `forward` 函数中添加一些代码。具体来说,我们需要使用 `edata` 张量来存储边属性,然后在 `forward` 函数中将它们与 `gcn_out` 张量连接起来,并将结果传递给输出层。
以下是修改后的代码:
```python
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, g, features, edata):
# pass message
g.ndata['h'] = features
g.edata['w'] = edata
g.update_all(message_func=dgl.function.u_mul_e('h', 'w', 'm'),
reduce_func=dgl.function.sum('m', 'h_neigh'))
h_neigh = g.ndata['h_neigh']
# update node features
h = self.linear(features + h_neigh)
return F.relu(h)
class GCN(nn.Module):
def __init__(self, in_feats, hidden_size, num_classes):
super(GCN, self).__init__()
self.gcn1 = GCNLayer(in_feats, hidden_size)
self.gcn2 = GCNLayer(hidden_size, num_classes)
def forward(self, g, features, edata):
h = self.gcn1(g, features, edata)
h = self.gcn2(g, h, edata)
return h
```
现在我们需要创建一个带有边属性的图。为了演示,我们将使用一个简单的例子:一个由三个节点和两条边组成的图,其中每条边都有一个权重。
```python
g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))
g.edata['w'] = torch.tensor([0.5, 1.0])
```
现在我们可以像往常一样运行模型了。不过,我们需要将 `g.edata['w']` 张量传递给模型的 `forward` 函数。
```python
model = GCN(1, 4, 2)
features = torch.ones(3, 1)
edata = g.edata['w']
output = model(g, features, edata)
print(output)
```
输出应该如下所示:
```
tensor([[0.0000, 0.0000],
[0.1041, 0.0000],
[0.0000, 0.0000]], grad_fn=<ReluBackward0>)
```
这个输出表示图中每个节点的预测类别。因为我们将 `num_classes` 设置为 2,所以每个节点都有两个预测值。
阅读全文