Updates: We have added the script for model visualization (Figure 4 in our paper)!
The repository implements the Structure-Aware Transformer (SAT) in Pytorch Geometric described in the following paper
Dexiong Chen*, Leslie O'Bray*, and Karsten Borgwardt. Structure-Aware Transformer for Graph Representation Learning. ICML 2022.
*Equal contribution
TL;DR: A class of simple and flexible graph transformers built upon a new self-attention mechanism, which incorporates structural information into the original self-attention by extracting a subgraph representation rooted at each node before computing the attention. Our structure-aware framework can leverage any existing GNN to extract the subgraph representation and systematically improve the peroformance relative to the base GNN.
Please use the following to cite our work:
@InProceedings{Chen22a,
author = {Dexiong Chen and Leslie O'Bray and Karsten Borgwardt},
title = {Structure-Aware Transformer for Graph Representation Learning},
year = {2022},
booktitle = {Proceedings of the 39th International Conference on Machine Learning~(ICML)},
series = {Proceedings of Machine Learning Research}
}The SAT architecture compared with the vanilla transformer architecture is shown above. We make the self-attention calculation in each transformer layer structure-aware by leveraging structure-aware node embeddings. We generate these embeddings using a structure extractor (for example, any GNN) on the
The figure above shows the two example structure extractors used in our paper (
Below you can find a quick-start example on the ZINC dataset, see ./experiments/train_zinc.py for more details.
click to see the example:
import torch
from torch_geometric import datasets
from torch_geometric.loader import DataLoader
from sat.data import GraphDataset
from sat import GraphTransformer
# Load the ZINC dataset using our wrapper GraphDataset,
# which automatically creates the fully connected graph.
# For datasets with large graph, we recommend setting return_complete_index=False
# leading to faster computation
dset = datasets.ZINC('./datasets/ZINC', subset=True, split='train')
dset = GraphDataset(dset)
# Create a PyG data loader
train_loader = DataLoader(dset, batch_size=16, shuffle=True)
# Create a SAT model
dim_hidden = 16
gnn_type = 'gcn' # use GCN as the structure extractor
k_hop = 2 # use a 2-layer GCN
model = GraphTransformer(
in_size=28, # number of node labels for ZINC
num_class=1, # regression task
d_model=dim_hidden,
dim_feedforward=2 * dim_hidden,
num_layers=2,
batch_norm=True,
gnn_type='gcn', # use GCN as the structure extractor
use_edge_attr=True,
num_edge_features=4, # number of edge labels
edge_dim=dim_hidden,
k_hop=k_hop,
se='gnn', # we use the k-subtree structure extractor
global_pool='add'
)
for data in train_loader:
output = model(data) # batch_size x 1
breakThe dependencies are managed by miniconda
python=3.9
numpy
scipy
pytorch=1.9.1
pytorch-geometric=2.0.2
einops
ogb
Once you have activated the environment and installed all dependencies, run:
source sDatasets will be downloaded via Pytorch geometric and OGB package.
All our experimental scripts are in the folder experiments. So to start with, after having run source s, run cd experiments. The hyperparameters used below are selected as optimal
Train a k-subtree SAT with PNA:
python train_zinc.py --abs-pe rw --se gnn --gnn-type pna2 --dropout 0.3 --k-hop 3 --use-edge-attrTrain a k-subgraph SAT with PNA
python train_zinc.py --abs-pe rw --se khopgnn --gnn-type pna2 --dropout 0.2 --k-hop 3 --use-edge-attrTrain a k-subtree SAT on PATTERN:
python train_SBMs.py --dataset PATTERN --weight-class --abs-pe rw --abs-pe-dim 7 --se gnn --gnn-type pna3 --dropout 0.2 --k-hop 3 --num-layers 6 --lr 0.0003and on CLUSTER:
python train_SBMs.py --dataset CLUSTER --weight-class --abs-pe rw --abs-pe-dim 3 --se gnn --gnn-type pna2 --dropout 0.4 --k-hop 3 --num-layers 16 --dim-hidden 48 --lr 0.0005--gnn-type can be gcn, gine or pna, where pna obtains the best performance.
# Train SAT on OGBG-PPA
python train_ppa.py --gnn-type gcn --use-edge-attr# Train SAT on OGBG-CODE2
python train_code2.py --gnn-type gcn --use-edge-attrWe showcase here how to visualize the attention weights of the [CLS] node learned by SAT and vanilla Transformer with the random walk positional encoding. We have provided the pre-trained models on the Mutagenecity dataset. To visualize the pre-trained models, you need to install the networkx and matplotlib packages, then run:
python model_visu.py --graph-idx 2003This will generate the following image, the same as the Figure 4 in our paper:


