1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
| Data = namedtuple('Data', ['x', 'y', 'adjacency', 'train_mask', 'val_mask', 'test_mask'])
def tensor_from_numpy(x, device): return torch.from_numpy(x).to(device) class CoraData(object): download_url = "https://raw.githubusercontent.com/kimiyoung/planetoid/master/data" filenames = ["ind.cora.{}".format(name) for name in ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']] def __init__(self, data_root="cora", rebuild=False): """Cora数据,包括数据下载,处理,加载等功能 当数据的缓存文件存在时,将使用缓存文件,否则将下载、进行处理,并缓存到磁盘 处理之后的数据可以通过属性 .data 获得,它将返回一个数据对象,包括如下几部分: * x: 节点的特征,维度为 2708 * 1433,类型为 np.ndarray * y: 节点的标签,总共包括7个类别,类型为 np.ndarray * adjacency: 邻接矩阵,维度为 2708 * 2708,类型为 scipy.sparse.coo.coo_matrix * train_mask: 训练集掩码向量,维度为 2708,当节点属于训练集时,相应位置为True,否则False * val_mask: 验证集掩码向量,维度为 2708,当节点属于验证集时,相应位置为True,否则False * test_mask: 测试集掩码向量,维度为 2708,当节点属于测试集时,相应位置为True,否则False Args: ------- data_root: string, optional 存放数据的目录,原始数据路径: {data_root}/raw 缓存数据路径: {data_root}/processed_cora.pkl rebuild: boolean, optional 是否需要重新构建数据集,当设为True时,如果存在缓存数据也会重建数据 """ self.data_root = data_root save_file = osp.join(self.data_root, "processed_cora.pkl") if osp.exists(save_file) and not rebuild: print("Using Cached file: {}".format(save_file)) self._data = pickle.load(open(save_file, "rb")) else: self.maybe_download() self._data = self.process_data() with open(save_file, "wb") as f: pickle.dump(self.data, f) print("Cached file: {}".format(save_file)) @property def data(self): """返回Data数据对象,包括x, y, adjacency, train_mask, val_mask, test_mask""" return self._data def process_data(self): """ 处理数据,得到节点特征和标签,邻接矩阵,训练集、验证集以及测试集 引用自:https://github.com/rusty1s/pytorch_geometric """ print("Process data ...") _, tx, allx, y, ty, ally, graph, test_index = [self.read_data( osp.join(self.data_root, "raw", name)) for name in self.filenames] train_index = np.arange(y.shape[0]) val_index = np.arange(y.shape[0], y.shape[0] + 500) sorted_test_index = sorted(test_index) x = np.concatenate((allx, tx), axis=0) y = np.concatenate((ally, ty), axis=0).argmax(axis=1) x[test_index] = x[sorted_test_index] y[test_index] = y[sorted_test_index] num_nodes = x.shape[0] train_mask = np.zeros(num_nodes, dtype=np.bool) val_mask = np.zeros(num_nodes, dtype=np.bool) test_mask = np.zeros(num_nodes, dtype=np.bool) train_mask[train_index] = True val_mask[val_index] = True test_mask[test_index] = True adjacency = self.build_adjacency(graph) print("Node's feature shape: ", x.shape) print("Node's label shape: ", y.shape) print("Adjacency's shape: ", adjacency.shape) print("Number of training nodes: ", train_mask.sum()) print("Number of validation nodes: ", val_mask.sum()) print("Number of test nodes: ", test_mask.sum()) return Data(x=x, y=y, adjacency=adjacency, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) def maybe_download(self): save_path = os.path.join(self.data_root, "raw") for name in self.filenames: if not osp.exists(osp.join(save_path, name)): self.download_data( "{}/{}".format(self.download_url, name), save_path) @staticmethod def build_adjacency(adj_dict): """根据下载的邻接表创建邻接矩阵""" edge_index = [] num_nodes = len(adj_dict) for src, dst in adj_dict.items(): edge_index.extend([src, v] for v in dst) edge_index.extend([v, src] for v in dst) edge_index = list(k for k, _ in itertools.groupby(sorted(edge_index))) edge_index = np.asarray(edge_index) adjacency = sp.coo_matrix((np.ones(len(edge_index)), (edge_index[:, 0], edge_index[:, 1])), shape=(num_nodes, num_nodes), dtype="float32") return adjacency @staticmethod def read_data(path): """使用不同的方式读取原始数据以进一步处理""" name = osp.basename(path) if name == "ind.cora.test.index": out = np.genfromtxt(path, dtype="int64") return out else: out = pickle.load(open(path, "rb"), encoding="latin1") out = out.toarray() if hasattr(out, "toarray") else out return out @staticmethod def download_data(url, save_path): """数据下载工具,当原始数据不存在时将会进行下载""" if not os.path.exists(save_path): os.makedirs(save_path) data = urllib.request.urlopen(url) filename = os.path.split(url)[-1] with open(os.path.join(save_path, filename), 'wb') as f: f.write(data.read()) return True @staticmethod def normalization(adjacency): """计算 L=D^-0.5 * (A+I) * D^-0.5""" adjacency += sp.eye(adjacency.shape[0]) degree = np.array(adjacency.sum(1)) d_hat = sp.diags(np.power(degree, -0.5).flatten()) return d_hat.dot(adjacency).dot(d_hat).tocoo()
|