用dgl.scatter_add和torch_scatter import scatter 运行同一组数据
时间: 2023-09-28 18:12:30 浏览: 111
dgl.scatter_add和torch_scatter的作用都是在张量的某个维度上对指定的索引进行聚合操作,但是它们的实现方式略有不同。
dgl.scatter_add是Deep Graph Library(DGL)中提供的函数,用于在图上对节点特征进行聚合。它的具体使用方式如下:
```
import dgl
import torch
# 构造一个包含3个节点的图
g = dgl.graph(([0, 1, 2], [1, 2, 0]))
# 构造一个3x4的特征矩阵
feat = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
# 定义聚合的方式为相加
def reducer(nodes):
return {'sum': torch.sum(nodes.mailbox['m'], dim=1)}
# 对节点的特征进行聚合,聚合方式为相加
g.update_all(dgl.function.copy_src('feat', 'm'), dgl.function.sum('m', 'feat'))
feat = g.ndata['feat']
# 使用scatter_add对特征进行聚合
idx = torch.tensor([0, 1, 2, 0])
val = torch.tensor([1, 2, 3, 4])
feat = torch.zeros(3, 5)
feat.scatter_add_(1, idx.unsqueeze(1).repeat(1, 4), val.unsqueeze(1).repeat(1, 4))
```
其中,scatter_add_函数的第一个参数表示要聚合的维度,这里是第1维,即按行进行聚合。第二个参数是一个形状为(N, M)的张量,表示要聚合的索引,这里是idx.unsqueeze(1).repeat(1, 4),即将idx扩展为形状为(N, M),其中M为feat的第1维大小。第三个参数是一个形状为(N, M)的张量,表示要聚合的值,这里是val.unsqueeze(1).repeat(1, 4),即将val扩展为形状为(N, M)。
torch_scatter中的scatter函数使用方式也类似,具体使用方式如下:
```
import torch
from torch_scatter import scatter
idx = torch.tensor([0, 1, 2, 0])
val = torch.tensor([1, 2, 3, 4])
feat = torch.zeros(3, 5)
feat = scatter(val.unsqueeze(1).repeat(1, 4), idx.unsqueeze(1).repeat(1, 4), dim=1, out=feat, reduce='add')
```
其中,scatter函数的第一个参数表示要聚合的值,第二个参数表示要聚合的索引,第三个参数表示要聚合的维度,这里是第1维,即按行进行聚合。第四个参数表示输出的张量,这里是feat,最后一个参数表示聚合方式,这里是相加。
阅读全文