1
+ import os .path as osp
2
+ import torch
3
+ import numpy as np
4
+ from torch_geometric .data import InMemoryDataset , Data
5
+ from torch_geometric .io import read_txt_array
6
+ import torch .nn .functional as F
7
+
8
+ import scipy
9
+ import pickle as pkl
10
+ import csv
11
+ import json
12
+
13
+ import warnings
14
+ warnings .filterwarnings ('ignore' , category = DeprecationWarning )
15
+
16
+
17
+ class CitationDataset (InMemoryDataset ):
18
+ def __init__ (self ,
19
+ root ,
20
+ name ,
21
+ transform = None ,
22
+ pre_transform = None ,
23
+ pre_filter = None ):
24
+ self .name = name
25
+ self .root = root
26
+ super (CitationDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
27
+
28
+ self .data , self .slices = torch .load (self .processed_paths [0 ])
29
+
30
+ @property
31
+ def raw_file_names (self ):
32
+ return ["docs.txt" , "edgelist.txt" , "labels.txt" ]
33
+
34
+ @property
35
+ def processed_file_names (self ):
36
+ return ['data.pt' ]
37
+
38
+ def download (self ):
39
+ pass
40
+
41
+ def process (self ):
42
+ edge_path = osp .join (self .raw_dir , '{}_edgelist.txt' .format (self .name ))
43
+ edge_index = read_txt_array (edge_path , sep = ',' , dtype = torch .long ).t ()
44
+
45
+ docs_path = osp .join (self .raw_dir , '{}_docs.txt' .format (self .name ))
46
+ f = open (docs_path , 'rb' )
47
+ content_list = []
48
+ for line in f .readlines ():
49
+ line = str (line , encoding = "utf-8" )
50
+ content_list .append (line .split ("," ))
51
+ x = np .array (content_list , dtype = float )
52
+ x = torch .from_numpy (x ).to (torch .float )
53
+
54
+ label_path = osp .join (self .raw_dir , '{}_labels.txt' .format (self .name ))
55
+ f = open (label_path , 'rb' )
56
+ content_list = []
57
+ for line in f .readlines ():
58
+ line = str (line , encoding = "utf-8" )
59
+ line = line .replace ("\r " , "" ).replace ("\n " , "" )
60
+ content_list .append (line )
61
+ y = np .array (content_list , dtype = int )
62
+ y = torch .from_numpy (y ).to (torch .int64 )
63
+
64
+ data_list = []
65
+ data = Data (edge_index = edge_index , x = x , y = y )
66
+
67
+ random_node_indices = np .random .permutation (y .shape [0 ])
68
+ training_size = int (len (random_node_indices ) * 0.8 )
69
+ val_size = int (len (random_node_indices ) * 0.1 )
70
+ train_node_indices = random_node_indices [:training_size ]
71
+ val_node_indices = random_node_indices [training_size :training_size + val_size ]
72
+ test_node_indices = random_node_indices [training_size + val_size :]
73
+
74
+ train_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
75
+ train_masks [train_node_indices ] = 1
76
+ val_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
77
+ val_masks [val_node_indices ] = 1
78
+ test_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
79
+ test_masks [test_node_indices ] = 1
80
+
81
+ data .train_mask = train_masks
82
+ data .val_mask = val_masks
83
+ data .test_mask = test_masks
84
+
85
+ if self .pre_transform is not None :
86
+ data = self .pre_transform (data )
87
+
88
+ data_list .append (data )
89
+
90
+ data , slices = self .collate ([data ])
91
+
92
+ torch .save ((data , slices ), self .processed_paths [0 ])
93
+
94
+
95
+ class EllipticDataset (InMemoryDataset ):
96
+ def __init__ (self ,
97
+ root ,
98
+ name ,
99
+ transform = None ,
100
+ pre_transform = None ,
101
+ pre_filter = None ):
102
+ self .name = name
103
+ self .root = root
104
+ super (EllipticDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
105
+
106
+ self .data , self .slices = torch .load (self .processed_paths [0 ])
107
+
108
+ @property
109
+ def raw_file_names (self ):
110
+ return [".pkl" ]
111
+
112
+ @property
113
+ def processed_file_names (self ):
114
+ return ['data.pt' ]
115
+
116
+ def download (self ):
117
+ pass
118
+
119
+ def process (self ):
120
+ path = osp .join (self .raw_dir , '{}.pkl' .format (self .name ))
121
+ result = pkl .load (open (path , 'rb' ))
122
+ A , label , features = result
123
+ label = label + 1
124
+ edge_index = torch .tensor (np .array (A .nonzero ()), dtype = torch .long )
125
+ features = np .array (features )
126
+ x = torch .from_numpy (features ).to (torch .float )
127
+ y = torch .tensor (label ).to (torch .int64 )
128
+
129
+ data_list = []
130
+ data = Data (edge_index = edge_index , x = x , y = y )
131
+
132
+ random_node_indices = np .random .permutation (y .shape [0 ])
133
+ training_size = int (len (random_node_indices ) * 0.8 )
134
+ val_size = int (len (random_node_indices ) * 0.1 )
135
+ train_node_indices = random_node_indices [:training_size ]
136
+ val_node_indices = random_node_indices [training_size :training_size + val_size ]
137
+ test_node_indices = random_node_indices [training_size + val_size :]
138
+
139
+ train_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
140
+ train_masks [train_node_indices ] = 1
141
+ val_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
142
+ val_masks [val_node_indices ] = 1
143
+ test_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
144
+ test_masks [test_node_indices ] = 1
145
+
146
+ data .train_mask = train_masks
147
+ data .val_mask = val_masks
148
+ data .test_mask = test_masks
149
+
150
+ if self .pre_transform is not None :
151
+ data = self .pre_transform (data )
152
+
153
+ data_list .append (data )
154
+
155
+ data , slices = self .collate ([data ])
156
+
157
+ torch .save ((data , slices ), self .processed_paths [0 ])
158
+
159
+
160
+ class TwitchDataset (InMemoryDataset ):
161
+ def __init__ (self ,
162
+ root ,
163
+ name ,
164
+ transform = None ,
165
+ pre_transform = None ,
166
+ pre_filter = None ):
167
+ self .name = name
168
+ self .root = root
169
+ super (TwitchDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
170
+
171
+ self .data , self .slices = torch .load (self .processed_paths [0 ])
172
+
173
+ @property
174
+ def raw_file_names (self ):
175
+ return ["edges.csv, features.json, target.csv" ]
176
+
177
+ @property
178
+ def processed_file_names (self ):
179
+ return ['data.pt' ]
180
+
181
+ def download (self ):
182
+ pass
183
+
184
+ def load_twitch (self , lang ):
185
+ assert lang in ('DE' , 'EN' , 'FR' ), 'Invalid dataset'
186
+ filepath = self .raw_dir
187
+ label = []
188
+ node_ids = []
189
+ src = []
190
+ targ = []
191
+ uniq_ids = set ()
192
+ with open (f"{ filepath } /musae_{ lang } _target.csv" , 'r' ) as f :
193
+ reader = csv .reader (f )
194
+ next (reader )
195
+ for row in reader :
196
+ node_id = int (row [5 ])
197
+ # handle FR case of non-unique rows
198
+ if node_id not in uniq_ids :
199
+ uniq_ids .add (node_id )
200
+ label .append (int (row [2 ]== "True" ))
201
+ node_ids .append (int (row [5 ]))
202
+
203
+ node_ids = np .array (node_ids , dtype = np .int32 )
204
+
205
+ with open (f"{ filepath } /musae_{ lang } _edges.csv" , 'r' ) as f :
206
+ reader = csv .reader (f )
207
+ next (reader )
208
+ for row in reader :
209
+ src .append (int (row [0 ]))
210
+ targ .append (int (row [1 ]))
211
+
212
+ with open (f"{ filepath } /musae_{ lang } _features.json" , 'r' ) as f :
213
+ j = json .load (f )
214
+
215
+ src = np .array (src )
216
+ targ = np .array (targ )
217
+ label = np .array (label )
218
+
219
+ inv_node_ids = {node_id :idx for (idx , node_id ) in enumerate (node_ids )}
220
+ reorder_node_ids = np .zeros_like (node_ids )
221
+ for i in range (label .shape [0 ]):
222
+ reorder_node_ids [i ] = inv_node_ids [i ]
223
+
224
+ n = label .shape [0 ]
225
+ A = scipy .sparse .csr_matrix ((np .ones (len (src )), (np .array (src ), np .array (targ ))), shape = (n ,n ))
226
+ features = np .zeros ((n ,3170 ))
227
+ for node , feats in j .items ():
228
+ if int (node ) >= n :
229
+ continue
230
+ features [int (node ), np .array (feats , dtype = int )] = 1
231
+ new_label = label [reorder_node_ids ]
232
+ label = new_label
233
+
234
+ return A , label , features
235
+
236
+ def process (self ):
237
+ A , label , features = self .load_twitch (self .name )
238
+ edge_index = torch .tensor (np .array (A .nonzero ()), dtype = torch .long )
239
+ features = np .array (features )
240
+ x = torch .from_numpy (features ).to (torch .float )
241
+ y = torch .from_numpy (label ).to (torch .int64 )
242
+
243
+ data_list = []
244
+ data = Data (edge_index = edge_index , x = x , y = y )
245
+
246
+ random_node_indices = np .random .permutation (y .shape [0 ])
247
+ training_size = int (len (random_node_indices ) * 0.8 )
248
+ val_size = int (len (random_node_indices ) * 0.1 )
249
+ train_node_indices = random_node_indices [:training_size ]
250
+ val_node_indices = random_node_indices [training_size :training_size + val_size ]
251
+ test_node_indices = random_node_indices [training_size + val_size :]
252
+
253
+ train_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
254
+ train_masks [train_node_indices ] = 1
255
+ val_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
256
+ val_masks [val_node_indices ] = 1
257
+ test_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
258
+ test_masks [test_node_indices ] = 1
259
+
260
+ data .train_mask = train_masks
261
+ data .val_mask = val_masks
262
+ data .test_mask = test_masks
263
+
264
+ if self .pre_transform is not None :
265
+ data = self .pre_transform (data )
266
+
267
+ data_list .append (data )
268
+
269
+ data , slices = self .collate ([data ])
270
+
271
+ torch .save ((data , slices ), self .processed_paths [0 ])
0 commit comments