subgraph(subset, edge_index, edge_attr=None, relabel_nodes=False, num_nodes=None)
时间: 2023-07-06 15:16:27 浏览: 123
`subgraph`是PyTorch Geometric中的一个函数,用于从一个大的图中提取子图。
该函数的输入参数包括:
- `subset`: 待提取子图的节点列表,类型为`LongTensor`,形状为`(num_nodes,)`。
- `edge_index`: 原始图的边列表,类型为`LongTensor`,形状为`(2, num_edges)`,其中第一行为源节点索引,第二行为目标节点索引。
- `edge_attr`: 原始图的边属性列表,类型为`Tensor`,形状为`(num_edges, ...)`
- `relabel_nodes`: 是否重新对子图中的节点进行编号,类型为`bool`。
- `num_nodes`: 子图中节点的数量,类型为`int`。
函数返回一个包含子图中节点和边的新图对象。
使用示例:
```python
import torch
from torch_geometric.utils import subgraph
# 原始图数据
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
edge_attr = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]])
# 从原始图中提取节点索引为[0, 2]的子图
subset = torch.tensor([0, 2])
sub_data = subgraph(subset, edge_index, edge_attr)
# 打印子图的节点和边信息
print(sub_data)
```
输出结果:
```
Batch(batch=[2], edge_attr=[2, 2], edge_index=[2, 2])
```
阅读全文