diff --git a/.gitignore b/.gitignore index ee0c1a3640..ec94badfc6 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ word_language_model/model.pt fast_neural_style/saved_models fast_neural_style/saved_models.zip gcn/cora/ +gat/cora/ docs/build docs/venv diff --git a/docs/source/index.rst b/docs/source/index.rst index b0f2524c35..ddfd45b3fc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -176,4 +176,4 @@ experiment with PyTorch. This example implements the `Semi-Supervised Classification with Graph Convolutional Networks `__ paper on the CORA database. - `GO TO EXAMPLE `__ :opticon:`link-external` + `GO TO EXAMPLE `__ :opticon:`link-external` \ No newline at end of file diff --git a/gat/README.md b/gat/README.md new file mode 100644 index 0000000000..41106a1f76 --- /dev/null +++ b/gat/README.md @@ -0,0 +1,114 @@ +# Graph Attention Network + +This repository contains a PyTorch implementation of the **Graph Attention Networks (GAT)** based on the paper ["Graph Attention Network" by Velickovic et al](https://arxiv.org/abs/1710.10903v3). + +The Graph Attention Network is a powerful graph neural network model for learning represtations on graph-structured data, which has shown excellent performance in various tasks such as node classification, link prediction, and graph classification. + + +## Overview +The Graph Attention Network (GAT) is a graph neural network architecture designed specifically for handling graph-structured data. It leverages multi-head attention mechanism to capture the information of neighboring nodes in an attentive manner to learn represtations for each node. This attention mechanism allows the model to focus on relevant nodes and adaptively weight their contributions during message passing. + +Check out the following resources for more ino on GATs: +- [Blog post by the main auther, Petar Velickovic](https://petar-v.com/GAT/) +- [Main paper](https://doi.org/10.48550/arXiv.1710.10903) + +This repository provides a clean and short implementation of the official GAT model using PyTorch. The code is well-documented and easy to understand, making it a valuable resource for researchers and practitioners interested in graph deep learning. + + +## Key Features + +- **GAT Model**: Implementation of the Graph Attention Network model with multi-head attention based on on the paper "Graph Attention Network" by Velickovic et al. +- **Graph Attention Layers**: Implementation of graph convolutional layers that aggregate information from neighboring nodes using a self-attention mechanisms to learn node importance weights. +- **Training and Evaluation**: Code for training GAT models on graph-structured data and evaluating their performance on node classification tasks on the *Cora* benchmark dataset. + +--- + +# Requirements +- Python 3.7 or higher +- PyTorch 2.0 or higher +- Requests 2.31 or higher +- NumPy 1.24 or higher + + + +# Dataset +The implementation includes support for the Cora dataset, a standard benchmark dataset for graph-based machine learning tasks. The Cora dataset consists of scientific publications, where nodes represent papers and edges represent citation relationships. Each paper is associated with a binary label indicating one of seven classes. The dataset is downloaded, preprocessed and ready to use. + +# Model Architecture +The official architecture (used in this project) proposed in the paper "Graph Attention Network" by Velickovic et al. consists of two graph attention layers which incorporates the multi-head attention mechanisms during its message trasformation and aggregation. Each graph attention layer applies a shared self-attention mechanism to every node in the graph, allowing them to learn different representations based on the importance of their neighbors. + +In terms of activation functions, the GAT model employs both the **Exponential Linear Unit (ELU)** and the **Leaky Rectified Linear Unit (LeakyReLU)** activations, which introduce non-linearity to the model. ELU is used as the activation function for the **hidden layers**, while LeakyReLU is applied to the **attention coefficients** to ensure non-zero gradients for negative values. + +Following the official implementation, the first GAT layer consists of **K = 8 attention heads** computing **F' = 8 features** each (for a **total of 64 features**) followed by an exponential linear unit (ELU) activation on the layer outputs. The second GAT layer is used for classification: a **single attention head** that computes C features (where C is the number of classes), followed by a softmax activation for probablisitic outputs. (we use log-softmax instead for computational convenience with using NLLLoss) + +*Note that due to being an educational example, this implementation uses the full dense form of the adjacency matrix of the graph, and not the sparse form of the matrix. Thus all the operations in the model implemeation is done in a non-sparse from. This will not affect the model's performance accuracy-wise. However an sparse-friendly implementation will help with the efficiency in the use of resources, storage, and speed.* + + +# Usage +Training and evaluating the GAT model on the Cora dataset can be done through running the the `main.py` script as follows: + +1. Clone the PyTorch examples repository: + +``` +git clone https://github.com/pytorch/examples.git +cd examples/gat +``` + +2. Install the required dependencies: + +``` +pip install -r requirements.txt +``` + +3. Train the GAT model by running the the `main.py` script as follows:: (Example using the default parameters) + +```bash +python main.py --epochs 300 --lr 0.005 --l2 5e-4 --dropout-p 0.6 --num-heads 8 --hidden-dim 64 --val-every 20 +``` + +In more detail, the `main.py` script recieves following arguments: +``` +usage: main.py [-h] [--epochs EPOCHS] [--lr LR] [--l2 L2] [--dropout-p DROPOUT_P] [--hidden-dim HIDDEN_DIM] [--num-heads NUM_HEADS] [--concat-heads] [--val-every VAL_EVERY] + [--no-cuda] [--no-mps] [--dry-run] [--seed S] + +PyTorch Graph Attention Network + +options: + -h, --help show this help message and exit + --epochs EPOCHS number of epochs to train (default: 300) + --lr LR learning rate (default: 0.005) + --l2 L2 weight decay (default: 6e-4) + --dropout-p DROPOUT_P + dropout probability (default: 0.6) + --hidden-dim HIDDEN_DIM + dimension of the hidden representation (default: 64) + --num-heads NUM_HEADS + number of the attention heads (default: 4) + --concat-heads wether to concatinate attention heads, or average over them (default: False) + --val-every VAL_EVERY + epochs to wait for print training and validation evaluation (default: 20) + --no-cuda disables CUDA training + --no-mps disables macOS GPU training + --dry-run quickly check a single pass + --seed S random seed (default: 13) +``` + + + +# Results +After training for **300 epochs** with default hyperparameters on random train/val/test data splits, the GAT model achieves around **%81.25** classification accuracy on the test split. This result is comparable to the performance reported in the original paper. However, the results can vary due to the randomness of the train/val/test split. + +# Reference + +``` +@article{ + velickovic2018graph, + title="{Graph Attention Networks}", + author={Veli{\v{c}}kovi{\'{c}}, Petar and Cucurull, Guillem and Casanova, Arantxa and Romero, Adriana and Li{\`{o}}, Pietro and Bengio, Yoshua}, + journal={International Conference on Learning Representations}, + year={2018}, + url={https://openreview.net/forum?id=rJXMpikCZ}, +} +``` +- Paper on arxiv: [arXiv:1710.10903v3](https://doi.org/10.48550/arXiv.1710.10903) +- Original paper repository: [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT) diff --git a/gat/main.py b/gat/main.py new file mode 100644 index 0000000000..9c143af8ec --- /dev/null +++ b/gat/main.py @@ -0,0 +1,375 @@ +import os +import time +import requests +import tarfile +import numpy as np +import argparse + +import torch +from torch import nn +import torch.nn.functional as F +from torch.optim import Adam + + +################################ +### GAT LAYER DEFINITION ### +################################ + +class GraphAttentionLayer(nn.Module): + """ + Graph Attention Layer (GAT) as described in the paper `"Graph Attention Networks" `. + + This operation can be mathematically described as: + + e_ij = a(W h_i, W h_j) + α_ij = softmax_j(e_ij) = exp(e_ij) / Σ_k(exp(e_ik)) + h_i' = σ(Σ_j(α_ij W h_j)) + + where h_i and h_j are the feature vectors of nodes i and j respectively, W is a learnable weight matrix, + a is an attention mechanism that computes the attention coefficients e_ij, and σ is an activation function. + + """ + def __init__(self, in_features: int, out_features: int, n_heads: int, concat: bool = False, dropout: float = 0.4, leaky_relu_slope: float = 0.2): + super(GraphAttentionLayer, self).__init__() + + self.n_heads = n_heads # Number of attention heads + self.concat = concat # wether to concatenate the final attention heads + self.dropout = dropout # Dropout rate + + if concat: # concatenating the attention heads + self.out_features = out_features # Number of output features per node + assert out_features % n_heads == 0 # Ensure that out_features is a multiple of n_heads + self.n_hidden = out_features // n_heads + else: # averaging output over the attention heads (Used in the main paper) + self.n_hidden = out_features + + # A shared linear transformation, parametrized by a weight matrix W is applied to every node + # Initialize the weight matrix W + self.W = nn.Parameter(torch.empty(size=(in_features, self.n_hidden * n_heads))) + + # Initialize the attention weights a + self.a = nn.Parameter(torch.empty(size=(n_heads, 2 * self.n_hidden, 1))) + + self.leakyrelu = nn.LeakyReLU(leaky_relu_slope) # LeakyReLU activation function + self.softmax = nn.Softmax(dim=1) # softmax activation function to the attention coefficients + + self.reset_parameters() # Reset the parameters + + + def reset_parameters(self): + """ + Reinitialize learnable parameters. + """ + nn.init.xavier_normal_(self.W) + nn.init.xavier_normal_(self.a) + + + def _get_attention_scores(self, h_transformed: torch.Tensor): + """calculates the attention scores e_ij for all pairs of nodes (i, j) in the graph + in vectorized parallel form. for each pair of source and target nodes (i, j), + the attention score e_ij is computed as follows: + + e_ij = LeakyReLU(a^T [Wh_i || Wh_j]) + + where || denotes the concatenation operation, and a and W are the learnable parameters. + + Args: + h_transformed (torch.Tensor): Transformed feature matrix with shape (n_nodes, n_heads, n_hidden), + where n_nodes is the number of nodes and out_features is the number of output features per node. + + Returns: + torch.Tensor: Attention score matrix with shape (n_heads, n_nodes, n_nodes), where n_nodes is the number of nodes. + """ + + source_scores = torch.matmul(h_transformed, self.a[:, :self.n_hidden, :]) + target_scores = torch.matmul(h_transformed, self.a[:, self.n_hidden:, :]) + + # broadcast add + # (n_heads, n_nodes, 1) + (n_heads, 1, n_nodes) = (n_heads, n_nodes, n_nodes) + e = source_scores + target_scores.mT + return self.leakyrelu(e) + + def forward(self, h: torch.Tensor, adj_mat: torch.Tensor): + """ + Performs a graph attention layer operation. + + Args: + h (torch.Tensor): Input tensor representing node features. + adj_mat (torch.Tensor): Adjacency matrix representing graph structure. + + Returns: + torch.Tensor: Output tensor after the graph convolution operation. + """ + n_nodes = h.shape[0] + + # Apply linear transformation to node feature -> W h + # output shape (n_nodes, n_hidden * n_heads) + h_transformed = torch.mm(h, self.W) + h_transformed = F.dropout(h_transformed, self.dropout, training=self.training) + + # splitting the heads by reshaping the tensor and putting heads dim first + # output shape (n_heads, n_nodes, n_hidden) + h_transformed = h_transformed.view(n_nodes, self.n_heads, self.n_hidden).permute(1, 0, 2) + + # getting the attention scores + # output shape (n_heads, n_nodes, n_nodes) + e = self._get_attention_scores(h_transformed) + + # Set the attention score for non-existent edges to -9e15 (MASKING NON-EXISTENT EDGES) + connectivity_mask = -9e16 * torch.ones_like(e) + e = torch.where(adj_mat > 0, e, connectivity_mask) # masked attention scores + + # attention coefficients are computed as a softmax over the rows + # for each column j in the attention score matrix e + attention = F.softmax(e, dim=-1) + attention = F.dropout(attention, self.dropout, training=self.training) + + # final node embeddings are computed as a weighted average of the features of its neighbors + h_prime = torch.matmul(attention, h_transformed) + + # concatenating/averaging the attention heads + # output shape (n_nodes, out_features) + if self.concat: + h_prime = h_prime.permute(1, 0, 2).contiguous().view(n_nodes, self.out_features) + else: + h_prime = h_prime.mean(dim=0) + + return h_prime + +################################ +### MAIN GAT NETWORK MODULE ### +################################ + +class GAT(nn.Module): + """ + Graph Attention Network (GAT) as described in the paper `"Graph Attention Networks" `. + Consists of a 2-layer stack of Graph Attention Layers (GATs). The fist GAT Layer is followed by an ELU activation. + And the second (final) layer is a GAT layer with a single attention head and softmax activation function. + """ + def __init__(self, + in_features, + n_hidden, + n_heads, + num_classes, + concat=False, + dropout=0.4, + leaky_relu_slope=0.2): + """ Initializes the GAT model. + + Args: + in_features (int): number of input features per node. + n_hidden (int): output size of the first Graph Attention Layer. + n_heads (int): number of attention heads in the first Graph Attention Layer. + num_classes (int): number of classes to predict for each node. + concat (bool, optional): Wether to concatinate attention heads or take an average over them for the + output of the first Graph Attention Layer. Defaults to False. + dropout (float, optional): dropout rate. Defaults to 0.4. + leaky_relu_slope (float, optional): alpha (slope) of the leaky relu activation. Defaults to 0.2. + """ + + super(GAT, self).__init__() + + # Define the Graph Attention layers + self.gat1 = GraphAttentionLayer( + in_features=in_features, out_features=n_hidden, n_heads=n_heads, + concat=concat, dropout=dropout, leaky_relu_slope=leaky_relu_slope + ) + + self.gat2 = GraphAttentionLayer( + in_features=n_hidden, out_features=num_classes, n_heads=1, + concat=False, dropout=dropout, leaky_relu_slope=leaky_relu_slope + ) + + + def forward(self, input_tensor: torch.Tensor , adj_mat: torch.Tensor): + """ + Performs a forward pass through the network. + + Args: + input_tensor (torch.Tensor): Input tensor representing node features. + adj_mat (torch.Tensor): Adjacency matrix representing graph structure. + + Returns: + torch.Tensor: Output tensor after the forward pass. + """ + + # Apply the first Graph Attention layer + x = self.gat1(input_tensor, adj_mat) + x = F.elu(x) # Apply ELU activation function to the output of the first layer + + # Apply the second Graph Attention layer + x = self.gat2(x, adj_mat) + + return F.log_softmax(x, dim=1) # Apply log softmax activation function + +################################ +### LOADING THE CORA DATASET ### +################################ + +def load_cora(path='./cora', device='cpu'): + """ + Loads the Cora dataset. The dataset is downloaded from https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz. + + """ + + # Set the paths to the data files + content_path = os.path.join(path, 'cora.content') + cites_path = os.path.join(path, 'cora.cites') + + # Load data from files + content_tensor = np.genfromtxt(content_path, dtype=np.dtype(str)) + cites_tensor = np.genfromtxt(cites_path, dtype=np.int32) + + # Process features + features = torch.FloatTensor(content_tensor[:, 1:-1].astype(np.int32)) # Extract feature values + scale_vector = torch.sum(features, dim=1) # Compute sum of features for each node + scale_vector = 1 / scale_vector # Compute reciprocal of the sums + scale_vector[scale_vector == float('inf')] = 0 # Handle division by zero cases + scale_vector = torch.diag(scale_vector).to_sparse() # Convert the scale vector to a sparse diagonal matrix + features = scale_vector @ features # Scale the features using the scale vector + + # Process labels + classes, labels = np.unique(content_tensor[:, -1], return_inverse=True) # Extract unique classes and map labels to indices + labels = torch.LongTensor(labels) # Convert labels to a tensor + + # Process adjacency matrix + idx = content_tensor[:, 0].astype(np.int32) # Extract node indices + idx_map = {id: pos for pos, id in enumerate(idx)} # Create a dictionary to map indices to positions + + # Map node indices to positions in the adjacency matrix + edges = np.array( + list(map(lambda edge: [idx_map[edge[0]], idx_map[edge[1]]], + cites_tensor)), dtype=np.int32) + + V = len(idx) # Number of nodes + E = edges.shape[0] # Number of edges + adj_mat = torch.sparse_coo_tensor(edges.T, torch.ones(E), (V, V), dtype=torch.int64) # Create the initial adjacency matrix as a sparse tensor + adj_mat = torch.eye(V) + adj_mat # Add self-loops to the adjacency matrix + + # return features.to_sparse().to(device), labels.to(device), adj_mat.to_sparse().to(device) + return features.to(device), labels.to(device), adj_mat.to(device) + +################################# +### TRAIN AND TEST FUNCTIONS ### +################################# + +def train_iter(epoch, model, optimizer, criterion, input, target, mask_train, mask_val, print_every=10): + start_t = time.time() + model.train() + optimizer.zero_grad() + + # Forward pass + output = model(*input) + loss = criterion(output[mask_train], target[mask_train]) # Compute the loss using the training mask + + loss.backward() + optimizer.step() + + # Evaluate the model performance on training and validation sets + loss_train, acc_train = test(model, criterion, input, target, mask_train) + loss_val, acc_val = test(model, criterion, input, target, mask_val) + + if epoch % print_every == 0: + # Print the training progress at specified intervals + print(f'Epoch: {epoch:04d} ({(time.time() - start_t):.4f}s) loss_train: {loss_train:.4f} acc_train: {acc_train:.4f} loss_val: {loss_val:.4f} acc_val: {acc_val:.4f}') + + +def test(model, criterion, input, target, mask): + model.eval() + with torch.no_grad(): + output = model(*input) + output, target = output[mask], target[mask] + + loss = criterion(output, target) + acc = (output.argmax(dim=1) == target).float().sum() / len(target) + return loss.item(), acc.item() + + +if __name__ == '__main__': + + # Training settings + # All defalut values are the same as in the config used in the main paper + + parser = argparse.ArgumentParser(description='PyTorch Graph Attention Network') + parser.add_argument('--epochs', type=int, default=300, + help='number of epochs to train (default: 300)') + parser.add_argument('--lr', type=float, default=0.005, + help='learning rate (default: 0.005)') + parser.add_argument('--l2', type=float, default=5e-4, + help='weight decay (default: 6e-4)') + parser.add_argument('--dropout-p', type=float, default=0.6, + help='dropout probability (default: 0.6)') + parser.add_argument('--hidden-dim', type=int, default=64, + help='dimension of the hidden representation (default: 64)') + parser.add_argument('--num-heads', type=int, default=8, + help='number of the attention heads (default: 4)') + parser.add_argument('--concat-heads', action='store_true', default=False, + help='wether to concatinate attention heads, or average over them (default: False)') + parser.add_argument('--val-every', type=int, default=20, + help='epochs to wait for print training and validation evaluation (default: 20)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--no-mps', action='store_true', default=False, + help='disables macOS GPU training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=13, metavar='S', + help='random seed (default: 13)') + args = parser.parse_args() + + torch.manual_seed(args.seed) + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + # Set the device to run on + if use_cuda: + device = torch.device('cuda') + elif use_mps: + device = torch.device('mps') + else: + device = torch.device('cpu') + print(f'Using {device} device') + + # Load the dataset + cora_url = 'https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz' + path = './cora' + + if os.path.isfile(os.path.join(path, 'cora.content')) and os.path.isfile(os.path.join(path, 'cora.cites')): + print('Dataset already downloaded...') + else: + print('Downloading dataset...') + with requests.get(cora_url, stream=True) as tgz_file: + with tarfile.open(fileobj=tgz_file.raw, mode='r:gz') as tgz_object: + tgz_object.extractall() + + print('Loading dataset...') + # Load the dataset + features, labels, adj_mat = load_cora(device=device) + # Split the dataset into training, validation, and test sets + idx = torch.randperm(len(labels)).to(device) + idx_test, idx_val, idx_train = idx[:1200], idx[1200:1600], idx[1600:] + + + # Create the model + # The model consists of a 2-layer stack of Graph Attention Layers (GATs). + gat_net = GAT( + in_features=features.shape[1], # Number of input features per node + n_hidden=args.hidden_dim, # Output size of the first Graph Attention Layer + n_heads=args.num_heads, # Number of attention heads in the first Graph Attention Layer + num_classes=labels.max().item() + 1, # Number of classes to predict for each node + concat=args.concat_heads, # Wether to concatinate attention heads + dropout=args.dropout_p, # Dropout rate + leaky_relu_slope=0.2 # Alpha (slope) of the leaky relu activation + ).to(device) + + # configure the optimizer and loss function + optimizer = Adam(gat_net.parameters(), lr=args.lr, weight_decay=args.l2) + criterion = nn.NLLLoss() + + # Train and evaluate the model + for epoch in range(args.epochs): + train_iter(epoch + 1, gat_net, optimizer, criterion, (features, adj_mat), labels, idx_train, idx_val, args.val_every) + if args.dry_run: + break + loss_test, acc_test = test(gat_net, criterion, (features, adj_mat), labels, idx_test) + print(f'Test set results: loss {loss_test:.4f} accuracy {acc_test:.4f}') \ No newline at end of file diff --git a/gat/requirements.txt b/gat/requirements.txt new file mode 100644 index 0000000000..a47e3fee28 --- /dev/null +++ b/gat/requirements.txt @@ -0,0 +1,3 @@ +torch +requests +numpy \ No newline at end of file diff --git a/run_python_examples.sh b/run_python_examples.sh index 93142527af..1b45a281cf 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -177,6 +177,11 @@ function gcn() { python main.py --epochs 1 --dry-run || error "graph convolutional network failed" } +function gat() { + start + python main.py --epochs 1 --dry-run || error "graph attention network failed" +} + function clean() { cd $BASE_DIR echo "running clean to remove cruft" @@ -198,7 +203,8 @@ function clean() { time_sequence_prediction/predict*.pdf \ time_sequence_prediction/traindata.pt \ word_language_model/model.pt \ - gcn/cora/ || error "couldn't clean up some files" + gcn/cora/ \ + gat/cora/ || error "couldn't clean up some files" git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image" } @@ -224,6 +230,7 @@ function run_all() { word_language_model fx gcn + gat } # by default, run all examples