Skip to content

Commit 408a65a

Browse files
committed
fix: enable load_data to return pyg data
1 parent d7da26e commit 408a65a

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

graph_datasets/load_data.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import dgl
88
import torch
9+
from torch_geometric.data import Data
10+
from torch_geometric.utils import from_dgl
911

1012
from .data_info import COLA_DATASETS
1113
from .data_info import CRITICAL_DATASETS
@@ -30,9 +32,10 @@ def load_data(
3032
directory: str = DEFAULT_DATA_DIR,
3133
verbosity: int = 0,
3234
source: str = "pyg",
35+
return_type: str = "dgl",
3336
rm_self_loop: bool = True,
3437
to_simple: bool = True,
35-
) -> Tuple[dgl.DGLGraph, torch.Tensor, int]:
38+
) -> Tuple[dgl.DGLGraph, torch.Tensor, int] or Data:
3639
"""Load graphs.
3740
3841
Args:
@@ -42,6 +45,8 @@ def load_data(
4245
verbosity (int, optional): Output debug information. \
4346
The greater, the more detailed. Defaults to 0.
4447
source (str, optional): Source for data loading. Defaults to "pyg".
48+
return_type (str, optional): Return type of the graphs within ["dgl", "pyg"]. \
49+
Defaults to "dgl".
4550
rm_self_loop (str, optional): Remove self loops. Defaults to True.
4651
to_simple (str, optional): Convert to a simple graph with no duplicate undirected edges.
4752
@@ -55,12 +60,27 @@ def load_data(
5560
.. code-block:: python
5661
5762
from graph_datasets import load_data
63+
# dgl graph
5864
graph, label, n_clusters = load_data(
5965
dataset_name='cora',
6066
directory="./data",
67+
return_type="dgl",
6168
source='pyg',
6269
verbosity=3,
70+
rm_self_loop=True,
71+
to_simple=True,
6372
)
73+
# pyG data
74+
data = load_data(
75+
dataset_name='cora',
76+
directory="./data",
77+
return_type="pyg",
78+
source='pyg',
79+
verbosity=3,
80+
rm_self_loop=True,
81+
to_simple=True,
82+
)
83+
6484
"""
6585
dataset_name = (
6686
dataset_name.lower() if dataset_name not in [
@@ -132,8 +152,8 @@ def load_data(
132152
# make label from 0
133153
uni = label.unique()
134154
old2new = dict(zip(uni.numpy().tolist(), list(range(len(uni)))))
135-
newlabel = torch.tensor(list(map(lambda x: old2new[x.item()], label)))
136-
graph.ndata["label"] = newlabel
155+
new_label = torch.tensor(list(map(lambda x: old2new[x.item()], label)))
156+
graph.ndata["label"] = new_label
137157

138158
if verbosity:
139159
print_dataset_info(
@@ -144,7 +164,19 @@ def load_data(
144164
n_clusters=n_clusters,
145165
)
146166

147-
return graph, newlabel, n_clusters
167+
if return_type == "dgl":
168+
return graph, new_label, n_clusters
169+
170+
data = from_dgl(graph)
171+
data.name = dataset_name
172+
data.num_classes = n_clusters
173+
data.x = data.feat
174+
data.y = data.label
175+
data.num_nodes = graph.num_nodes()
176+
data.num_edges = graph.num_edges()
177+
data.edge_index = torch.stack(graph.edges(), dim=0)
178+
179+
return data
148180

149181

150182
if __name__ == "__main__":

0 commit comments

Comments
 (0)