MindSpore加载图数据集
加载图数据集
- MindSpore提供的
mindspore.dataset
模块可以帮助用户构建数据集对象,分批次地读取文本数据。
图的概念
-
通常一个图(graph)
G
是由一系列的节点(vertices)V
以及边(eage)E
组成的,每条边都连接着图中的两个节点,用公式可表述为:G = F(V, E)
,简单的图如下所示。
-
图中包含节点V = {a, b, c, d},和边E = {(a, b), (b, c), (c, d), (d, b)},针对图中的连接关系通常需借助数学的方式进行描述,如常用的基于邻接矩阵的方式,用于描述上述图连接关系的矩阵C如下,其中a、 b、c、d对应为第1、2、 3、4个节点。
数据集下载和转换
(1) 数据集介绍
-
常用的图数据集包含Cora、Citeseer、PubMed等
-
原始数据集可以从ucsc网站进行下载,
-
github提供的预处理后的数据集,GCN等公开使用
-
Cora数据集主体部分(
cora.content
)-
2708条样本(节点),每条样本描述1篇科学论文的信息,论文都属于7个类别中的一个。每条样本数据包含三部分,依次为论文编号、论文的词向量(一个1433位的二进制)、论文的类别;
-
引用数据集部分(
cora.cites
)包含5429行(边),每行包含两个论文编号,表示第二篇论文对第一篇论文进行了引用。
-
数据集下载:下载预处理后的cora数据集目录如下:
.
└── cora
├── ind.cora.allx
├── ind.cora.ally
├── ind.cora.graph
├── ind.cora.test.index
├── ind.cora.tx
├── ind.cora.ty
├── ind.cora.x
├── ind.cora.y
├── trans.cora.graph
├── trans.cora.tx
├── trans.cora.ty
├── trans.cora.x
└── trans.cora.y
(2)数据集下载
以下示例代码将cora数据集下载并解压到指定位置。
!mkdir -p ./cora
!git clone https://github.com/kimiyoung/planetoid
!cp planetoid/data/*.cora.* ./cora
!rm -rf planetoid
(3)数据集格式转换
- 数据集格式转换:将数据集转换为MindRecord格式,可借助models仓库提供的转换脚本进行转换,生成的MindRecord文件在
./cora_mindrecord
路径下。
!git clone https://gitee.com/mindspore/models.git
SRC_PATH = "./cora"
MINDRECORD_PATH = "./cora_mindrecord"
!rm -rf $MINDRECORD_PATH
!mkdir $MINDRECORD_PATH
!python models/utils/graph_to_mindrecord/writer.py --mindrecord_script cora --mindrecord_file "$MINDRECORD_PATH/cora_mr" --mindrecord_partitions 1 --mindrecord_header_size_by_bit 18 --mindrecord_page_size_by_bit 20 --graph_api_args "$SRC_PATH"
- 报错,但命令行可以
- 改: 环境切换 没得搞定啊
!source activate py37_ms16
!python models/utils/graph_to_mindrecord/writer.py --mindrecord_script cora --mindrecord_file "$MINDRECORD_PATH/cora_mr" --mindrecord_partitions 1 --mindrecord_header_size_by_bit 18 --mindrecord_page_size_by_bit 20 --graph_api_args "$SRC_PATH"
- 乖乖命令行试试。看来默认环境没有ms不行?
source activate py37_ms16
python models/utils/graph_to_mindrecord/writer.py --mindrecord_script cora --mindrecord_file "./cora_mindrecord/cora_mr" --mindrecord_partitions 1 --mindrecord_header_size_by_bit 18 --mindrecord_page_size_by_bit 20 --graph_api_args "./cora"
加载数据集
-
MindSpore目前支持加载文本领域常用的经典数据集和多种数据存储格式下的数据集,用户也可以通过构建自定义数据集类实现自定义方式的数据加载。
-
下面演示使用
MindSpore.dataset
模块中的MindDataset
类加载上述已转换成mindrecord格式的cora数据集。
(1)配置数据集目录,创建数据集对象。
import mindspore.dataset as ds
import numpy as np
data_file = "./cora_mindrecord/cora_mr"
dataset = ds.GraphData(data_file)
(2)访问对应的接口,获取图信息及特性、标签内容。
# 查看图中结构信息
graph = dataset.graph_info()
print("graph info:", graph)
# 获取所有的节点信息
nodes = dataset.get_all_nodes(0)
nodes_list = nodes.tolist()
print("node shape:", len(nodes_list))
# 获取特征和标签信息,总共2708条数据
# 每条数据中特征信息是用于描述论文i,长度为1433的二进制表示,标签信息指的是论文所属的种类
raw_tensor = dataset.get_node_feature(nodes_list, [1, 2])
features, labels = raw_tensor[0], raw_tensor[1]
print("features shape:", features.shape)
print("labels shape:", labels.shape)
print("labels:", labels)
数据处理
- MindSpore目前支持的数据处理算子及其详细使用方法。下面构建pipeline,对节点进行采样等操作。
(1)获取节点的邻居节点,构造邻接矩阵。
neighbor = dataset.get_all_neighbors(nodes_list, 0)
# neighbor的第一列是node_id,第二列到最后一列存储的是第一列的邻居节点,如果不存在这么多,则用-1补齐。
print("neighbor:\n", neighbor)
(2)依据节点的邻居节点信息,构造邻接矩阵。
nodes_num = labels.shape[0]
node_map = {node_id: index for index, node_id in enumerate(nodes_list)}
adj = np.zeros([nodes_num, nodes_num], dtype=np.float32)
for index, value in np.ndenumerate(neighbor):
# neighbor的第一列是node_id,第二列到最后一列存储的是第一列的邻居节点,如果不存在这么多,则用-1补齐。
if value >= 0 and index[1] > 0:
adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1
print("adj:\n", adj)
(3)节点采样,支持常见的多次跳跃采样与随机游走采样方法等。
- 多跳邻接点采样如(a)图所示,当次采样的节点将作为下次采样的起始点;随机游走方式如(b)图所示,随机选择一条路径依次遍历相邻的节点,对应图中则选择了从Vi到Vj的游走路径。
# 基于多次跳跃进行节点采样
neighbor = dataset.get_sampled_neighbors(nodes_list[0:21], [2], [0])
print("neighbor:\n", neighbor)
# 基于随机游走进行节点采样
meta_path = [0]
walks = dataset.random_walk(nodes_list[0:21], meta_path)
print("walks:\n", walks)
(4)通过节点获取边/通过边获取节点。
# 通过边获取节点
part_edges = dataset.get_all_edges(0)[:10]
nodes = dataset.get_nodes_from_edges(part_edges)
print("part edges:", part_edges)
print("nodes:", nodes)
# 通过节点获取边
# nodes_pair_list = [(0, 1), (1, 2), (1, 3), (1, 4)]
# edges = dataset.get_edges_from_nodes(nodes_pair_list)
# print("edges:", edges)
- 点赞
- 收藏
- 关注作者
评论(0)