|
| 1 | +<div align="center"> |
| 2 | + |
| 3 | +# GraphAny: A Foundation Model for Node Classification on Any Graph # |
| 4 | + |
| 5 | +[](https://pytorch.org/get-started/locally/) |
| 6 | +[](https://pytorchlightning.ai/) |
| 7 | +[](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) |
| 8 | +[](http://arxiv.org/abs/2405.20445) |
| 9 | +[](https://hydra.cc/) |
| 10 | + |
| 11 | + |
| 12 | +</div> |
| 13 | + |
| 14 | +Original PyTorch implementation of [GraphAny]. |
| 15 | + |
| 16 | +Authored by [Jianan Zhao], [Hesham Mostafa], [Michael Galkin], [Michael Bronstein], |
| 17 | +[Zhaocheng Zhu], and [Jian Tang]. |
| 18 | + |
| 19 | +[Jianan Zhao]: https://andyjzhao.github.io/ |
| 20 | +[Hesham Mostafa]: https://www.linkedin.com/in/hesham-mostafa-79ba93237 |
| 21 | +[Zhaocheng Zhu]: https://kiddozhu.github.io |
| 22 | +[Michael Galkin]: https://migalkin.github.io/ |
| 23 | +[Michael Bronstein]: https://www.cs.ox.ac.uk/people/michael.bronstein/ |
| 24 | +[Jian Tang]: https://jian-tang.com/ |
| 25 | +[GraphAny]: https://github.com/AndyJZhao/GraphAny |
| 26 | + |
| 27 | +Links to be updated later. |
| 28 | + |
| 29 | +## Overview ## |
| 30 | + |
| 31 | + |
| 32 | + |
| 33 | +GraphAny is a foundation model for node classification. A single pre-trained GraphAny |
| 34 | +model performs node classification tasks on any graph with any feature and label |
| 35 | +spaces. Performance-wise, averaged on 30+ graphs, a single pre-trained GraphAny model |
| 36 | +is better **_in inference mode_** than many supervised models (e.g., MLP, GCN, GAT) |
| 37 | +trained specifically for each graph. Following the pretrain-inference paradigm of |
| 38 | +foundation models, you can perform training from scratch and inference on 30 datasets |
| 39 | +as shown in [Training from scratch](#training-from-scratch). |
| 40 | + |
| 41 | +This repository is based on PyTorch 2.1, Pytorch-Lightning 2.2, PyG 2.4, DGL 2.1, and Hydra 1.3. |
| 42 | + |
| 43 | +## Environment Setup ## |
| 44 | + |
| 45 | +Our experiments are designed to run on both GPU and CPU platforms. A GPU with 16 GB |
| 46 | +of memory is sufficient to handle all 31 datasets, and we have also tested the setup |
| 47 | +on a single CPU (specifically, an M1 MacBook). |
| 48 | + |
| 49 | +To configure your environment, use the following commands based on your setup: |
| 50 | + |
| 51 | +```bash |
| 52 | +# For setups with a GPU (requires CUDA 11.8): |
| 53 | +conda env create -f environment.yaml |
| 54 | +# For setups using a CPU (tested on macOS with M1 chip): |
| 55 | +conda env create -f environment_cpu.yaml |
| 56 | +``` |
| 57 | + |
| 58 | +## File Structure ## |
| 59 | + |
| 60 | +``` |
| 61 | +├── README.md |
| 62 | +├── checkpoints |
| 63 | +├── configs |
| 64 | +│ ├── data.yaml |
| 65 | +│ ├── main.yaml |
| 66 | +│ └── model.yaml |
| 67 | +├── environment.yaml |
| 68 | +├── environment_cpu.yaml |
| 69 | +└── graphany |
| 70 | + ├── __init__.py |
| 71 | + ├── data.py |
| 72 | + ├── model.py |
| 73 | + ├── run.py |
| 74 | + └── utils |
| 75 | +``` |
| 76 | + |
| 77 | +## Reproduce Our Results ## |
| 78 | + |
| 79 | +### Training Foundation Models from Scratch ### |
| 80 | + |
| 81 | +This section would detail how users can train GraphAny on one dataset (Cora, |
| 82 | +Wisconsin, Arxiv, or Product) and evaluate on all 31 datasets. You can reproduce |
| 83 | +our results via the commands below. The checkpoints of these commands are saved in |
| 84 | +the `checkpoints/` folder. |
| 85 | + |
| 86 | +```bash |
| 87 | +cd path/to/this/repo |
| 88 | +# Reproduce GraphAny-Cora: test_acc= 66.98 for seed 0 |
| 89 | +python graphany/run.py dataset=CoraXAll total_steps=500 n_hidden=64 n_mlp_layer=1 entropy=2 n_per_label_examples=5 |
| 90 | +# Reproduce GraphAny-Wisconsin: test_acc= 67.36 for seed 0 |
| 91 | +python graphany/run.py dataset=WisXAll total_steps=1000 n_hidden=32 n_mlp_layer=2 entropy=1 n_per_label_examples=5 |
| 92 | +# Reproduce GraphAny-Arxiv: test_acc=67.58 for seed 0 |
| 93 | +python graphany/run.py dataset=ArxivXAll total_steps=1000 n_hidden=128 n_mlp_layer=2 entropy=1 n_per_label_examples=3 |
| 94 | +# Reproduce GraphAny-Product: test_acc=67.77 for seed 0 |
| 95 | +python graphany/run.py dataset=ProdXAll total_steps=1000 n_hidden=128 n_mlp_layer=2 entropy=1 n_per_label_examples=3 |
| 96 | +``` |
| 97 | + |
| 98 | +### Inference Using Pre-trained Checkpoints ### |
| 99 | + |
| 100 | +Once trained, GraphAny enjoys the ability to perform inference on any graph. You |
| 101 | +can use our trained checkpoint to run inference on your graph easily. Here, we |
| 102 | +showcase an example of loading a GraphAny model trained on Arxiv and perform |
| 103 | +inference on Cora and Citeseer. |
| 104 | + |
| 105 | +**Step 1**: Define your custom combined dataset config in the `configs/data.yaml` : |
| 106 | + |
| 107 | +```yaml |
| 108 | +# configs/data.yaml |
| 109 | +_dataset_lookup: |
| 110 | + # Train on Arxiv, inference on Cora and Citeseer |
| 111 | + CoraCiteInference: |
| 112 | + train: [ Arxiv ] |
| 113 | + eval: [ Cora, Citeseer ] |
| 114 | +``` |
| 115 | +
|
| 116 | +**Step 2** _(optional)_: Define your dataset processing logic in graph_any/data.py. |
| 117 | +This step is necessary only if you are not using our pre-processed data. If you |
| 118 | +choose to use our provided datasets, you can skip this step and proceed directly to |
| 119 | +Step 3. |
| 120 | +
|
| 121 | +**Step 3**: Inference using pre-trained model using command: |
| 122 | +
|
| 123 | +```bash |
| 124 | +python graphany/run.py prev_ckpt=checkpoints/graph_any_arxiv.pt total_steps=0 dataset=CoraCiteInference |
| 125 | +# ind/cora_test_acc 79.4 ind/cite_test_acc 68.4 |
| 126 | +``` |
| 127 | + |
| 128 | + |
| 129 | +<details> |
| 130 | +<summary>Example Output Log</summary> |
| 131 | +<pre><code># Training Logs |
| 132 | +CRITICAL { |
| 133 | +'ind/cora_val_acc': 75.4, |
| 134 | +'ind/cite_val_acc': 70.4, |
| 135 | +'val_acc': 72.9, |
| 136 | +'trans_val_acc': nan, # Not applicable as Arxiv is not included in the evaluation set |
| 137 | +'ind_val_acc': 72.9, |
| 138 | +'heldout_val_acc': 70.4, |
| 139 | +'ind/cora_test_acc': 79.4, |
| 140 | +'ind/cite_test_acc': 68.4, |
| 141 | +'test_acc': 73.9, |
| 142 | +'trans_test_acc': nan, |
| 143 | +'ind_test_acc': 73.9, |
| 144 | +'heldout_test_acc': 68.4 |
| 145 | +} |
| 146 | +INFO Finished main at 06-01 05:07:49, running time = 2.52s. |
| 147 | +</code></pre> |
| 148 | + |
| 149 | +Note: The `trans_test_acc` field is not applicable since Arxiv is not specified in |
| 150 | +the evaluation datasets. Additionally, the heldout accuracies are calculated by |
| 151 | +excluding datasets specified as transductive in `configs/data.yaml` (default |
| 152 | +settings: `_trans_datasets: [Arxiv, Product, Cora, Wisconsin]`). To utilize the heldout |
| 153 | +metrics correctly, please adjust these transductive datasets in your configuration |
| 154 | +to reflect your specific dataset inductive split settings. |
| 155 | +</details> |
| 156 | + |
| 157 | +## Configuration Details ## |
| 158 | +We use [Hydra](https://hydra.cc/docs/intro/) to manage the configuration. The |
| 159 | +configs are organized in three files under the `configs/` directory: |
| 160 | + |
| 161 | +### `main.yaml` ### |
| 162 | +Settings for experiments, including random seed, wandb, path, |
| 163 | +hydra, and logging configs. |
| 164 | + |
| 165 | +### `data.yaml` ### |
| 166 | +This file contains settings for datasets, including preprocessing specifications, |
| 167 | +metadata, and lookup configurations. Here’s an overview of the key elements: |
| 168 | + |
| 169 | +<details> |
| 170 | + |
| 171 | +#### Dataset Preprocessing Options #### |
| 172 | +- `preprocess_device: gpu` — Specifies the device for computing propagated features $\boldsymbol{F}$. Set to cpu if your GPU memory is below 32GB. |
| 173 | +- `add_self_loop: false` — Specifies whether to add self-loops to the nodes in the |
| 174 | + graph. |
| 175 | +- `to_bidirected: true` — If set to true, edges are made bidirectional. |
| 176 | +- `n_hops: 2` — Defines the maximum number of hops of message passing. In our |
| 177 | + experiments, besides Linear, we use LinearSGC1, LinearSGC1, LinearHGC1, |
| 178 | + LinearHGC2, which predicts information within 2 hops of message passing. |
| 179 | + |
| 180 | +#### Train and Evaluation Dataset Lookup #### |
| 181 | +- The datasets for training and evaluation are dynamically selected based on the |
| 182 | + command-line arguments by looking up from the `_dataset_lookup` configuration |
| 183 | +- Example: Using `dataset=CoraXAll` sets `train_datasets` to `[Cora]` and |
| 184 | + `eval_datasets` to all datasets (31 in total). |
| 185 | + |
| 186 | +```yaml |
| 187 | +train_datasets: ${oc.select:_dataset_lookup.${dataset}.train,${dataset}} |
| 188 | +eval_datasets: ${oc.select:_dataset_lookup.${dataset}.eval,${dataset}} |
| 189 | +_dataset_lookup: |
| 190 | +- CoraXAll: |
| 191 | + - train: [Cora] |
| 192 | + - eval: ${_all_datasets} |
| 193 | +``` |
| 194 | +
|
| 195 | +Please define your own dataset combinations in `_dataset_lookup` if desired. |
| 196 | + |
| 197 | +#### Detailed Dataset Configurations #### |
| 198 | +The dataset meta-data stores the meta information including the interfaces [DGL], |
| 199 | +[PyG], [OGB], [Heterophilous] and their aliases (e.g. `Planetoid.Cora`) to load the |
| 200 | +dataset. The statistics are provided in the comment with a format of 'n_nodes, |
| 201 | +n_edges, n_feat_dim, n_labels'. For example: |
| 202 | + |
| 203 | +[DGL]: https://docs.dgl.ai/en/2.0.x/api/python/dgl.data.html#node-prediction-datasets |
| 204 | +[PyG]: https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html |
| 205 | +[OGB]: https://ogb.stanford.edu/docs/nodeprop/ |
| 206 | +[Heterophilous]: https://arxiv.org/abs/2302.11640 |
| 207 | + |
| 208 | +```yaml |
| 209 | +_ds_meta_data: |
| 210 | + Arxiv: ogb, ogbn-arxiv # 168,343 1,166,243 100 40 |
| 211 | + Cora: pyg, Planetoid.Cora # 2,708 10,556 1,433 7 |
| 212 | +``` |
| 213 | +</details> |
| 214 | + |
| 215 | +### `model.yaml` ### |
| 216 | +This file contains the settings for models and training. |
| 217 | + |
| 218 | +<details> |
| 219 | + |
| 220 | +GraphAny leverages **_interactions between predictions_** as input features for an |
| 221 | +MLP to calculate inductive attention scores. These inputs are termed "**_feature |
| 222 | +channels_**" and are defined in the configuration file as `feat_chn`. Subsequently, |
| 223 | +the outputs from LinearGNNs, referred to as "**_prediction channels_**", are |
| 224 | +combined using inductive attention scores and are defined as `pred_chn` in the |
| 225 | +configuration file. The default settings are: |
| 226 | + |
| 227 | +```yaml |
| 228 | +feat_chn: X+L1+L2+H1+H2 # X=Linear, L1=LinearSGC1, L2=LinearSGC2, H1=LinearHGC1, H2=LinearHGC2 |
| 229 | +pred_chn: X+L1+L2 # H1 and H2 channels are masked to enhance convergence speed. |
| 230 | +``` |
| 231 | + |
| 232 | +It is important to note that the feature channels and prediction channels do not |
| 233 | +need to be identical. Empirical observations indicate that masking LinearHGC1 and |
| 234 | +LinearHGC2 leads to faster convergence and marginally improved results (results in |
| 235 | +Table 2, Figure 1, and Figure 5). Furthermore, for the attention visualizations in |
| 236 | +Figure 6, all five channels (`pred_chn=X+L1+L2+H1+H2`) are employed. This |
| 237 | +demonstrates GraphAny's capability to learn inductive attention that effectively |
| 238 | +identifies critical channels for unseen graphs. |
| 239 | + |
| 240 | +Other model parameters and default values: |
| 241 | +```yaml |
| 242 | +# The entropy to normalize the distance features (conditional gaussian distribution). The standard deviation of conditional gaussian distribution is dynamically determined via binary search, default to 1 |
| 243 | +entropy: 1 |
| 244 | +attn_temp: 5 # The temperature for attention normalization |
| 245 | +n_hidden: 128 # The hidden dimension of MLP |
| 246 | +n_mlp_layer: 2 |
| 247 | +``` |
| 248 | +</details> |
| 249 | + |
| 250 | + |
| 251 | +## Bring Your Own Dataset ## |
| 252 | + |
| 253 | +<details> |
| 254 | +<summary> |
| 255 | +We support three major sources of graph dataset interfaces: |
| 256 | +<a href="https://docs.dgl.ai/en/2.0.x/api/python/dgl.data.html#node-prediction-datasets">DGL</a>, |
| 257 | +<a href="https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html">PyG</a>, and |
| 258 | +<a href="https://ogb.stanford.edu/docs/nodeprop/">OGB</a>. |
| 259 | +If you are interested in adding your own dataset, here's how we integrated the cleaned |
| 260 | +Texas dataset processed by <a href="https://arxiv.org/abs/2302.11640">this paper</a>. |
| 261 | +<i>The original Texas dataset contains 5 classes, with a class with only one node, |
| 262 | +which makes using this class for training and evaluation meaningless.</i> |
| 263 | +</summary> |
| 264 | + |
| 265 | +In the example below, we demonstrate how to add a dataset called "Texas" with 4 |
| 266 | +classes from a new data source termed `heterophilous`. |
| 267 | + |
| 268 | +**Step 1**: Update `configs/data.yaml`: |
| 269 | + |
| 270 | +First, define your dataset's metadata. |
| 271 | + |
| 272 | +```yaml |
| 273 | +# configs/data.yaml |
| 274 | +_ds_meta_data: # key: dataset name, value: data_source, alias |
| 275 | + Texas: heterophilous, texas_4_classes |
| 276 | +``` |
| 277 | + |
| 278 | +The `data_source` is set as 'heterophilous', which is handled differently from other |
| 279 | +sources ('pyg', 'dgl', 'ogb'). |
| 280 | + |
| 281 | +Additionally, update the `_dataset_lookup` with a new setting: |
| 282 | + |
| 283 | +```yaml |
| 284 | +# configs/data.yaml |
| 285 | +_dataset_lookup: |
| 286 | + Debug: |
| 287 | + train: [ Wisconsin ] |
| 288 | + eval: [ Texas ] |
| 289 | +``` |
| 290 | + |
| 291 | +**Step 2**: Implement the dataset interface: |
| 292 | + |
| 293 | +Implement `load_heterophilous_dataset` in `data.py` to download and process the dataset. |
| 294 | + |
| 295 | +```python |
| 296 | +import numpy as np |
| 297 | +import torch |
| 298 | +from graphany.data import download_url |
| 299 | +import dgl |
| 300 | +
|
| 301 | +def load_heterophilous_dataset(url, raw_dir): |
| 302 | + # Converts Heterophilous dataset to DGL Graph format |
| 303 | + download_path = download_url(url, raw_dir) |
| 304 | + data = np.load(download_path) |
| 305 | + node_features = torch.tensor(data['node_features']) |
| 306 | + labels = torch.tensor(data['node_labels']) |
| 307 | + edges = torch.tensor(data['edges']) |
| 308 | +
|
| 309 | + graph = dgl.graph((edges[:, 0], edges[:, 1]), |
| 310 | + num_nodes=len(node_features), idtype=torch.int32) |
| 311 | + num_classes = len(labels.unique()) |
| 312 | + train_mask, val_mask, test_mask = torch.tensor(data['train_mask']), torch.tensor(data['val_mask']), torch.tensor( |
| 313 | + data['test_mask']) |
| 314 | +
|
| 315 | + return graph, labels, num_classes, node_features, train_mask, val_mask, test_mask |
| 316 | +``` |
| 317 | + |
| 318 | +**Step 3**: Update `GraphDataset` class in `data.py`: |
| 319 | + |
| 320 | +Modify the initialization and dataset loading functions: |
| 321 | + |
| 322 | +```python |
| 323 | +# In GraphDataset.__init__(): |
| 324 | +if self.data_source in ['dgl', 'pyg', 'ogb']: |
| 325 | + pass # Code for other data sources omitted for brevity |
| 326 | +elif self.data_source == 'heterophilous': |
| 327 | + target = '.data.load_heterophilous_dataset' |
| 328 | + url = f'https://example.com/data/{ds_alias}.npz' |
| 329 | + ds_init_args = { |
| 330 | + "_target_": target, 'raw_dir': f'{cfg.dirs.data_storage}{self.data_source}/', 'url': url |
| 331 | + } |
| 332 | +else: |
| 333 | + raise NotImplementedError(f'Unsupported data source: {self.data_source}') |
| 334 | +
|
| 335 | +# In GraphDataset.load_dataset(): |
| 336 | +from hydra.utils import instantiate |
| 337 | +def load_dataset(self, data_init_args): |
| 338 | + dataset = instantiate(data_init_args) |
| 339 | + if self.data_source in ['dgl', 'pyg', 'ogb']: |
| 340 | + pass # Code for other data sources omitted for brevity |
| 341 | + elif self.data_source == 'heterophilous': |
| 342 | + g, label, num_class, feat, train_mask, val_mask, test_mask = dataset |
| 343 | + # Rest of the code omitted for brevity |
| 344 | +``` |
| 345 | + |
| 346 | +You can now run the code using the following commands: |
| 347 | + |
| 348 | +```bash |
| 349 | +# Training from scratch |
| 350 | +python graphany/run.py dataset=Debug total_steps=500 |
| 351 | +# Inference using existing checkpoint |
| 352 | +python graphany/run.py prev_ckpt=checkpoints/graph_any_wisconsin.pt dataset=Debug total_steps=0 |
| 353 | +``` |
| 354 | +</details> |
| 355 | + |
| 356 | +## Using Wandb for Enhanced Visualization ## |
| 357 | + |
| 358 | +We recommend using [Weights & Biases](https://wandb.ai/) (wandb) for advanced |
| 359 | +visualization capabilities. As an example, consider the visualizations for the |
| 360 | +GraphAny-Arxiv project shown below, which illustrate the validation accuracy across |
| 361 | +different data set categories: |
| 362 | +- **Transductive**: Training dataset (i.e. Arxiv) |
| 363 | +- **Heldout**: 27 datasets (except Cora, Wisconsin, Arxiv, Product) |
| 364 | +- **Inductive**: 30 datasets (except arxiv) |
| 365 | +- **Overall**: 31 datasets (all datasets) |
| 366 | + |
| 367 | + |
| 368 | + |
| 369 | +By default, wandb integration is disabled. To enable and configure wandb for your |
| 370 | +project, use the following command, substituting `YourOwnWandbEntity` with your |
| 371 | +actual Weights & Biases entity name: |
| 372 | + |
| 373 | +```bash |
| 374 | +use_wandb=true wandb_proj=GraphAny wandb_entity=YourOwnWandbEntity |
| 375 | +``` |
| 376 | + |
| 377 | +This setup will allow you to track and visualize metrics dynamically. |
| 378 | + |
| 379 | +## Citation ## |
| 380 | +If you find this codebase useful in your research, please cite the paper. |
| 381 | + |
| 382 | +```bibtex |
| 383 | +@article{zhao2024graphany, |
| 384 | + title={GraphAny: A Foundation Model for Node Classification on Any Graph}, |
| 385 | + author={Jianan Zhao and Hesham Mostafa and Michael Galkin and Michael Bronstein and Zhaocheng Zhu and Jian Tang}, |
| 386 | + year={2024}, |
| 387 | + eprint={2405.20445}, |
| 388 | + archivePrefix={arXiv}, |
| 389 | + primaryClass={cs.LG} |
| 390 | +} |
| 391 | +``` |
0 commit comments