Skip to content

Commit 4b7276a

Browse files
authored
Code Init
0 parents  commit 4b7276a

28 files changed

+1922
-0
lines changed

README.md

+391
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,391 @@
1+
<div align="center">
2+
3+
# GraphAny: A Foundation Model for Node Classification on Any Graph #
4+
5+
[![pytorch](https://img.shields.io/badge/PyTorch_2.1+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
6+
[![lightning](https://img.shields.io/badge/-Lightning_2.2+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/)
7+
[![pyg](https://img.shields.io/badge/PyG_2.4+-3C2179?logo=pyg&logoColor=#3C2179)](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)
8+
[![arxiv](http://img.shields.io/badge/arxiv-2405.20445-blue.svg)](http://arxiv.org/abs/2405.20445)
9+
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
10+
![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)
11+
12+
</div>
13+
14+
Original PyTorch implementation of [GraphAny].
15+
16+
Authored by [Jianan Zhao], [Hesham Mostafa], [Michael Galkin], [Michael Bronstein],
17+
[Zhaocheng Zhu], and [Jian Tang].
18+
19+
[Jianan Zhao]: https://andyjzhao.github.io/
20+
[Hesham Mostafa]: https://www.linkedin.com/in/hesham-mostafa-79ba93237
21+
[Zhaocheng Zhu]: https://kiddozhu.github.io
22+
[Michael Galkin]: https://migalkin.github.io/
23+
[Michael Bronstein]: https://www.cs.ox.ac.uk/people/michael.bronstein/
24+
[Jian Tang]: https://jian-tang.com/
25+
[GraphAny]: https://github.com/AndyJZhao/GraphAny
26+
27+
Links to be updated later.
28+
29+
## Overview ##
30+
31+
![Foundation Model on Node Classification](assets/fm_on_node_classification.png)
32+
33+
GraphAny is a foundation model for node classification. A single pre-trained GraphAny
34+
model performs node classification tasks on any graph with any feature and label
35+
spaces. Performance-wise, averaged on 30+ graphs, a single pre-trained GraphAny model
36+
is better **_in inference mode_** than many supervised models (e.g., MLP, GCN, GAT)
37+
trained specifically for each graph. Following the pretrain-inference paradigm of
38+
foundation models, you can perform training from scratch and inference on 30 datasets
39+
as shown in [Training from scratch](#training-from-scratch).
40+
41+
This repository is based on PyTorch 2.1, Pytorch-Lightning 2.2, PyG 2.4, DGL 2.1, and Hydra 1.3.
42+
43+
## Environment Setup ##
44+
45+
Our experiments are designed to run on both GPU and CPU platforms. A GPU with 16 GB
46+
of memory is sufficient to handle all 31 datasets, and we have also tested the setup
47+
on a single CPU (specifically, an M1 MacBook).
48+
49+
To configure your environment, use the following commands based on your setup:
50+
51+
```bash
52+
# For setups with a GPU (requires CUDA 11.8):
53+
conda env create -f environment.yaml
54+
# For setups using a CPU (tested on macOS with M1 chip):
55+
conda env create -f environment_cpu.yaml
56+
```
57+
58+
## File Structure ##
59+
60+
```
61+
├── README.md
62+
├── checkpoints
63+
├── configs
64+
│ ├── data.yaml
65+
│ ├── main.yaml
66+
│ └── model.yaml
67+
├── environment.yaml
68+
├── environment_cpu.yaml
69+
└── graphany
70+
├── __init__.py
71+
├── data.py
72+
├── model.py
73+
├── run.py
74+
└── utils
75+
```
76+
77+
## Reproduce Our Results ##
78+
79+
### Training Foundation Models from Scratch ###
80+
81+
This section would detail how users can train GraphAny on one dataset (Cora,
82+
Wisconsin, Arxiv, or Product) and evaluate on all 31 datasets. You can reproduce
83+
our results via the commands below. The checkpoints of these commands are saved in
84+
the `checkpoints/` folder.
85+
86+
```bash
87+
cd path/to/this/repo
88+
# Reproduce GraphAny-Cora: test_acc= 66.98 for seed 0
89+
python graphany/run.py dataset=CoraXAll total_steps=500 n_hidden=64 n_mlp_layer=1 entropy=2 n_per_label_examples=5
90+
# Reproduce GraphAny-Wisconsin: test_acc= 67.36 for seed 0
91+
python graphany/run.py dataset=WisXAll total_steps=1000 n_hidden=32 n_mlp_layer=2 entropy=1 n_per_label_examples=5
92+
# Reproduce GraphAny-Arxiv: test_acc=67.58 for seed 0
93+
python graphany/run.py dataset=ArxivXAll total_steps=1000 n_hidden=128 n_mlp_layer=2 entropy=1 n_per_label_examples=3
94+
# Reproduce GraphAny-Product: test_acc=67.77 for seed 0
95+
python graphany/run.py dataset=ProdXAll total_steps=1000 n_hidden=128 n_mlp_layer=2 entropy=1 n_per_label_examples=3
96+
```
97+
98+
### Inference Using Pre-trained Checkpoints ###
99+
100+
Once trained, GraphAny enjoys the ability to perform inference on any graph. You
101+
can use our trained checkpoint to run inference on your graph easily. Here, we
102+
showcase an example of loading a GraphAny model trained on Arxiv and perform
103+
inference on Cora and Citeseer.
104+
105+
**Step 1**: Define your custom combined dataset config in the `configs/data.yaml` :
106+
107+
```yaml
108+
# configs/data.yaml
109+
_dataset_lookup:
110+
# Train on Arxiv, inference on Cora and Citeseer
111+
CoraCiteInference:
112+
train: [ Arxiv ]
113+
eval: [ Cora, Citeseer ]
114+
```
115+
116+
**Step 2** _(optional)_: Define your dataset processing logic in graph_any/data.py.
117+
This step is necessary only if you are not using our pre-processed data. If you
118+
choose to use our provided datasets, you can skip this step and proceed directly to
119+
Step 3.
120+
121+
**Step 3**: Inference using pre-trained model using command:
122+
123+
```bash
124+
python graphany/run.py prev_ckpt=checkpoints/graph_any_arxiv.pt total_steps=0 dataset=CoraCiteInference
125+
# ind/cora_test_acc 79.4 ind/cite_test_acc 68.4
126+
```
127+
128+
129+
<details>
130+
<summary>Example Output Log</summary>
131+
<pre><code># Training Logs
132+
CRITICAL {
133+
'ind/cora_val_acc': 75.4,
134+
'ind/cite_val_acc': 70.4,
135+
'val_acc': 72.9,
136+
'trans_val_acc': nan, # Not applicable as Arxiv is not included in the evaluation set
137+
'ind_val_acc': 72.9,
138+
'heldout_val_acc': 70.4,
139+
'ind/cora_test_acc': 79.4,
140+
'ind/cite_test_acc': 68.4,
141+
'test_acc': 73.9,
142+
'trans_test_acc': nan,
143+
'ind_test_acc': 73.9,
144+
'heldout_test_acc': 68.4
145+
}
146+
INFO Finished main at 06-01 05:07:49, running time = 2.52s.
147+
</code></pre>
148+
149+
Note: The `trans_test_acc` field is not applicable since Arxiv is not specified in
150+
the evaluation datasets. Additionally, the heldout accuracies are calculated by
151+
excluding datasets specified as transductive in `configs/data.yaml` (default
152+
settings: `_trans_datasets: [Arxiv, Product, Cora, Wisconsin]`). To utilize the heldout
153+
metrics correctly, please adjust these transductive datasets in your configuration
154+
to reflect your specific dataset inductive split settings.
155+
</details>
156+
157+
## Configuration Details ##
158+
We use [Hydra](https://hydra.cc/docs/intro/) to manage the configuration. The
159+
configs are organized in three files under the `configs/` directory:
160+
161+
### `main.yaml` ###
162+
Settings for experiments, including random seed, wandb, path,
163+
hydra, and logging configs.
164+
165+
### `data.yaml` ###
166+
This file contains settings for datasets, including preprocessing specifications,
167+
metadata, and lookup configurations. Here’s an overview of the key elements:
168+
169+
<details>
170+
171+
#### Dataset Preprocessing Options ####
172+
- `preprocess_device: gpu` — Specifies the device for computing propagated features $\boldsymbol{F}$. Set to cpu if your GPU memory is below 32GB.
173+
- `add_self_loop: false` — Specifies whether to add self-loops to the nodes in the
174+
graph.
175+
- `to_bidirected: true` — If set to true, edges are made bidirectional.
176+
- `n_hops: 2` — Defines the maximum number of hops of message passing. In our
177+
experiments, besides Linear, we use LinearSGC1, LinearSGC1, LinearHGC1,
178+
LinearHGC2, which predicts information within 2 hops of message passing.
179+
180+
#### Train and Evaluation Dataset Lookup ####
181+
- The datasets for training and evaluation are dynamically selected based on the
182+
command-line arguments by looking up from the `_dataset_lookup` configuration
183+
- Example: Using `dataset=CoraXAll` sets `train_datasets` to `[Cora]` and
184+
`eval_datasets` to all datasets (31 in total).
185+
186+
```yaml
187+
train_datasets: ${oc.select:_dataset_lookup.${dataset}.train,${dataset}}
188+
eval_datasets: ${oc.select:_dataset_lookup.${dataset}.eval,${dataset}}
189+
_dataset_lookup:
190+
- CoraXAll:
191+
- train: [Cora]
192+
- eval: ${_all_datasets}
193+
```
194+
195+
Please define your own dataset combinations in `_dataset_lookup` if desired.
196+
197+
#### Detailed Dataset Configurations ####
198+
The dataset meta-data stores the meta information including the interfaces [DGL],
199+
[PyG], [OGB], [Heterophilous] and their aliases (e.g. `Planetoid.Cora`) to load the
200+
dataset. The statistics are provided in the comment with a format of 'n_nodes,
201+
n_edges, n_feat_dim, n_labels'. For example:
202+
203+
[DGL]: https://docs.dgl.ai/en/2.0.x/api/python/dgl.data.html#node-prediction-datasets
204+
[PyG]: https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html
205+
[OGB]: https://ogb.stanford.edu/docs/nodeprop/
206+
[Heterophilous]: https://arxiv.org/abs/2302.11640
207+
208+
```yaml
209+
_ds_meta_data:
210+
Arxiv: ogb, ogbn-arxiv # 168,343 1,166,243 100 40
211+
Cora: pyg, Planetoid.Cora # 2,708 10,556 1,433 7
212+
```
213+
</details>
214+
215+
### `model.yaml` ###
216+
This file contains the settings for models and training.
217+
218+
<details>
219+
220+
GraphAny leverages **_interactions between predictions_** as input features for an
221+
MLP to calculate inductive attention scores. These inputs are termed "**_feature
222+
channels_**" and are defined in the configuration file as `feat_chn`. Subsequently,
223+
the outputs from LinearGNNs, referred to as "**_prediction channels_**", are
224+
combined using inductive attention scores and are defined as `pred_chn` in the
225+
configuration file. The default settings are:
226+
227+
```yaml
228+
feat_chn: X+L1+L2+H1+H2 # X=Linear, L1=LinearSGC1, L2=LinearSGC2, H1=LinearHGC1, H2=LinearHGC2
229+
pred_chn: X+L1+L2 # H1 and H2 channels are masked to enhance convergence speed.
230+
```
231+
232+
It is important to note that the feature channels and prediction channels do not
233+
need to be identical. Empirical observations indicate that masking LinearHGC1 and
234+
LinearHGC2 leads to faster convergence and marginally improved results (results in
235+
Table 2, Figure 1, and Figure 5). Furthermore, for the attention visualizations in
236+
Figure 6, all five channels (`pred_chn=X+L1+L2+H1+H2`) are employed. This
237+
demonstrates GraphAny's capability to learn inductive attention that effectively
238+
identifies critical channels for unseen graphs.
239+
240+
Other model parameters and default values:
241+
```yaml
242+
# The entropy to normalize the distance features (conditional gaussian distribution). The standard deviation of conditional gaussian distribution is dynamically determined via binary search, default to 1
243+
entropy: 1
244+
attn_temp: 5 # The temperature for attention normalization
245+
n_hidden: 128 # The hidden dimension of MLP
246+
n_mlp_layer: 2
247+
```
248+
</details>
249+
250+
251+
## Bring Your Own Dataset ##
252+
253+
<details>
254+
<summary>
255+
We support three major sources of graph dataset interfaces:
256+
<a href="https://docs.dgl.ai/en/2.0.x/api/python/dgl.data.html#node-prediction-datasets">DGL</a>,
257+
<a href="https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html">PyG</a>, and
258+
<a href="https://ogb.stanford.edu/docs/nodeprop/">OGB</a>.
259+
If you are interested in adding your own dataset, here's how we integrated the cleaned
260+
Texas dataset processed by <a href="https://arxiv.org/abs/2302.11640">this paper</a>.
261+
<i>The original Texas dataset contains 5 classes, with a class with only one node,
262+
which makes using this class for training and evaluation meaningless.</i>
263+
</summary>
264+
265+
In the example below, we demonstrate how to add a dataset called "Texas" with 4
266+
classes from a new data source termed `heterophilous`.
267+
268+
**Step 1**: Update `configs/data.yaml`:
269+
270+
First, define your dataset's metadata.
271+
272+
```yaml
273+
# configs/data.yaml
274+
_ds_meta_data: # key: dataset name, value: data_source, alias
275+
Texas: heterophilous, texas_4_classes
276+
```
277+
278+
The `data_source` is set as 'heterophilous', which is handled differently from other
279+
sources ('pyg', 'dgl', 'ogb').
280+
281+
Additionally, update the `_dataset_lookup` with a new setting:
282+
283+
```yaml
284+
# configs/data.yaml
285+
_dataset_lookup:
286+
Debug:
287+
train: [ Wisconsin ]
288+
eval: [ Texas ]
289+
```
290+
291+
**Step 2**: Implement the dataset interface:
292+
293+
Implement `load_heterophilous_dataset` in `data.py` to download and process the dataset.
294+
295+
```python
296+
import numpy as np
297+
import torch
298+
from graphany.data import download_url
299+
import dgl
300+
301+
def load_heterophilous_dataset(url, raw_dir):
302+
# Converts Heterophilous dataset to DGL Graph format
303+
download_path = download_url(url, raw_dir)
304+
data = np.load(download_path)
305+
node_features = torch.tensor(data['node_features'])
306+
labels = torch.tensor(data['node_labels'])
307+
edges = torch.tensor(data['edges'])
308+
309+
graph = dgl.graph((edges[:, 0], edges[:, 1]),
310+
num_nodes=len(node_features), idtype=torch.int32)
311+
num_classes = len(labels.unique())
312+
train_mask, val_mask, test_mask = torch.tensor(data['train_mask']), torch.tensor(data['val_mask']), torch.tensor(
313+
data['test_mask'])
314+
315+
return graph, labels, num_classes, node_features, train_mask, val_mask, test_mask
316+
```
317+
318+
**Step 3**: Update `GraphDataset` class in `data.py`:
319+
320+
Modify the initialization and dataset loading functions:
321+
322+
```python
323+
# In GraphDataset.__init__():
324+
if self.data_source in ['dgl', 'pyg', 'ogb']:
325+
pass # Code for other data sources omitted for brevity
326+
elif self.data_source == 'heterophilous':
327+
target = '.data.load_heterophilous_dataset'
328+
url = f'https://example.com/data/{ds_alias}.npz'
329+
ds_init_args = {
330+
"_target_": target, 'raw_dir': f'{cfg.dirs.data_storage}{self.data_source}/', 'url': url
331+
}
332+
else:
333+
raise NotImplementedError(f'Unsupported data source: {self.data_source}')
334+
335+
# In GraphDataset.load_dataset():
336+
from hydra.utils import instantiate
337+
def load_dataset(self, data_init_args):
338+
dataset = instantiate(data_init_args)
339+
if self.data_source in ['dgl', 'pyg', 'ogb']:
340+
pass # Code for other data sources omitted for brevity
341+
elif self.data_source == 'heterophilous':
342+
g, label, num_class, feat, train_mask, val_mask, test_mask = dataset
343+
# Rest of the code omitted for brevity
344+
```
345+
346+
You can now run the code using the following commands:
347+
348+
```bash
349+
# Training from scratch
350+
python graphany/run.py dataset=Debug total_steps=500
351+
# Inference using existing checkpoint
352+
python graphany/run.py prev_ckpt=checkpoints/graph_any_wisconsin.pt dataset=Debug total_steps=0
353+
```
354+
</details>
355+
356+
## Using Wandb for Enhanced Visualization ##
357+
358+
We recommend using [Weights & Biases](https://wandb.ai/) (wandb) for advanced
359+
visualization capabilities. As an example, consider the visualizations for the
360+
GraphAny-Arxiv project shown below, which illustrate the validation accuracy across
361+
different data set categories:
362+
- **Transductive**: Training dataset (i.e. Arxiv)
363+
- **Heldout**: 27 datasets (except Cora, Wisconsin, Arxiv, Product)
364+
- **Inductive**: 30 datasets (except arxiv)
365+
- **Overall**: 31 datasets (all datasets)
366+
367+
![wandb_training_curve](assets/wandb_training_curve.png)
368+
369+
By default, wandb integration is disabled. To enable and configure wandb for your
370+
project, use the following command, substituting `YourOwnWandbEntity` with your
371+
actual Weights & Biases entity name:
372+
373+
```bash
374+
use_wandb=true wandb_proj=GraphAny wandb_entity=YourOwnWandbEntity
375+
```
376+
377+
This setup will allow you to track and visualize metrics dynamically.
378+
379+
## Citation ##
380+
If you find this codebase useful in your research, please cite the paper.
381+
382+
```bibtex
383+
@article{zhao2024graphany,
384+
title={GraphAny: A Foundation Model for Node Classification on Any Graph},
385+
author={Jianan Zhao and Hesham Mostafa and Michael Galkin and Michael Bronstein and Zhaocheng Zhu and Jian Tang},
386+
year={2024},
387+
eprint={2405.20445},
388+
archivePrefix={arXiv},
389+
primaryClass={cs.LG}
390+
}
391+
```

assets/fm_on_node_classification.png

130 KB
Loading

assets/graph_any_model.png

58 KB
Loading

assets/wandb_training_curve.png

109 KB
Loading

checkpoints/graph_any_arxiv.pt

89.1 KB
Binary file not shown.

checkpoints/graph_any_cora.pt

36.9 KB
Binary file not shown.

checkpoints/graph_any_product.pt

86.5 KB
Binary file not shown.

checkpoints/graph_any_wisconsin.pt

58.2 KB
Binary file not shown.

0 commit comments

Comments
 (0)