python dgl 构建包含100个图的数据集 并保存到本地磁盘中
时间: 2024-02-20 07:00:52 浏览: 154
图的数据集
要构建包含100个图的数据集,你需要先确定每个图的节点数量、边数量、节点特征以及边特征等信息。假设你已经有了这些信息,那么可以按如下步骤构建数据集:
1. 导入必要的库:
```python
import dgl
import numpy as np
import os
```
2. 定义生成图的函数:
```python
def generate_graph(num_nodes, num_edges, num_node_features, num_edge_features):
g = dgl.DGLGraph()
g.add_nodes(num_nodes)
src, dst = np.random.randint(0, num_nodes, size=(2, num_edges))
g.add_edges(src, dst)
node_features = np.random.rand(num_nodes, num_node_features)
edge_features = np.random.rand(num_edges, num_edge_features)
g.ndata['feat'] = node_features
g.edata['feat'] = edge_features
return g
```
这个函数会随机生成一个包含指定节点数量、边数量、节点特征和边特征的图,并返回这个图对象。
3. 循环生成100个图,并保存到本地磁盘中:
```python
# 创建保存数据集的目录
if not os.path.exists('datasets'):
os.mkdir('datasets')
for i in range(100):
g = generate_graph(num_nodes=100, num_edges=300, num_node_features=10, num_edge_features=5)
dgl.save_graphs(f'datasets/graph_{i}.bin', [g])
```
这个代码会生成100个包含100个节点、300条边,每个节点有10个特征、每条边有5个特征的图,并将这些图保存到名为 `datasets` 的目录中,文件名分别为 `graph_0.bin` 到 `graph_99.bin`。
注意:由于每个图的节点数量和边数量都是随机生成的,所以你得到的数据集可能与上面描述的有所不同。如果需要固定每个图的节点数量和边数量,可以将 `generate_graph` 函数中的随机生成代码替换为固定值。
阅读全文