diff --git a/programl/task/graph_level_classification/README.md b/programl/task/graph_level_classification/README.md new file mode 100644 index 000000000..5bdd6d7e5 --- /dev/null +++ b/programl/task/graph_level_classification/README.md @@ -0,0 +1,53 @@ +# Graph Level Classification + +Two subtasks are of particular interest: classifyapp a.k.a. poj104 and devmap a.k.a. heterogeneous device mapping. + +## Quickstart +`python run.py --help` will print this help +``` +Usage: + run.py [options] + +Options: + -h --help Show this screen. + --data_dir DATA_DIR Directory(*) to of dataset. (*)=relative to repository root ProGraML/. + Will overwrite the per-dataset defaults if provided. + + --log_dir LOG_DIR Directory(*) to store logfiles and trained models relative to repository dir. + [default: programl/task/graph_level_classification/logs/unspecified] + --model MODEL The model to run. + --dataset DATASET The dataset to use. + --config CONFIG Path(*) to a config json dump with params. + --config_json CONFIG_JSON Config json with params. + --restore CHECKPOINT Path(*) to a model file to restore from. + --skip_restore_config Whether to skip restoring the config from CHECKPOINT. + --test Test the model without training. + --restore_by_pattern PATTERN Restore newest model of this name from log_dir and + continue training. (AULT specific!) + PATTERN is a string that can be grep'ed for. + --kfold Run kfold cross-validation iff kfold is set. + Splits are currently dataset specific. + --transfer MODEL The model-class to transfer to. + The args specified will be applied to the transferred model to the extend applicable, e.g. + training params and Readout module specifications, but not to the transferred model trunk. + However, we strongly recommend to make all trunk-parameters match, in order to be able + to restore from transferred checkpoints without having to pass a matching config manually. + --transfer_mode MODE One of frozen, finetune (but not yet implemented) [default: frozen] + Mode frozen also sets all dropout in the restored model to zero (the newly initialized + readout function can have dropout nonetheless, depending on the config provided). + --skip_save_every_epoch Save latest model after every epoch (on a rolling basis). +``` +Therefore, an exemplary command could look like this: +``` +Reproduce the Transformer result for the rebuttal: + +python run.py --model transformer_poj104 --dataset poj104 --data_dir ~/rebuttal_datasets/classifyapp/ --log_dir logs/classifyapp_logs/rebuttal_transformer_poj104/ --config_json="{'train_subset': [0, 100], 'batch_size': 48, 'max_num_nodes': 40000, 'num_epochs': 70, 'vocab_size': 2231, 'message_weight_sharing': 2, 'update_weight_sharing': 2, 'lr': 1e-4, 'gnn_layers': 10}" +``` +NB: You can pass a double quoted string of config options in json format, except that you may use single quotes (they will be parsed as double quotes to transform this almost-json format into valid json) + +## How to reproduce results from the paper? + +``` +more run commands / another script that does it for us. +``` + diff --git a/programl/task/graph_level_classification/configs.py b/programl/task/graph_level_classification/configs.py new file mode 100644 index 000000000..48ccc00a2 --- /dev/null +++ b/programl/task/graph_level_classification/configs.py @@ -0,0 +1,294 @@ +# Copyright 2019 the ProGraML authors. +# +# Contact Chris Cummins . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configs""" +from .dataset import AblationVocab + + +class ProGraMLBaseConfig(object): + def __init__(self): + self.name = self.__class__.__name__ + + # Training Hyperparameters + self.num_epochs = 25 + self.batch_size = 128 + # limit the number of nodes per batch to a sensible maximum + # by possibly discarding certain samples from the batch. + self.max_num_nodes = 200000 + self.lr: float = 0.00025 + self.patience = 10000 + self.clip_grad_norm: float = 0.0 + self.train_subset = [0, 100] + self.random_seed: int = 42 + + # Readout + self.output_dropout: float = 0.0 + + # Model Hyperparameters + self.emb_size: int = 200 + self.edge_type_count: int = 3 + + self.vocab_size: int = 8568 + self.cdfg_vocab: bool = False + + # ABLATION OPTIONS + # NONE = 0 No ablation - use the full vocabulary (default). + # NO_VOCAB = 1 Ignore the vocabulary - every node has an x value of 0. + # NODE_TYPE_ONLY = 2 Use a 3-element vocabulary based on the node type: + # 0 - Instruction node + # 1 - Variable node + # 2 - Constant node + self.ablation_vocab: AblationVocab = 0 # 0 NONE, 1 NO_VOCAB, 2 NODE_TYPE_ONLY + + # inst2vec_embeddings can now be 'none' as well! + # this reduces the tokens that the network sees to only + # !IDENTIFIERs and !UNK statements + # One of {zero, constant, random, random_const, finetune, none} + self.inst2vec_embeddings = "random" + + self.ablate_structure = None # one of {control,data,call} + + @classmethod + def from_dict(cls, params): + """instantiate Config from params dict that overrides default values where given.""" + config = cls() + if params is None: + return config + + for key in params: + if hasattr(config, key): + setattr(config, key, params[key]) + else: + print( + f"(*CONFIG FROM DICT* Default {config.name} doesn't have a key {key}. Will add key to config anyway!" + ) + setattr(config, key, params[key]) + return config + + def to_dict(self): + config_dict = { + a: getattr(self, a) + for a in dir(self) + if not a.startswith("__") and not callable(getattr(self, a)) + } + return config_dict + + def check_equal(self, other): + # take either config object or config_dict + other_dict = other if isinstance(other, dict) else other.to_dict() + if not self.to_dict() == other_dict: + print( + f"WARNING: GGNNConfig.check_equal() FAILED:\nself and other are unequal: " + f"The difference is {set(self.to_dict()) ^ set(other.to_dict())}.\n self={self.to_dict()}\n other={other_dict}" + ) + + +class GGNN_POJ104_Config(ProGraMLBaseConfig): + def __init__(self): + super().__init__() + ############### + # Model Hyperparameters + self.gnn_layers: int = 8 + self.message_weight_sharing: int = 2 + self.update_weight_sharing: int = 2 + # self.message_timesteps: List[int] = [2, 2, 2, 2] + # self.update_timesteps: List[int] = [2, 2, 2, 2] + + # currently only admits node types 0 and 1 for statements and identifiers. + self.use_node_types = True + self.use_edge_bias: bool = True + self.position_embeddings: bool = True + + # Aggregate by mean or by sum + self.msg_mean_aggregation: bool = True + self.backward_edges: bool = True + + ############### + # Regularization + self.edge_weight_dropout: float = 0.0 + self.graph_state_dropout: float = 0.2 + + ############### + # Dataset inherent, don't change! + self.num_classes: int = 104 + self.has_graph_labels: bool = True + self.has_aux_input: bool = False + + # self.use_selector_embeddings: bool = False + # self.selector_size: int = 2 if getattr(self, 'use_selector_embeddings', False) else 0 + # TODO(Zach) Maybe refactor non-rectangular edge passing matrices for independent hidden size. + # hidden size of the whole model + self.hidden_size: int = self.emb_size + getattr(self, "selector_size", 0) + + +class GGNN_Devmap_Config(GGNN_POJ104_Config): + def __init__(self): + super().__init__() + # change default + self.batch_size = 64 + self.lr = 2.5e-4 + self.num_epochs = 150 + self.graph_state_dropout = 0.0 + + # Auxiliary Readout + self.aux_use_better = False + self.intermediate_loss_weight = 0.2 + self.aux_in_size = 2 + self.aux_in_layer_size = 32 + self.aux_in_log1p = True + + # Dataset inherent, don't change! + self.num_classes: int = 2 + self.has_graph_labels: bool = True + self.has_aux_input: bool = True + + +class GGNN_Threadcoarsening_Config(GGNN_POJ104_Config): + def __init__(self): + super().__init__() + # Dataset inherent, don't change! + self.num_classes: int = 6 + self.has_graph_labels: bool = True + # self.has_aux_input: bool = False + + +class GGNN_ForPretraining_Config(GGNN_POJ104_Config): + def __init__(self): + super().__init__() + # Pretraining Parameters + self.mlm_probability = 0.15 + self.mlm_statements_only = True + self.mlm_exclude_unk_tokens = True + self.mlm_mask_token_id = 8568 + self.unk_token_id = 8564 + + # set for pretraining to vocab_size + 1 [MASK] + self.vocab_size = self.vocab_size + 1 + self.num_classes = self.vocab_size + self.has_graph_labels: bool = False + + +class GraphTransformer_POJ104_Config(ProGraMLBaseConfig): + def __init__(self): + super().__init__() + ###### borrowed for debugging ########## + + # GGNNMessage Layer + # self.msg_mean_aggregation: bool = True + # self.use_edge_bias: bool = True + + ############### + self.backward_edges: bool = True + self.gnn_layers: int = 8 + self.message_weight_sharing: int = 2 + self.update_weight_sharing: int = 2 + # self.layer_timesteps: List[int] = [1, 1, 1, 1, 1, 1, 1, 1] #[2, 2, 2, 2] + self.use_node_types: bool = False + + # Dataset Specific, don't change! + self.num_classes: int = 104 + self.has_graph_labels: bool = True + self.hidden_size: int = self.emb_size + getattr(self, "selector_size", 0) + + # Message: + self.position_embeddings: bool = True + # Self-Attn Layer + self.attn_bias = True + self.attn_num_heads = 5 # 8 # choose among 4,5,8,10 for emb_sz 200 + self.attn_dropout = 0.1 + self.attn_v_pos = False + + # Update: + + # Transformer Update Layer + self.update_layer: str = "ff" # or 'gru' + self.tfmr_act = "gelu" # relu or gelu, default relu + self.tfmr_dropout = 0.2 # default 0.1 + self.tfmr_ff_sz = 512 # 512 # ~ 2.5 model_dim (Bert: 768 - 2048, Trfm: base 512 - 2048, big 1024 - 4096) + + # Optionally: GGNN Update Layer + # self.update_layer: str = 'gru' # or 'ff' + # self.edge_weight_dropout: float = 0.0 + # self.graph_state_dropout: float = 0.2 + + +class GraphTransformer_Devmap_Config(GraphTransformer_POJ104_Config): + def __init__(self): + super().__init__() + # change default + self.batch_size = 64 + self.lr = 2.5e-4 + self.num_epochs = 600 + # self.graph_state_dropout = 0.0 #GGNN only + + # self.output_dropout # <- applies to Readout func! + + # Auxiliary Readout + self.aux_use_better = False + self.intermediate_loss_weight = 0.2 + self.aux_in_size = 2 + self.aux_in_layer_size = 32 + self.aux_in_log1p = True + + # Dataset inherent, don't change! + self.num_classes: int = 2 + self.has_graph_labels: bool = True + self.has_aux_input: bool = True + + +class GraphTransformer_Threadcoarsening_Config(GraphTransformer_POJ104_Config): + def __init__(self): + super().__init__() + self.lr = 5e-5 # 2.5-4? + self.num_epochs = 600 + # Dataset inherent, don't change! + self.num_classes: int = 6 + self.has_graph_labels: bool = True + # self.has_aux_input: bool = False + + +class GraphTransformer_ForPretraining_Config(GraphTransformer_POJ104_Config): + def __init__(self): + super().__init__() + self.num_of_splits = 2 + # Pretraining Parameters + self.mlm_probability = 0.15 + self.mlm_statements_only = True + self.mlm_exclude_unk_tokens = True + self.mlm_mask_token_id = 8568 + self.unk_token_id = 8564 + + # set for pretraining to vocab_size + 1 [MASK] + self.vocab_size = self.vocab_size + 1 + self.num_classes = self.vocab_size + self.has_graph_labels: bool = False + + +class GGNN_BranchPrediction_Config(GGNN_POJ104_Config): + def __init__(self): + super().__init__() + self.batch_size = 4 + # self.use_tanh_readout = False ! + self.num_classes = 1 + self.has_graph_labels = False + + +class GraphTransformer_BranchPrediction_Config(GraphTransformer_POJ104_Config): + def __init__(self): + super().__init__() + self.batch_size = 4 + # self.use_tanh_readout = False ! + self.num_classes = 1 + self.has_graph_labels = False diff --git a/programl/task/graph_level_classification/dataloader.py b/programl/task/graph_level_classification/dataloader.py new file mode 100644 index 000000000..7f3c4414b --- /dev/null +++ b/programl/task/graph_level_classification/dataloader.py @@ -0,0 +1,123 @@ +import torch.utils.data +from torch._six import container_abcs, int_classes, string_classes +from torch.utils.data.dataloader import default_collate +from torch_geometric.data import Batch, Data + + +class DataLoader(torch.utils.data.DataLoader): + r"""Data loader which merges data objects from a + :class:`torch_geometric.data.dataset` to a mini-batch. + + Args: + dataset (Dataset): The dataset from which to load the data. + batch_size (int, optional): How many samples per batch to load. + (default: :obj:`1`) + shuffle (bool, optional): If set to :obj:`True`, the data will be + reshuffled at every epoch. (default: :obj:`False`) + follow_batch (list or tuple, optional): Creates assignment batch + vectors for each key in the list. (default: :obj:`[]`) + """ + + def __init__(self, dataset, batch_size=1, shuffle=False, follow_batch=[], **kwargs): + def collate(batch): + elem = batch[0] + if isinstance(elem, Data): + return Batch.from_data_list(batch, follow_batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float) + elif isinstance(elem, int_classes): + return torch.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, container_abcs.Mapping): + return {key: collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): + return type(elem)(*(collate(s) for s in zip(*batch))) + elif isinstance(elem, container_abcs.Sequence): + return [collate(s) for s in zip(*batch)] + raise TypeError( + "DataLoader found invalid type: {}".format(type(elem).__name__) + ) + + super().__init__( + dataset, + batch_size, + shuffle, + collate_fn=lambda batch: collate(batch), + **kwargs, + ) + + +class NodeLimitedDataLoader(torch.utils.data.DataLoader): + r"""Data loader which merges data objects from a + :class:`torch_geometric.data.dataset` to a mini-batch. + + Args: + dataset (Dataset): The dataset from which to load the data. + batch_size (int, optional): How many samples per batch to load. + (default: :obj:`1`) + shuffle (bool, optional): If set to :obj:`True`, the data will be + reshuffled at every epoch. (default: :obj:`False`) + follow_batch (list or tuple, optional): Creates assignment batch + vectors for each key in the list. (default: :obj:`[]`) + """ + + def __init__( + self, + dataset, + batch_size=1, + shuffle=False, + follow_batch=[], + max_num_nodes=None, + warn_on_limit=False, + **kwargs, + ): + self.max_num_nodes = max_num_nodes + + def collate(batch): + elem = batch[0] + if isinstance(elem, Data): + # greedily add all samples that fit within self.max_num_nodes + # and silently discard all others + if max_num_nodes is not None: + num_nodes = 0 + limited_batch = [] + for elem in batch: + if num_nodes + elem.num_nodes <= self.max_num_nodes: + limited_batch.append(elem) + num_nodes += elem.num_nodes + else: # for debugging + pass + if len(limited_batch) < len(batch): + if warn_on_limit: + print( + f"dropped {len(batch) - len(limited_batch)} graphs from batch!" + ) + assert ( + limited_batch != [] + ), f"limited batch is empty! original batch was {batch}" + return Batch.from_data_list(limited_batch, follow_batch) + else: + return Batch.from_data_list(batch, follow_batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float) + elif isinstance(elem, int_classes): + return torch.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, container_abcs.Mapping): + return {key: collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): + return type(elem)(*(collate(s) for s in zip(*batch))) + elif isinstance(elem, container_abcs.Sequence): + return [collate(s) for s in zip(*batch)] + + raise TypeError("DataLoader found invalid type: {}".format(type(elem))) + + super().__init__( + dataset, + batch_size, + shuffle, + collate_fn=lambda batch: collate(batch), + **kwargs, + ) diff --git a/programl/task/graph_level_classification/dataset.py b/programl/task/graph_level_classification/dataset.py new file mode 100644 index 000000000..7b664a6db --- /dev/null +++ b/programl/task/graph_level_classification/dataset.py @@ -0,0 +1,1408 @@ +# better dataloader +import csv +import enum +import math +import os +import pickle +import subprocess +import sys +from pathlib import Path +from typing import Dict, Optional + +import numpy as np +import pandas as pd +import torch +import tqdm +from sklearn.model_selection import KFold, StratifiedKFold +from torch_geometric.data import Data, InMemoryDataset + +from programl.proto.program_graph_pb2 import ProgramGraph + +# make this file executable from anywhere + +full_path = os.path.realpath(__file__) +# print(full_path) +REPO_ROOT = full_path.rsplit("ProGraML", maxsplit=1)[0] + "ProGraML" +# print(REPO_ROOT) +# insert at 1, 0 is the script path (or '' in REPL) +sys.path.insert(1, REPO_ROOT) +REPO_ROOT = Path(REPO_ROOT) + + +# The vocabulary files used in the dataflow experiments. +PROGRAML_VOCABULARY = REPO_ROOT / "deeplearning/ml4pl/poj104/programl_vocabulary.csv" +CDFG_VOCABULARY = REPO_ROOT / "deeplearning/ml4pl/poj104/cdfg_vocabulary.csv" +assert PROGRAML_VOCABULARY.is_file(), f"File not found: {PROGRAML_VOCABULARY}" +assert CDFG_VOCABULARY.is_file(), f"File not found: {CDFG_VOCABULARY}" + +# The path of the graph2cdfg binary which converts ProGraML graphs to the CDFG +# representation. +# +# To build this file, clone the ProGraML repo and build +# //programl/cmd:graph2cdfg: +# +# 1. git clone https://github.com/ChrisCummins/ProGraML.git +# 2. cd ProGraML +# 3. git checkout 2d93e5e14bf321336f1928d3364e9d7196cee995 +# 4. bazel build -c opt //programl/cmd:graph2cdfg +# 5. cp -v bazel-bin/programl/cmd/graph2cdfg ${THIS_DIR} +# +GRAPH2CDFG = REPO_ROOT / "deeplearning/ml4pl/poj104/graph2cdfg" +assert GRAPH2CDFG.is_file(), f"File not found: {GRAPH2CDFG}" + + +def load(file: str, cdfg: bool = False) -> ProgramGraph: + """Read a ProgramGraph protocol buffer from file. + + Args: + file: The path of the ProgramGraph protocol buffer to load. + cdfg: If true, convert the graph to CDFG during load. + Returns: + graph: the proto of the programl / CDFG graph + orig_graph: the original programl proto (that contains graph level labels) + """ + graph = ProgramGraph() + with open(file, "rb") as f: + proto = f.read() + + if cdfg: + # hotfix missing graph labels in cdfg proto + orig_graph = ProgramGraph() + orig_graph.ParseFromString(proto) + + graph2cdfg = subprocess.Popen( + [str(GRAPH2CDFG), "--stdin_fmt=pb", "--stdout_fmt=pb"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + proto, _ = graph2cdfg.communicate(proto) + assert not graph2cdfg.returncode, f"CDFG conversion failed: {file}" + + graph.ParseFromString(proto) + + if not cdfg: + orig_graph = graph + return graph, orig_graph + + +def load_vocabulary(path: Path): + """Read the vocabulary file used in the dataflow experiments.""" + vocab = {} + with open(path) as f: + vocab_file = csv.reader(f.readlines(), delimiter="\t") + for i, row in enumerate(vocab_file, start=-1): + if i == -1: # Skip the header. + continue + (_, _, _, text) = row + vocab[text] = i + + return vocab + + +class AblationVocab(enum.IntEnum): + # No ablation - use the full vocabulary (default). + NONE = 0 + # Ignore the vocabulary - every node has an x value of 0. + NO_VOCAB = 1 + # Use a 3-element vocabulary based on the node type: + # 0 - Instruction node + # 1 - Variable node + # 2 - Constant node + NODE_TYPE_ONLY = 2 + + +def filename( + split: str, cdfg: bool = False, ablation_vocab: AblationVocab = AblationVocab.NONE +) -> str: + """Generate the name for a data file. + + Args: + split: The name of the split. + cdfg: Whether using CDFG representation. + ablate_vocab: The ablation vocab type. + + Returns: + A file name which uniquely identifies this combination of + split/cdfg/ablation. + """ + name = str(split) + if cdfg: + name = f"{name}_cdfg" + if ablation_vocab != AblationVocab.NONE: + # transform if ablation_vocab was passed as int. + if type(ablation_vocab) == int: + ablation_vocab = AblationVocab(ablation_vocab) + name = f"{name}_{ablation_vocab.name.lower()}" + return f"{name}_data.pt" + + +def nx2data( + graph: ProgramGraph, + vocabulary: Dict[str, int], + y_feature_name: Optional[str] = None, + ignore_profile_info=True, + ablate_vocab=AblationVocab.NONE, + orig_graph: ProgramGraph = None, +): + r"""Converts a program graph protocol buffer to a + :class:`torch_geometric.data.Data` instance. + + Args: + graph A program graph protocol buffer. + vocabulary A map from node text to vocabulary indices. + y_feature_name The name of the graph-level feature to use as class label. + ablate_vocab Whether to use an ablation vocabulary. + orig_graph A program graph protocol buffer that has graph level labels. + """ + + # collect edge_index + edge_tuples = [(edge.source, edge.target) for edge in graph.edge] + edge_index = torch.tensor(edge_tuples).t().contiguous() + + # collect edge_attr + positions = torch.tensor([edge.position for edge in graph.edge]) + flows = torch.tensor([int(edge.flow) for edge in graph.edge]) + + edge_attr = torch.cat([flows, positions]).view(2, -1).t().contiguous() + + # collect x + if ablate_vocab == AblationVocab.NONE: + vocabulary_indices = vocab_ids = [ + vocabulary.get(node.text, len(vocabulary)) for node in graph.node + ] + elif ablate_vocab == AblationVocab.NO_VOCAB: + vocabulary_indices = [0] * len(graph.node) + elif ablate_vocab == AblationVocab.NODE_TYPE_ONLY: + vocabulary_indices = [int(node.type) for node in graph.node] + else: + raise NotImplementedError("unreachable") + + xs = torch.tensor(vocabulary_indices) + types = torch.tensor([int(node.type) for node in graph.node]) + + x = torch.cat([xs, types]).view(2, -1).t().contiguous() + + assert ( + edge_attr.size()[0] == edge_index.size()[1] + ), f"edge_attr={edge_attr.size()} size mismatch with edge_index={edge_index.size()}" + + data_dict = { + "x": x, + "edge_index": edge_index, + "edge_attr": edge_attr, + } + + # maybe collect these data too + if y_feature_name is not None: + assert orig_graph is not None, "need orig_graph to retrieve graph level labels!" + y = torch.tensor( + orig_graph.features.feature[y_feature_name].int64_list.value[0] + ).view( + 1 + ) # <1> + if y_feature_name == "poj104_label": + y -= 1 + data_dict["y"] = y + + # branch prediction / profile info specific + if not ignore_profile_info: + raise NotImplementedError( + "profile info is not supported with the new nx2data (from programgraph) adaptation." + ) + profile_info = [] + for i, node_data in nx_graph.nodes(data=True): + # default to -1, -1, -1 if not all profile info is given. + if not ( + node_data.get("llvm_profile_true_weight") is not None + and node_data.get("llvm_profile_false_weight") is not None + and node_data.get("llvm_profile_total_weight") is not None + ): + mask = 0 + true_weight = -1 + false_weight = -1 + total_weight = -1 + else: + mask = 1 + true_weight = node_data["llvm_profile_true_weight"] + false_weight = node_data["llvm_profile_false_weight"] + total_weight = node_data["llvm_profile_total_weight"] + + profile_info.append([mask, true_weight, false_weight, total_weight]) + + data_dict["profile_info"] = torch.tensor(profile_info) + + # make Data + data = Data(**data_dict) + + return data + + +class BranchPredictionDataset(InMemoryDataset): + def __init__( + self, + root="deeplearning/ml4pl/poj104/branch_prediction_data", + split="train", + transform=None, + pre_transform=None, + train_subset=[0, 100], + train_subset_seed=0, + ): + """ + Args: + train_subset: [start_percentile, stop_percentile) default [0,100). + sample a random (but fixed) train set of data in slice by percent, with given seed. + train_subset_seed: seed for the train_subset fixed random permutation. + """ + self.split = split + self.train_subset = train_subset + self.train_subset_seed = train_subset_seed + super().__init__(root, transform, pre_transform) + + assert split in [ + "train" + ], "The BranchPrediction dataset only has a 'train' split. use train_subset=[0,x] and [x, 100] for training and testing." + self.data, self.slices = torch.load(self.processed_paths[0]) + pass + + @property + def raw_file_names(self): + """A list of files that need to be found in the raw_dir in order to skip the download""" + return [] # not implemented here + + @property + def processed_file_names(self): + """A list of files in the processed_dir which needs to be found in order to skip the processing.""" + base = f"{self.split}_data.pt" + + if tuple(self.train_subset) == (0, 100) or self.split in ["val", "test"]: + return [base] + else: + assert self.split == "train" + return [ + f"{self.split}_data_subset_{self.train_subset[0]}_{self.train_subset[1]}_seed_{self.train_subset_seed}.pt" + ] + + def download(self): + """Download raw data to `self.raw_dir`""" + pass # not implemented + + def _save_train_subset(self): + """saves a train_subset of self to file. + Percentile slice is taken according to self.train_subset + with a fixed random permutation with self.train_subset_seed. + """ + import numpy as np + + perm = np.random.RandomState(self.train_subset_seed).permutation(len(self)) + + # take slice of perm according to self.train_subset + start = np.math.floor(len(self) / 100 * self.train_subset[0]) + stop = np.math.floor(len(self) / 100 * self.train_subset[1]) + perm = perm[start:stop] + print(f"Fixed permutation starts with: {perm[:min(30, len(perm))]}") + + dataset = self.__indexing__(perm) + + data, slices = dataset.data, dataset.slices + torch.save((data, slices), self.processed_paths[0]) + return + + def return_cross_validation_splits(self, split): + assert self.train_subset == [ + 0, + 100, + ], "Do cross-validation on the whole dataset!" + # num_samples = len(self) + # perm = np.random.RandomState(self.train_subset_seed).permutation(len(self)) + + # 10-fold cross-validation + n_splits = 10 + kf = KFold(n_splits=n_splits, shuffle=True, random_state=42) + (train_index, test_index) = list(kf.split(range(len(self))))[split] + train_data = self.__indexing__(train_index) + test_data = self.__indexing__(test_index) + return train_data, test_data + + def filter_max_num_nodes(self, max_num_nodes): + idx = [] + for i, d in enumerate(self): + if d.num_nodes <= max_num_nodes: + idx.append(i) + dataset = self.__indexing__(idx) + print( + f"Filtering out graphs larger than {max_num_nodes} yields a dataset with {len(dataset)}/{len(self)} samples remaining." + ) + return dataset + + def process(self): + """Processes raw data and saves it into the `processed_dir`. + New implementation: + Here specifically it will collect all '*.ll.pickle' files recursively from subdirectories of `root` + and process the loaded nx graphs to Data. + Old implementation: + Instead of looking for .ll.pickle (nx graphs), we directly look for '*.data.p' files. + """ + # check if we need to create the full dataset: + full_dataset = Path(self.processed_dir) / f"{self.split}_data.pt" + if full_dataset.is_file(): + assert self.split == "train", "here shouldnt be reachable." + print( + f"Full dataset found. Generating train_subset={self.train_subset} with seed={self.train_subset_seed}" + ) + # just get the split and save it + self.data, self.slices = torch.load(full_dataset) + self._save_train_subset() + print( + f"Saved train_subset={self.train_subset} with seed={self.train_subset_seed} to disk." + ) + return + + # ~~~~~ we need to create the full dataset ~~~~~~~~~~~ + assert not full_dataset.is_file(), "shouldnt be" + processed_path = str(full_dataset) + + # read data into huge `Data` list. + data_list = [] + + ds_base = Path(self.root) + print(f"Creating {self.split} dataset at {str(ds_base)}") + # TODO change this line to go to the new format + # out_base = ds_base / ('ir_' + self.split + '_programl') + # assert out_base.exists(), f"{out_base} doesn't exist!" + # TODO collect .ll.pickle instead and call nx2data on the fly! + print(f"=== DATASET {str(ds_base)}: Collecting .data.p files into dataset") + + # files = list(ds_base.rglob('*.data.p')) + # files = list(ds_base.rglob('*.ll.pickle')) + files = list(ds_base.rglob("*.ll.p")) + + for file in tqdm.tqdm(files): + if not file.is_file(): + continue + try: + nx_graph = load(file) + except EOFError: + print(f"Failing to unpickle bc. EOFError on {file}! Skipping ...") + continue + try: + data = nx2data(nx_graph, ignore_profile_info=False) + data_list.append(data) + except IndexError: + print( + f"Failing nx2data bc IndexError (prob. empty graph) on {file}! Skipping ..." + ) + continue + + print(f" * COMPLETED * === DATASET {ds_base}: now pre-filtering...") + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + print( + f" * COMPLETED * === DATASET {ds_base}: Completed filtering, now pre_transforming..." + ) + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + print(f" * COMPLETED * === DATASET {ds_base}: saving to disk...") + self.data, self.slices = self.collate(data_list) + torch.save((self.data, self.slices), processed_path) + + # maybe save train_subset as well + if not tuple(self.train_subset) == (0, 100) and self.split not in [ + "val", + "test", + ]: + self._save_train_subset() + + +class NCCDataset(InMemoryDataset): + def __init__( + self, + root=REPO_ROOT / "deeplearning/ml4pl/poj104/ncc_data", + split="train", + transform=None, + pre_transform=None, + train_subset=[0, 100], + train_subset_seed=0, + ): + """ + NCC dataset + + Args: + train_subset: [start_percentile, stop_percentile) default [0,100). + sample a random (but fixed) train set of data in slice by percent, with given seed. + train_subset_seed: seed for the train_subset fixed random permutation. + + """ + self.split = split + self.train_subset = train_subset + self.train_subset_seed = train_subset_seed + super().__init__(root, transform, pre_transform) + + assert split in [ + "train" + ], "The NCC dataset only has a 'train' split. use train_subset=[0,x] and [x, 100] for training and testing." + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + """A list of files that need to be found in the raw_dir in order to skip the download""" + return [] # not implemented here + + @property + def processed_file_names(self): + """A list of files in the processed_dir which needs to be found in order to skip the processing.""" + base = f"{self.split}_data.pt" + + if tuple(self.train_subset) == (0, 100) or self.split in ["val", "test"]: + return [base] + else: + assert self.split == "train" + return [ + f"{self.split}_data_subset_{self.train_subset[0]}_{self.train_subset[1]}_seed_{self.train_subset_seed}.pt" + ] + + def download(self): + """Download raw data to `self.raw_dir`""" + pass # not implemented + + def _save_train_subset(self): + """saves a train_subset of self to file. + Percentile slice is taken according to self.train_subset + with a fixed random permutation with self.train_subset_seed. + """ + import numpy as np + + perm = np.random.RandomState(self.train_subset_seed).permutation(len(self)) + + # take slice of perm according to self.train_subset + start = np.math.floor(len(self) / 100 * self.train_subset[0]) + stop = np.math.floor(len(self) / 100 * self.train_subset[1]) + perm = perm[start:stop] + print(f"Fixed permutation starts with: {perm[:min(30, len(perm))]}") + + dataset = self.__indexing__(perm) + + data, slices = dataset.data, dataset.slices + torch.save((data, slices), self.processed_paths[0]) + return + + def filter_max_num_nodes(self, max_num_nodes): + idx = [] + for i, d in enumerate(self): + if d.num_nodes <= max_num_nodes: + idx.append(i) + dataset = self.__indexing__(idx) + print( + f"Filtering out graphs larger than {max_num_nodes} yields a dataset with {len(dataset)}/{len(self)} samples remaining." + ) + return dataset + + def process(self): + """Processes raw data and saves it into the `processed_dir`. + New implementation: + Here specifically it will collect all '*.ll.pickle' files recursively from subdirectories of `root` + and process the loaded nx graphs to Data. + Old implementation: + Instead of looking for .ll.pickle (nx graphs), we directly look for '*.data.p' files. + """ + # check if we need to create the full dataset: + full_dataset = Path(self.processed_dir) / f"{self.split}_data.pt" + if full_dataset.is_file(): + assert self.split == "train", "here shouldnt be reachable." + print( + f"Full dataset found. Generating train_subset={self.train_subset} with seed={self.train_subset_seed}" + ) + # just get the split and save it + self.data, self.slices = torch.load(full_dataset) + self._save_train_subset() + print( + f"Saved train_subset={self.train_subset} with seed={self.train_subset_seed} to disk." + ) + return + + # ~~~~~ we need to create the full dataset ~~~~~~~~~~~ + assert not full_dataset.is_file(), "shouldnt be" + processed_path = str(full_dataset) + + # read data into huge `Data` list. + data_list = [] + + ds_base = Path(self.root) + print(f"Creating {self.split} dataset at {str(ds_base)}") + # TODO change this line to go to the new format + # out_base = ds_base / ('ir_' + self.split + '_programl') + # assert out_base.exists(), f"{out_base} doesn't exist!" + # TODO collect .ll.pickle instead and call nx2data on the fly! + print(f"=== DATASET {str(ds_base)}: Collecting .data.p files into dataset") + + # files = list(ds_base.rglob('*.data.p')) + # files = list(ds_base.rglob('*.ll.pickle')) + files = list(ds_base.rglob("*.ll.p")) + + for file in tqdm.tqdm(files): + if not file.is_file(): + continue + try: + nx_graph = load(file) + except EOFError: + print(f"Failing to unpickle bc. EOFError on {file}! Skipping ...") + continue + try: + data = nx2data(nx_graph) + data_list.append(data) + except IndexError: + print( + f"Failing nx2data bc IndexError (prob. empty graph) on {file}! Skipping ..." + ) + continue + + print(f" * COMPLETED * === DATASET {ds_base}: now pre-filtering...") + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + print( + f" * COMPLETED * === DATASET {ds_base}: Completed filtering, now pre_transforming..." + ) + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + print(f" * COMPLETED * === DATASET {ds_base}: saving to disk...") + self.data, self.slices = self.collate(data_list) + torch.save((self.data, self.slices), processed_path) + + # maybe save train_subset as well + if not tuple(self.train_subset) == (0, 100) and self.split not in [ + "val", + "test", + ]: + self._save_train_subset() + + +class LegacyNCCDataset(InMemoryDataset): + def __init__( + self, + root="deeplearning/ml4pl/poj104/unsupervised_ncc_data", + split="train", + transform=None, + pre_transform=None, + train_subset=[0, 100], + train_subset_seed=0, + ): + """ + Args: + train_subset: [start_percentile, stop_percentile) default [0,100). + sample a random (but fixed) train set of data in slice by percent, with given seed. + train_subset_seed: seed for the train_subset fixed random permutation. + + """ + self.split = split + self.train_subset = train_subset + self.train_subset_seed = train_subset_seed + super().__init__(root, transform, pre_transform) + + assert split in [ + "train" + ], "The NCC dataset only has a 'train' split. use train_subset=[0,x] and [x, 100] for training and testing." + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + """A list of files that need to be found in the raw_dir in order to skip the download""" + return [] # not implemented here + + @property + def processed_file_names(self): + """A list of files in the processed_dir which needs to be found in order to skip the processing.""" + base = f"{self.split}_data.pt" + + if tuple(self.train_subset) == (0, 100) or self.split in ["val", "test"]: + return [base] + else: + assert self.split == "train" + return [ + f"{self.split}_data_subset_{self.train_subset[0]}_{self.train_subset[1]}_seed_{self.train_subset_seed}.pt" + ] + + def download(self): + """Download raw data to `self.raw_dir`""" + pass # not implemented + + def _save_train_subset(self): + """saves a train_subset of self to file. + Percentile slice is taken according to self.train_subset + with a fixed random permutation with self.train_subset_seed. + """ + import numpy as np + + perm = np.random.RandomState(self.train_subset_seed).permutation(len(self)) + + # take slice of perm according to self.train_subset + start = np.math.floor(len(self) / 100 * self.train_subset[0]) + stop = np.math.floor(len(self) / 100 * self.train_subset[1]) + perm = perm[start:stop] + print(f"Fixed permutation starts with: {perm[:min(30, len(perm))]}") + + dataset = self.__indexing__(perm) + + data, slices = dataset.data, dataset.slices + torch.save((data, slices), self.processed_paths[0]) + return + + def filter_max_num_nodes(self, max_num_nodes): + idx = [] + for i, d in enumerate(self): + if d.num_nodes <= max_num_nodes: + idx.append(i) + dataset = self.__indexing__(idx) + print( + f"Filtering out graphs larger than {max_num_nodes} yields a dataset with {len(dataset)}/{len(self)} samples remaining." + ) + return dataset + + def process(self): + """Processes raw data and saves it into the `processed_dir`. + New implementation: + Here specifically it will collect all '*.ll.pickle' files recursively from subdirectories of `root` + and process the loaded nx graphs to Data. + Old implementation: + Instead of looking for .ll.pickle (nx graphs), we directly look for '*.data.p' files. + """ + # check if we need to create the full dataset: + full_dataset = Path(self.processed_dir) / f"{self.split}_data.pt" + if full_dataset.is_file(): + assert self.split == "train", "here shouldnt be reachable." + print( + f"Full dataset found. Generating train_subset={self.train_subset} with seed={self.train_subset_seed}" + ) + # just get the split and save it + self.data, self.slices = torch.load(full_dataset) + self._save_train_subset() + print( + f"Saved train_subset={self.train_subset} with seed={self.train_subset_seed} to disk." + ) + return + + # ~~~~~ we need to create the full dataset ~~~~~~~~~~~ + assert not full_dataset.is_file(), "shouldnt be" + processed_path = str(full_dataset) + + # read data into huge `Data` list. + data_list = [] + + ds_base = Path(self.root) + print(f"Creating {self.split} dataset at {str(ds_base)}") + # TODO change this line to go to the new format + # out_base = ds_base / ('ir_' + self.split + '_programl') + # assert out_base.exists(), f"{out_base} doesn't exist!" + # TODO collect .ll.pickle instead and call nx2data on the fly! + print(f"=== DATASET {str(ds_base)}: Collecting .data.p files into dataset") + + files = list(ds_base.rglob("*.data.p")) + for file in tqdm.tqdm(files): + if not file.is_file(): + continue + data = load(file) + data_list.append(data) + + print(f" * COMPLETED * === DATASET {ds_base}: now pre-filtering...") + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + print( + f" * COMPLETED * === DATASET {ds_base}: Completed filtering, now pre_transforming..." + ) + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + print(f" * COMPLETED * === DATASET {ds_base}: saving to disk...") + self.data, self.slices = self.collate(data_list) + torch.save((self.data, self.slices), processed_path) + + # maybe save train_subset as well + if not tuple(self.train_subset) == (0, 100) and self.split not in [ + "val", + "test", + ]: + self._save_train_subset() + + +class ThreadcoarseningDataset(InMemoryDataset): + def __init__( + self, + root="deeplearning/ml4pl/poj104/threadcoarsening_data", + split="fail_fast", + transform=None, + pre_transform=None, + train_subset=[0, 100], + train_subset_seed=0, + ): + """ + Args: + train_subset: [start_percentile, stop_percentile) default [0,100). + sample a random (but fixed) train set of data in slice by percent, with given seed. + train_subset_seed: seed for the train_subset fixed random permutation. + split: 'amd' or 'nvidia' + + """ + assert split in [ + "Cypress", + "Tahiti", + "Fermi", + "Kepler", + ], f"Split is {split}, but has to be 'Cypress', 'Tahiti', 'Fermi', or 'Kepler'" + self.split = split + self.train_subset = train_subset + self.train_subset_seed = train_subset_seed + super().__init__(root, transform, pre_transform) + + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + return "threadcoarsening_data.zip" + + @property + def processed_file_names(self): + base = f"{self.split}_data.pt" + + if tuple(self.train_subset) == (0, 100): + return [base] + else: + return [ + f"{self.split}_data_subset_{self.train_subset[0]}_{self.train_subset[1]}_seed_{self.train_subset_seed}.pt" + ] + + def download(self): + # download to self.raw_dir + pass + + def return_cross_validation_splits(self, split): + assert self.train_subset == [ + 0, + 100, + ], "Do cross-validation on the whole dataset!" + assert ( + split <= 16 and split >= 0 + ), f"This dataset shall be 17-fold (leave one out) cross-validated, but split={split}." + # leave one out + n_splits = 17 + train_idx = list(range(n_splits)) + train_idx.remove(split) + train_data = self.__indexing__(train_idx) + test_data = self.__indexing__([split]) + return train_data, test_data + + def _save_train_subset(self): + """saves a train_subset of self to file. + Percentile slice is taken according to self.train_subset + with a fixed random permutation with self.train_subset_seed. + """ + import numpy as np + + perm = np.random.RandomState(self.train_subset_seed).permutation(len(self)) + + # take slice of perm according to self.train_subset + start = np.math.floor(len(self) / 100 * self.train_subset[0]) + stop = np.math.floor(len(self) / 100 * self.train_subset[1]) + perm = perm[start:stop] + print(f"Fixed permutation starts with: {perm[:min(100, len(perm))]}") + + dataset = self.__indexing__(perm) + + data, slices = dataset.data, dataset.slices + torch.save((data, slices), self.processed_paths[0]) + return + + def platform2str(self, platform): + if platform == "Fermi": + return "NVIDIA GTX 480" + elif platform == "Kepler": + return "NVIDIA Tesla K20c" + elif platform == "Cypress": + return "AMD Radeon HD 5900" + elif platform == "Tahiti": + return "AMD Tahiti 7970" + else: + raise LookupError + + def _get_all_runtimes(self, platform, df, oracles): + all_runtimes = {} + for kernel in oracles["kernel"]: + kernel_r = [] + for cf in [1, 2, 4, 8, 16, 32]: + row = df[(df["kernel"] == kernel) & (df["cf"] == cf)] + if len(row) == 1: + value = float(row[f"runtime_{platform}"].values) + if math.isnan(value): + print( + f"WARNING: Dataset contain NaN value (missing entry in runtimes most likely). kernel={kernel}, cf={cf}, value={row}.Replacing by infinity!." + ) + value = float("inf") + kernel_r.append(value) + elif len(row) == 0: + print( + f" kernel={kernel:>20} is missing cf={cf}. Ad-hoc inserting result from last existing coarsening factor." + ) + kernel_r.append(kernel_r[-1]) + else: + raise + all_runtimes[kernel] = kernel_r + return all_runtimes + + def process(self): + # check if we need to create the full dataset: + full_dataset = Path(self.processed_dir) / f"{self.split}_data.pt" + if full_dataset.is_file(): + print( + f"Full dataset {full_dataset.name} found. Generating train_subset={self.train_subset} with seed={self.train_subset_seed}" + ) + # just get the split and save it + self.data, self.slices = torch.load(full_dataset) + self._save_train_subset() + print( + f"Saved train_subset={self.train_subset} with seed={self.train_subset_seed} to disk." + ) + return + + # ~~~~~ we need to create the full dataset ~~~~~~~~~~~ + assert not full_dataset.is_file(), "shouldnt be" + processed_path = str(full_dataset) + + root = Path(self.root) + # Load runtime data + oracle_file = root / "pact-2014-oracles.csv" + oracles = pd.read_csv(oracle_file) + + runtimes_file = root / "pact-2014-runtimes.csv" + df = pd.read_csv(runtimes_file) + + print("\tReading data from", oracle_file, "\n\tand", runtimes_file) + + # get all runtime info per kernel + runtimes = self._get_all_runtimes(self.split, df=df, oracles=oracles) + + # get oracle labels + cfs = [1, 2, 4, 8, 16, 32] + y = np.array( + [cfs.index(int(x)) for x in oracles["cf_" + self.split]], dtype=np.int64 + ) + + # sanity check oracles against min runtimes + for i, (k, v) in enumerate(runtimes.items()): + assert int(y[i]) == np.argmin( + v + ), f"{i}: {k} {v}, argmin(v): {np.argmin(v)} vs. oracles data {int(y[i])}." + + # Add attributes to graphs + data_list = [] + + kernels = oracles["kernel"].values # list of strings of kernel names + + for kernel in kernels: + # legacy + # file = root / 'kernels_ir_programl' / (kernel + '.data.p') + file = root / "kernels_ir" / (kernel + ".ll.p") + assert file.exists(), f"input file not found: {file}" + # with open(file, 'rb') as f: + # data = pickle.load(f) + g = load(file) + data = nx2data(g) + # add attributes + data["y"] = torch.tensor([np.argmin(runtimes[kernel])], dtype=torch.long) + data["runtimes"] = torch.tensor([runtimes[kernel]]) + data_list.append(data) + + ################################## + + print( + f" * COMPLETED * === DATASET Threadcoarsening-{self.split}: now pre-filtering..." + ) + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + print( + f" * COMPLETED * === DATASET Threadcoarsening-{self.split}: Completed filtering, now pre_transforming..." + ) + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.data, self.slices = self.collate(data_list) + torch.save((self.data, self.slices), processed_path) + + # maybe save train_subset as well + if not tuple(self.train_subset) == (0, 100) and self.split not in [ + "val", + "test", + ]: + self._save_train_subset() + + +class DevmapDataset(InMemoryDataset): + def __init__( + self, + root="deeplearning/ml4pl/poj104/devmap_data", + split="fail", + transform=None, + pre_transform=None, + train_subset=[0, 100], + train_subset_seed=0, + cdfg: bool = False, + ablation_vocab: AblationVocab = AblationVocab.NONE, + ): + """ + Args: + train_subset: [start_percentile, stop_percentile) default [0,100). + sample a random (but fixed) train set of data in slice by percent, with given seed. + train_subset_seed: seed for the train_subset fixed random permutation. + split: 'amd' or 'nvidia' + cdfg: Use CDFG graph representation. + """ + assert split in [ + "amd", + "nvidia", + ], f"Split is {split}, but has to be 'amd' or 'nvidia'" + self.split = split + self.train_subset = train_subset + self.train_subset_seed = train_subset_seed + self.cdfg = cdfg + self.ablation_vocab = ablation_vocab + super().__init__(root, transform, pre_transform) + + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + return "devmap_data.zip" + + @property + def processed_file_names(self): + base = filename(self.split, self.cdfg, self.ablation_vocab) + + if tuple(self.train_subset) == (0, 100): + return [base] + else: + return [ + f"{name}_data_subset_{self.train_subset[0]}_{self.train_subset[1]}_seed_{self.train_subset_seed}.pt" + ] + + def download(self): + # download to self.raw_dir + pass + + def return_cross_validation_splits(self, split): + assert self.train_subset == [ + 0, + 100, + ], "Do cross-validation on the whole dataset!" + # num_samples = len(self) + # perm = np.random.RandomState(self.train_subset_seed).permutation(len(self)) + + # 10-fold cross-validation + n_splits = 10 + kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) + (train_index, test_index) = list(kf.split(self.data.y, self.data.y))[split] + train_data = self.__indexing__(train_index) + test_data = self.__indexing__(test_index) + return train_data, test_data + + def _save_train_subset(self): + """saves a train_subset of self to file. + Percentile slice is taken according to self.train_subset + with a fixed random permutation with self.train_subset_seed. + """ + perm = np.random.RandomState(self.train_subset_seed).permutation(len(self)) + + # take slice of perm according to self.train_subset + start = np.math.floor(len(self) / 100 * self.train_subset[0]) + stop = np.math.floor(len(self) / 100 * self.train_subset[1]) + perm = perm[start:stop] + print(f"Fixed permutation starts with: {perm[:min(100, len(perm))]}") + + dataset = self.__indexing__(perm) + + data, slices = dataset.data, dataset.slices + torch.save((data, slices), self.processed_paths[0]) + return + + def process(self): + # check if we need to create the full dataset: + name = filename(self.split, self.cdfg, self.ablation_vocab) + full_dataset = Path(self.processed_dir) / name + if full_dataset.is_file(): + print( + f"Full dataset {full_dataset.name} found. Generating train_subset={self.train_subset} with seed={self.train_subset_seed}" + ) + # just get the split and save it + self.data, self.slices = torch.load(full_dataset) + self._save_train_subset() + print( + f"Saved train_subset={self.train_subset} with seed={self.train_subset_seed} to disk." + ) + return + + # ~~~~~ we need to create the full dataset ~~~~~~~~~~~ + assert not full_dataset.is_file(), "shouldnt be" + processed_path = str(full_dataset) + + vocab = load_vocabulary(CDFG_VOCABULARY if self.cdfg else PROGRAML_VOCABULARY) + assert len(vocab) > 0, "vocab is empty :|" + + root = Path(self.root) + + # Get list of source file names and attributes + input_files = list((root / f"graphs_{self.split}").iterdir()) + + num_files = len(input_files) + print("\n--- Preparing to read", num_files, "input files") + + # read data into huge `Data` list. + + data_list = [] + for i in tqdm.tqdm(range(num_files)): + filename = input_files[i] + + proto, _ = load(filename, cdfg=self.cdfg) + data = nx2data(proto, vocabulary=vocab, ablate_vocab=self.ablation_vocab) + + # graph2cdfg conversion drops the graph features, so we may have to + # reload the graph. + if self.cdfg: + proto = load(filename) + + # Add the features and label. + proto_features = proto.features.feature + data["y"] = torch.tensor( + proto_features["devmap_label"].int64_list.value[0] + ).view(1) + data["aux_in"] = torch.tensor( + [ + proto_features["transfer_bytes"].int64_list.value[0], + proto_features["wgsize"].int64_list.value[0], + ] + ) + + data_list.append(data) + + ################################## + + print(f" * COMPLETED * === DATASET Devmap-{name}: now pre-filtering...") + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + print( + f" * COMPLETED * === DATASET Devmap-{name}: Completed filtering, now pre_transforming..." + ) + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.data, self.slices = self.collate(data_list) + torch.save((self.data, self.slices), processed_path) + + # maybe save train_subset as well + if not tuple(self.train_subset) == (0, 100): + self._save_train_subset() + + +class POJ104Dataset(InMemoryDataset): + def __init__( + self, + root="deeplearning/ml4pl/poj104/classifyapp_data", + split="fail", + transform=None, + pre_transform=None, + train_subset=[0, 100], + train_subset_seed=0, + cdfg: bool = False, + ablation_vocab: AblationVocab = AblationVocab.NONE, + ): + """ + Args: + train_subset: [start_percentile, stop_percentile) default [0,100). + sample a random (but fixed) train set of data in slice by percent, with given seed. + train_subset_seed: seed for the train_subset fixed random permutation. + cdfg: Use the CDFG graph format and vocabulary. + """ + self.split = split + self.train_subset = train_subset + self.train_subset_seed = train_subset_seed + self.cdfg = cdfg + self.ablation_vocab = ablation_vocab + super().__init__(root, transform, pre_transform) + + assert split in ["train", "val", "test"] + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + return "classifyapp_data.zip" # ['ir_val', 'ir_val_programl'] + + @property + def processed_file_names(self): + base = filename(self.split, self.cdfg, self.ablation_vocab) + + if tuple(self.train_subset) == (0, 100) or self.split in ["val", "test"]: + return [base] + else: + assert self.split == "train" + return [ + f"{self.split}_data_subset_{self.train_subset[0]}_{self.train_subset[1]}_seed_{self.train_subset_seed}.pt" + ] + + def download(self): + # download to self.raw_dir + pass + + def _save_train_subset(self): + """saves a train_subset of self to file. + Percentile slice is taken according to self.train_subset + with a fixed random permutation with self.train_subset_seed. + """ + import numpy as np + + perm = np.random.RandomState(self.train_subset_seed).permutation(len(self)) + + # take slice of perm according to self.train_subset + start = np.math.floor(len(self) / 100 * self.train_subset[0]) + stop = np.math.floor(len(self) / 100 * self.train_subset[1]) + perm = perm[start:stop] + print(f"Fixed permutation starts with: {perm[:min(100, len(perm))]}") + + dataset = self.__indexing__(perm) + + data, slices = dataset.data, dataset.slices + torch.save((data, slices), self.processed_paths[0]) + return + + def process(self): + # hardcoded + num_classes = 104 + + # check if we need to create the full dataset: + full_dataset = Path(self.processed_dir) / filename( + self.split, self.cdfg, self.ablation_vocab + ) + if full_dataset.is_file(): + assert self.split == "train", "here shouldnt be reachable." + print( + f"Full dataset found. Generating train_subset={self.train_subset} with seed={self.train_subset_seed}" + ) + # just get the split and save it + self.data, self.slices = torch.load(full_dataset) + self._save_train_subset() + print( + f"Saved train_subset={self.train_subset} with seed={self.train_subset_seed} to disk." + ) + return + + # ~~~~~ we need to create the full dataset ~~~~~~~~~~~ + assert not full_dataset.is_file(), "shouldnt be" + processed_path = str(full_dataset) + + # get vocab first + vocab = load_vocabulary(CDFG_VOCABULARY if self.cdfg else PROGRAML_VOCABULARY) + assert len(vocab) > 0, "vocab is empty :|" + # read data into huge `Data` list. + data_list = [] + + ds_base = Path(self.root) + print(f"Creating {self.split} dataset at {str(ds_base)}") + + split_folder = ds_base / (self.split) + assert split_folder.exists(), f"{split_folder} doesn't exist!" + + # collect .pb and call nx2data on the fly! + print( + f"=== DATASET {split_folder}: Collecting ProgramGraph.pb files into dataset" + ) + + # only take files from subfolders (with class names!) recursively + files = [x for x in split_folder.rglob("*ProgramGraph.pb")] + assert len(files) > 0, "no files collected. error." + for file in tqdm.tqdm(files): + # skip classes that are larger than what config says to enable debugging with less data + # class_label = int(file.parent.name) - 1 # let classes start from 0. + # if class_label >= num_classes: + # continue + + g, orig_graph = load(file, cdfg=self.cdfg) + data = nx2data( + graph=g, + vocabulary=vocab, + ablate_vocab=self.ablation_vocab, + y_feature_name="poj104_label", + orig_graph=orig_graph, + ) + data_list.append(data) + + print(f" * COMPLETED * === DATASET {split_folder}: now pre-filtering...") + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + print( + f" * COMPLETED * === DATASET {split_folder}: Completed filtering, now pre_transforming..." + ) + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.data, self.slices = self.collate(data_list) + torch.save((self.data, self.slices), processed_path) + + # maybe save train_subset as well + if not tuple(self.train_subset) == (0, 100) and self.split not in [ + "val", + "test", + ]: + self._save_train_subset() + + +class LegacyPOJ104Dataset(InMemoryDataset): + def __init__( + self, + root="deeplearning/ml4pl/poj104/classifyapp_data", + split="fail", + transform=None, + pre_transform=None, + train_subset=[0, 100], + train_subset_seed=0, + ): + """ + Args: + train_subset: [start_percentile, stop_percentile) default [0,100). + sample a random (but fixed) train set of data in slice by percent, with given seed. + train_subset_seed: seed for the train_subset fixed random permutation. + + """ + self.split = split + self.train_subset = train_subset + self.train_subset_seed = train_subset_seed + super().__init__(root, transform, pre_transform) + + assert split in ["train", "val", "test"] + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + return "classifyapp_data.zip" # ['ir_val', 'ir_val_programl'] + + @property + def processed_file_names(self): + base = f"{self.split}_data.pt" + + if tuple(self.train_subset) == (0, 100) or self.split in ["val", "test"]: + return [base] + else: + assert self.split == "train" + return [ + f"{self.split}_data_subset_{self.train_subset[0]}_{self.train_subset[1]}_seed_{self.train_subset_seed}.pt" + ] + + def download(self): + # download to self.raw_dir + pass + + def _save_train_subset(self): + """saves a train_subset of self to file. + Percentile slice is taken according to self.train_subset + with a fixed random permutation with self.train_subset_seed. + """ + import numpy as np + + perm = np.random.RandomState(self.train_subset_seed).permutation(len(self)) + + # take slice of perm according to self.train_subset + start = np.math.floor(len(self) / 100 * self.train_subset[0]) + stop = np.math.floor(len(self) / 100 * self.train_subset[1]) + perm = perm[start:stop] + print(f"Fixed permutation starts with: {perm[:min(100, len(perm))]}") + + dataset = self.__indexing__(perm) + + data, slices = dataset.data, dataset.slices + torch.save((data, slices), self.processed_paths[0]) + return + + def process(self): + # hardcoded + num_classes = 104 + + # check if we need to create the full dataset: + full_dataset = Path(self.processed_dir) / f"{self.split}_data.pt" + if full_dataset.is_file(): + assert self.split == "train", "here shouldnt be reachable." + print( + f"Full dataset found. Generating train_subset={self.train_subset} with seed={self.train_subset_seed}" + ) + # just get the split and save it + self.data, self.slices = torch.load(full_dataset) + self._save_train_subset() + print( + f"Saved train_subset={self.train_subset} with seed={self.train_subset_seed} to disk." + ) + return + + # ~~~~~ we need to create the full dataset ~~~~~~~~~~~ + assert not full_dataset.is_file(), "shouldnt be" + processed_path = str(full_dataset) + + # read data into huge `Data` list. + data_list = [] + + ds_base = Path(self.root) + print(f"Creating {self.split} dataset at {str(ds_base)}") + # TODO change this line to go to the new format + out_base = ds_base / ("ir_" + self.split + "_programl") + assert out_base.exists(), f"{out_base} doesn't exist!" + # TODO collect .ll.pickle instead and call nx2data on the fly! + print(f"=== DATASET {out_base}: Collecting .data.p files into dataset") + + folders = [ + x + for x in out_base.glob("*") + if x.is_dir() and x.name not in ["_nx", "_tuples"] + ] + for folder in tqdm.tqdm(folders): + # skip classes that are larger than what config says to enable debugging with less data + if int(folder.name) > num_classes: + continue + for k, file in enumerate(folder.glob("*.data.p")): + with open(file, "rb") as f: + data = pickle.load(f) + data_list.append(data) + + print(f" * COMPLETED * === DATASET {out_base}: now pre-filtering...") + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + print( + f" * COMPLETED * === DATASET {out_base}: Completed filtering, now pre_transforming..." + ) + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.data, self.slices = self.collate(data_list) + torch.save((self.data, self.slices), processed_path) + + # maybe save train_subset as well + if not tuple(self.train_subset) == (0, 100) and self.split not in [ + "val", + "test", + ]: + self._save_train_subset() + + +if __name__ == "__main__": + # d = NewNCCDataset() + # print(d.data) + root = "/home/zacharias/llvm_datasets/threadcoarsening_data/" + a = ThreadcoarseningDataset(root, "Cypress") + b = ThreadcoarseningDataset(root, "Tahiti") + c = ThreadcoarseningDataset(root, "Fermi") + d = ThreadcoarseningDataset(root, "Kepler") diff --git a/programl/task/graph_level_classification/modeling.py b/programl/task/graph_level_classification/modeling.py new file mode 100644 index 000000000..af076703f --- /dev/null +++ b/programl/task/graph_level_classification/modeling.py @@ -0,0 +1,1550 @@ +# Copyright 2019 the ProGraML authors. +# +# Contact Chris Cummins . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Modules that make up the pytorch GNN models.""" +import math + +import torch +import torch.nn.functional as F +from torch import nn, optim + +# Dependency moved into SelfAttention Message Layer +from torch_geometric.utils import softmax as scatter_softmax + +SMALL_NUMBER = 1e-8 + + +def print_state_dict(mod): + for n, t in mod.state_dict().items(): + print(n, t.size()) + + +def num_parameters(mod) -> int: + """Compute the number of trainable parameters in a nn.Module and its children. + OBS: + This function misses some parameters, i.e. in pytorch's official MultiheadAttention layer, + while the state dict doesn't miss any! + """ + num_params = sum( + param.numel() for param in mod.parameters(recurse=True) if param.requires_grad + ) + return f"{num_params:,} params, weights size: {num_params * 4 / 1e6:.3f}MB." + + +def assert_no_nan(tensor_list): + for i, t in enumerate(tensor_list): + assert not torch.isnan(t).any(), f"{i}: {tensor_list}" + + +################################################ +# Main Model classes +################################################ +class BaseGNNModel(nn.Module): + def __init__(self): + super().__init__() + + def setup(self, config, test_only): + self.loss = Loss(config) + # move model to device before making optimizer! + self.dev = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + + self.to(self.dev) + print(f"Moved model to {self.dev}") + + if test_only: + self.opt = None + self.eval() + else: + self.opt = self.get_optimizer(self.config) + + def get_optimizer(self, config): + return optim.AdamW(self.parameters(), lr=config.lr) + + def num_parameters(self) -> int: + """Compute the number of trainable parameters in this nn.Module and its children.""" + num_params = sum( + param.numel() + for param in self.parameters(recurse=True) + if param.requires_grad + ) + return f"{num_params:,} params, weights size: ~{num_params * 4 // 1e6:,}MB." + + def forward( + self, + vocab_ids, + labels, + edge_lists, + selector_ids=None, + pos_lists=None, + num_graphs=None, + graph_nodes_list=None, + node_types=None, + aux_in=None, + test_time_steps=None, + readout_mask=None, + runtimes=None, + ): + # Input + # selector_ids are ignored anyway by the NodeEmbeddings module that doesn't support them. + raw_in = self.node_embeddings(vocab_ids, selector_ids) + + # GNN + raw_out, raw_in, *unroll_stats = self.gnn( + edge_lists, raw_in, pos_lists, node_types + ) # OBS! self.gnn might change raw_in inplace, so use the two outputs + # instead! + + # Readout + if getattr(self.config, "has_graph_labels", False): + assert ( + graph_nodes_list is not None and num_graphs is not None + ), "has_graph_labels requires graph_nodes_list and num_graphs tensors." + nodewise_readout, graphwise_readout = self.readout( + raw_in, + raw_out, + graph_nodes_list=graph_nodes_list, + num_graphs=num_graphs, + auxiliary_features=aux_in, + readout_mask=readout_mask, + ) + logits = graphwise_readout + else: # nodewise only + nodewise_readout, _ = self.readout( + raw_in, raw_out, readout_mask=readout_mask + ) + graphwise_readout = None + logits = nodewise_readout + + # do the old style aux_readout if not aux_use_better is set + if getattr(self.config, "has_aux_input", False) and not getattr( + self.config, "aux_use_better", False + ): + assert ( + self.config.has_graph_labels is True + ), "Implementation hasn't been checked for use with aux_input and nodewise prediction! It could work or fail silently." + assert aux_in is not None + logits, graphwise_readout = self.aux_readout(logits, aux_in) + + if readout_mask is not None: # need to mask labels in the same fashion. + assert readout_mask.dtype == torch.bool, "Readout mask should be boolean!" + labels = labels[readout_mask] + + # Metrics + # accuracy, correct?, targets, maybe runtimes: actual, optimal + metrics_tuple = self.metrics(logits, labels, runtimes) + + outputs = (logits,) + metrics_tuple + (graphwise_readout,) + tuple(unroll_stats) + + return outputs + + +class GraphTransformerModel(BaseGNNModel): + """Transformer Encoder for Graphs.""" + + def __init__(self, config, pretrained_embeddings=None, test_only=False): + super().__init__() + self.config = config + self.node_embeddings = NodeEmbeddings(config) + self.gnn = GraphTransformerEncoder(config) + + # get readout and maybe tack on the aux readout + self.has_aux_input = getattr(self.config, "has_aux_input", False) + self.aux_use_better = getattr(self.config, "aux_use_better", False) + + if self.has_aux_input and self.aux_use_better: + self.readout = BetterAuxiliaryReadout(config) + elif self.has_aux_input: + self.readout = Readout(config) + self.aux_readout = AuxiliaryReadout(config) + else: + assert not self.aux_use_better, "aux_use_better only with has_aux_input!" + self.readout = Readout(config) + + self.metrics = Metrics() + + self.setup(config, test_only) + print(self) + print( + f"Number of trainable params in GraphTransformerModel: {self.num_parameters()}" + ) + + +class GGNNModel(BaseGNNModel): + def __init__(self, config, pretrained_embeddings=None, test_only=False): + super().__init__() + self.config = config + + # input layer + if getattr(config, "use_selector_embeddings", False): + self.node_embeddings = NodeEmbeddingsWithSelectors( + config, pretrained_embeddings + ) + else: + self.node_embeddings = NodeEmbeddings(config, pretrained_embeddings) + + # Readout layer + # get readout and maybe tack on the aux readout + self.has_aux_input = getattr(self.config, "has_aux_input", False) + self.aux_use_better = getattr(self.config, "aux_use_better", False) + if self.has_aux_input and self.aux_use_better: + self.readout = BetterAuxiliaryReadout(config) + elif self.has_aux_input: + self.readout = Readout(config) + self.aux_readout = AuxiliaryReadout(config) + else: + assert not self.aux_use_better, "aux_use_better only with has_aux_input!" + self.readout = Readout(config) + + # GNN + # make readout available to label_convergence tests in GGNN Proper (at runtime) + self.gnn = GGNNEncoder(config, readout=self.readout) + + # eval and training + self.metrics = Metrics() + + self.setup(config, test_only) + print(self) + print(f"Number of trainable params in GGNNModel: {self.num_parameters()}") + + +################################################ +# GNN Encoder: Message+Aggregate, Update +################################################ + +# GNN Encoder, i.e. everything between input and readout. +# Will rely on the different msg+aggr and update modules to build up a GNN. + + +class GGNNEncoder(nn.Module): + def __init__(self, config, readout=None): + super().__init__() + self.backward_edges = config.backward_edges + + self.gnn_layers = config.gnn_layers + self.message_weight_sharing = config.message_weight_sharing + self.update_weight_sharing = config.update_weight_sharing + message_layers = self.gnn_layers // self.message_weight_sharing + update_layers = self.gnn_layers // self.update_weight_sharing + assert ( + message_layers * self.message_weight_sharing == self.gnn_layers + ), "layer number and reuse mismatch." + assert ( + update_layers * self.update_weight_sharing == self.gnn_layers + ), "layer number and reuse mismatch." + # self.layer_timesteps = config.layer_timesteps + + self.position_embeddings = config.position_embeddings + + # optional eval time unrolling parameter + self.test_layer_timesteps = getattr(config, "test_layer_timesteps", 0) + self.unroll_strategy = getattr(config, "unroll_strategy", "none") + self.max_timesteps = getattr(config, "max_timesteps", 1000) + self.label_conv_threshold = getattr(config, "label_conv_threshold", 0.995) + self.label_conv_stable_steps = getattr(config, "label_conv_stable_steps", 1) + + # make readout avalable for label_convergence tests + if self.unroll_strategy == "label_convergence": + assert ( + not self.config.has_aux_input + ), "aux_input is not supported with label_convergence" + assert ( + readout + ), "Gotta pass instantiated readout module for label_convergence tests!" + self.readout = readout + + # Message and update layers + self.message = nn.ModuleList() + # for i in range(len(self.layer_timesteps)):§ + for i in range(message_layers): + self.message.append(GGNNMessageLayer(config)) + + self.update = nn.ModuleList() + # for i in range(len(self.layer_timesteps)): + for i in range(update_layers): + self.update.append(GGNNUpdateLayer(config)) + + def forward( + self, + edge_lists, + node_states, + pos_lists=None, + node_types=None, + test_time_steps=None, + ): + old_node_states = node_states.clone() + + if self.backward_edges: + back_edge_lists = [x.flip([1]) for x in edge_lists] + edge_lists.extend(back_edge_lists) + + # For backward edges we keep the positions of the forward edge! + if self.position_embeddings: + pos_lists.extend(pos_lists) + + # we allow for some fancy unrolling strategies. + # Currently only at eval time, but there is really no good reason for this. + assert ( + self.unroll_strategy == "none" + ), "New layer_timesteps not implemented for this unroll_strategy." + # if self.training or self.unroll_strategy == "none": + # #layer_timesteps = + # #layer_timesteps = self.layer_timesteps + # elif self.unroll_strategy == "constant": + # layer_timesteps = self.test_layer_timesteps + # elif self.unroll_strategy == "edge_count": + # assert ( + # test_time_steps is not None + # ), f"You need to pass test_time_steps or not use unroll_strategy '{self.unroll_strategy}''" + # layer_timesteps = [min(test_time_steps, self.max_timesteps)] + # elif self.unroll_strategy == "data_flow_max_steps": + # assert ( + # test_time_steps is not None + # ), f"You need to pass test_time_steps or not use unroll_strategy '{self.unroll_strategy}''" + # layer_timesteps = [min(test_time_steps, self.max_timesteps)] + # elif self.unroll_strategy == "label_convergence": + # node_states, unroll_steps, converged = self.label_convergence_forward( + # edge_lists, node_states, pos_lists, node_types, initial_node_states=old_node_states + # ) + # return node_states, old_node_states, unroll_steps, converged + # else: + # raise TypeError( + # "Unreachable! " + # f"Unroll strategy: {self.unroll_strategy}, training: {self.training}" + # ) + + # for (layer_idx, num_timesteps) in enumerate(layer_timesteps): + # for t in range(num_timesteps): + # messages = self.message[layer_idx](edge_lists, node_states, pos_lists) + # node_states = self.update[layer_idx](messages, node_states, node_types) + + for i in range(self.gnn_layers): + m_idx = i // self.message_weight_sharing + u_idx = i // self.update_weight_sharing + messages = self.message[m_idx](edge_lists, node_states, pos_lists) + node_states = self.update[u_idx](messages, node_states, node_types) + return node_states, old_node_states + + def label_convergence_forward( + self, edge_lists, node_states, pos_lists, node_types, initial_node_states + ): + assert ( + len(self.layer_timesteps) == 1 + ), f"Label convergence only supports one-layer GGNNs, but {len(self.layer_timesteps)} are configured in layer_timesteps: {self.layer_timesteps}" + + stable_steps, i = 0, 0 + old_tentative_labels = self.tentative_labels(initial_node_states, node_states) + + while True: + messages = self.message[0](edge_lists, node_states, pos_lists) + node_states = self.update[0](messages, node_states, node_types) + new_tentative_labels = self.tentative_labels( + initial_node_states, node_states + ) + i += 1 + + # return the new node states if their predictions match the old node states' predictions. + # It doesn't matter during testing since the predictions are the same anyway. + stability = ( + (new_tentative_labels == old_tentative_labels) + .to(dtype=torch.get_default_dtype()) + .mean() + ) + if stability >= self.label_conv_threshold: + stable_steps += 1 + + if stable_steps >= self.label_conv_stable_steps: + return node_states, i, True + + if i >= self.max_timesteps: # maybe escape + return node_states, i, False + + old_tentative_labels = new_tentative_labels + + raise ValueError("Serious Design Error: Unreachable code!") + + def tentative_labels(self, initial_node_states, node_states): + logits, _ = self.readout(initial_node_states, node_states) + preds = F.softmax(logits, dim=1) + predicted_labels = torch.argmax(preds, dim=1) + return predicted_labels + + +class GraphTransformerEncoder(nn.Module): + def __init__(self, config, readout=None): + super().__init__() + self.backward_edges = config.backward_edges + + self.gnn_layers = config.gnn_layers + self.message_weight_sharing = config.message_weight_sharing + self.update_weight_sharing = config.update_weight_sharing + message_layers = self.gnn_layers // self.message_weight_sharing + update_layers = self.gnn_layers // self.update_weight_sharing + assert ( + message_layers * self.message_weight_sharing == self.gnn_layers + ), "layer number and reuse mismatch." + assert ( + update_layers * self.update_weight_sharing == self.gnn_layers + ), "layer number and reuse mismatch." + # self.layer_timesteps = config.layer_timesteps + + self.use_node_types = getattr(config, "use_node_types", False) + assert not self.use_node_types, "not implemented" + + # Position Embeddings + if getattr(config, "position_embeddings", False): + self.selector_size = getattr(config, "selector_size", 0) + self.emb_size = config.emb_size + # we are going to lookup the pos embs only once per batch instead of within every message layer + self.position_embs = PositionEmbeddings() + # self.register_buffer("position_embs", + # PositionEmbeddings()( + # torch.arange(512, dtype=torch.get_default_dtype()), + # config.emb_size, + # dpad=getattr(config, 'selector_size', 0), + # ), + # ) + else: + self.position_embs = None + + # Message and update layers + self.message = nn.ModuleList() + # for i in range(len(self.layer_timesteps)): + for i in range(message_layers): + self.message.append(TypedSelfAttentionMessageLayer(config)) + + update_layer = getattr(config, "update_layer", "ff") + if update_layer == "ff": + UpdateLayer = TransformerUpdateLayer + elif update_layer == "gru": + UpdateLayer = GGNNUpdateLayer + else: + raise ValueError("config.update_layer has to be 'gru' or 'ff'!") + + self.update = nn.ModuleList() + # for i in range(len(self.layer_timesteps)): + for i in range(update_layers): + self.update.append(UpdateLayer(config)) + + def forward( + self, + edge_lists, + node_states, + pos_lists=None, + node_types=None, + test_time_steps=None, + ): + old_node_states = node_states.clone() + + # gather position embeddings for each edge + pos_emb_lists = None + if getattr(self, "position_embs") is not None: + pos_emb_lists = [] + for i, pl in enumerate(pos_lists): + # p_emb = torch.index_select(self.position_embs, dim=0, index=pl) + p_emb = self.position_embs( + pl.to(dtype=torch.get_default_dtype()), + self.emb_size, + dpad=self.selector_size, + ) + pos_emb_lists.append(p_emb) + + # Prepare for backward edges + if self.backward_edges: + back_edge_lists = [x.flip([1]) for x in edge_lists] + edge_lists.extend(back_edge_lists) + + # For backward edges we keep the positions of the forward edge! + if pos_emb_lists: + pos_emb_lists.extend(pos_emb_lists) + assert len(pos_emb_lists) == len(edge_lists) + + # Actual work + for i in range(self.gnn_layers): + m_idx = i // self.message_weight_sharing + u_idx = i // self.update_weight_sharing + messages = self.message[m_idx](edge_lists, node_states, pos_emb_lists) + node_states = self.update[u_idx](messages, node_states, node_types) + return node_states, old_node_states + + +###### Message Layers + + +class SelfAttentionMessageLayer(nn.Module): + """Implements transformer scaled dot-product self-attention, cf. Vaswani et al. 2017, + in a sparse setting on a graph. This reduces the time and space complexity + from O(N^2 * D) to O(M * D), which is much better if the graph has an average degree + that is << O(n), e.g. M \in O(n) instead of O(n^2)! + + NB: + The layer shares the weight-layout with pytorch's dense implementation, + i.e. makes them interoperable. + + Position information must be added to the node_states beforehand, + just like in the original. + + Args: + edge_lists list of edge_index tensors of shape + node_states + edges alternatively a single edge_index <2, M>! + (<2, M> is the torch geometric format!) + Returns: + attn_out: messages + attn_weights: optionally the attention weights + """ + + def __init__(self, config): + super().__init__() + self.embed_dim = config.hidden_size + + self.bias = config.attn_bias + self.num_heads = config.attn_num_heads + self.dropout_p = config.attn_dropout + + head_dim = self.embed_dim // self.num_heads + assert ( + head_dim * self.num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + # projection from input to q, k, v + # Myle Ott et al. apparently observed that initializing the qkv_projection (in one matrix) with xavier_uni and gain 1/sqrt(2) to be much better than 1. + self.qkv_in_proj = LinearNet( + self.embed_dim, self.embed_dim * 3, bias=self.bias, gain=1 / math.sqrt(2) + ) + self.out_proj = LinearNet(self.embed_dim, self.embed_dim, bias=self.bias) + self.dropout = nn.Dropout(p=self.dropout_p, inplace=True) + + def forward( + self, + edge_lists=None, + node_states=None, + pos_lists=None, + edges=None, + need_weights=False, + ): + """NB: pos_lists are ignored.""" + + # Glue Code: + assert node_states is not None + + # since we don't support edge-types in this layer, we just concatenate them here. + if edge_lists is not None: + assert edges is None + edges = torch.cat(edge_lists, dim=0).t() # t()! + else: + assert edges is not None + edge_sources = edges[0, :] + edge_targets = edges[1, :] + + # ~~~ Sparse Self-Attention ~~~ + # The implementation follows the official pytorch implementation, but sparse. + # Legend: + # Model hidden size D, + # number of attention heads h, + # number of edges M, + # number of nodes N + num_nodes, embed_dim = node_states.size() + assert embed_dim == self.embed_dim + + head_dim = embed_dim // self.num_heads + assert ( + head_dim * self.num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + # 1) get Q, K, V from node_states + # (needs to be merged with step 2 if we want to use positions..., bc + # they need to be added before retrieving Q, K, V) + + q, k, v = self.qkv_in_proj(node_states).chunk(3, dim=1) + + # 2) get Q', K', V' \in by doing an F.emb lookup on Q, K, V (maybe transposed) + # according to index + # edge_target for Q, and + # edge_sources for K, V + # since the receiver of messages is querying her neighbors. + q_prime = torch.index_select(q, dim=0, index=edge_targets) + k_prime = torch.index_select(k, dim=0, index=edge_sources) + v_prime = torch.index_select(v, dim=0, index=edge_sources) + + messages, attn_weights = self.sparse_attn_forward( + q_prime, k_prime, v_prime, num_nodes, edge_targets, need_weights + ) + if need_weights: + return messages, attn_weights + return messages + + def sparse_attn_forward( + self, q_prime, k_prime, v_prime, num_nodes, edge_targets, need_weights + ): + """Differently to dense self-attention, we expect q', k', v', + which are the query, key and value projected node_states [+pos embs] + index_selected by edge_targets, edge_sources and edge_source. + Args: + q_prime, k_prime, v_prime: + num_nodes: int(N) + edge_targets: + Returns: + attn_out: messages + attn_out_weights: optionally: + """ + embed_dim = q_prime.size()[1] + head_dim = embed_dim // self.num_heads + + # some checks + assert ( + head_dim * self.num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + assert ( + q_prime.size() == k_prime.size() + ), f"q_prime, k_prime size mismatch: {q_prime.size()}, {k_prime.size()}" + assert ( + q_prime.size()[0] == v_prime.size()[0] + ), "number of queries and values mismatch" + + # ~~~ Sparse Self-Attention ~~~ + # The implementation follows the official pytorch implementation, but sparse. + # Legend: + # Model hidden size D, + # number of attention heads h, + # number of edges M, + # number of nodes N + + # 3) Q' * K' (hadamard) and sum over D dimension, + # then scaled by sqrt(D) + # 3*) multi-head: If we want h multiple heads, then we should only sum the h segments of size D//h here. + # We will end up with unnormalized attention scores. + scores_prime = q_prime * k_prime + # sum segments of head_dim size into num_head chunks + scores = ( + scores_prime.transpose(0, 1) + .view(self.num_heads, head_dim, -1) + .sum(dim=1) + .t() + .contiguous() + ) + scaling = float(head_dim) ** -0.5 + scores = scores * scaling + assert scores.size() == (q_prime.size()[0], self.num_heads) # + + # 4) Scattered Softmax: + # Perform a softmax by normalizing scores with the sum of those scores + # where edge_targets coincide (meaning incoming edges to the same target are normalized) + # we end up with normalized self-attention scores + # 4*) multi-head: here we run the scattered_softmax in parallel over the h dimensions independently. + + # + attn_output_weights = scatter_softmax( + scores, index=edge_targets, num_nodes=num_nodes + ) # noqa: F821 + attn_output_weights = self.dropout(attn_output_weights) + + # 5) V' * %4: weight values V' by attention. + # The result up to here are the messages traveling across edges. + # 5* a) multi-head: get a view of V' with dim D_v split into + # then get back the old view + v_prime = v_prime.transpose(0, 1) + v_prime = v_prime.view(self.num_heads, head_dim, -1) + v_prime = v_prime.permute(2, 0, 1) # v_prime now: + + attn_out_per_edge = v_prime * attn_output_weights.unsqueeze(2) + attn_out_per_edge = attn_out_per_edge.view(-1, embed_dim) + + # 6) Scatter Add: aggregate messages via index_add with index edge_target + # to end up with + # messages + attn_out = torch.zeros( + num_nodes, embed_dim, dtype=torch.get_default_dtype(), device=q_prime.device + ) + attn_out.index_add_(0, edge_targets, attn_out_per_edge) + + # 5* b) Additionally project from the concatenation back to D. cf. vaswani et al. 2017 + attn_out = self.out_proj(attn_out) + + # now we have messages_by_targets! finally... + + if need_weights: + # average attention weights over heads (sorted like the edges) + attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads + return attn_out, attn_output_weights + return attn_out, None + + +class TypedSelfAttentionMessageLayer(SelfAttentionMessageLayer): + """Implements transformer scaled dot-product self-attention, cf. Vaswani et al. 2017, + in a sparse setting on a graph. This reduces the time and space complexity + from O(N^2 * D) to O(M * D), which is much better if the graph has an average degree + that is << O(n), e.g. M \in O(n) instead of O(n^2)! + + Graph Neural Network adaptations: + The layer supports different edge_types: + Each edge type gets their own k, v projection, but queries are shared. + The layer supports embedding edge-position information: + The position embedding is added to the attention keys only or + optionally both to k and v, but not to q. + + Forward Args: + edge_lists: list of edge_index tensors of size + node_states: + pos_lists: OBS: We expect these to be pos_emb_lists each of size + need_weights: optionally return avg attention weights per edge of size + Returns: + incoming messages per node of shape + """ + + def __init__(self, config): + # init as a module without running parent __init__. + nn.Module.__init__(self) + + self.edge_type_count = ( + config.edge_type_count * 2 + if config.backward_edges + else config.edge_type_count + ) + self.embed_dim = config.hidden_size + self.bias = config.attn_bias + self.num_heads = config.attn_num_heads + self.dropout_p = config.attn_dropout + + head_dim = self.embed_dim // self.num_heads + assert ( + head_dim * self.num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.position_embs = getattr(config, "position_embeddings", False) + self.attn_v_pos = getattr(config, "attn_v_pos", False) + if not self.position_embs: + assert ( + not self.attn_v_pos + ), "Use position_embeddings if you want attn_v_pos!" + + # projection from input to q, k, v + # Myle Ott et al. apparently observed that initializing the qkv_projection (in one matrix) + # with + # nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + # to be much better than only xavier. + # self.qkv_in_proj = LinearNet(self.embed_dim, self.embed_dim * 3, bias=self.bias, gain=1 / math.sqrt(2)) + + # in projection per edge type. + self.q_proj = LinearNet(self.embed_dim, self.embed_dim, bias=self.bias) + self.k_proj = nn.ModuleList() + self.v_proj = nn.ModuleList() + for i in range(self.edge_type_count): + self.k_proj.append( + LinearNet(self.embed_dim, self.embed_dim, bias=self.bias) + ) + self.v_proj.append( + LinearNet(self.embed_dim, self.embed_dim, bias=self.bias) + ) + + self.out_proj = LinearNet(self.embed_dim, self.embed_dim, bias=self.bias) + self.dropout = nn.Dropout(p=self.dropout_p, inplace=True) + + def forward(self, edge_lists, node_states, pos_lists=None, need_weights=False): + """Args: + edge_lists: list of edge_index tensors of size + node_states: + pos_lists: OBS: We expect these to be pos_emb_lists each of size + need_weights: optionally return avg attention weights per edge of size + """ + assert len(edge_lists) == self.edge_type_count + + num_nodes, embed_dim = node_states.size() + assert embed_dim == self.embed_dim + + head_dim = embed_dim // self.num_heads + assert ( + head_dim * self.num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + # 1) get Q', K', V' \in from node_states + # by index_select according to index + # edge_target for Q', and + # edge_sources for K', V' + # since the receiver of messages is querying her neighbors. + # 2) Optionally add pos_embs to K', V' + # 2) Then project from node_states into the q,k,v subspaces + + q = self.q_proj(node_states) + + q_primes, k_primes, v_primes, edge_targets_list = [], [], [], [] + + # carefully obtain keys and values and collect queries. + for i, el in enumerate(edge_lists): + edge_sources = el[:, 0] # el + edge_targets = el[:, 1] + edge_targets_list.append(edge_targets) + + q_prime = torch.index_select(q, dim=0, index=edge_targets) + + selected_nodes = torch.index_select(node_states, dim=0, index=edge_sources) + # maybe add position embeddings + if self.position_embs and self.attn_v_pos: + selected_nodes = selected_nodes + pos_lists[i] + v_prime = self.v_proj[i](selected_nodes) + k_prime = self.k_proj[i](selected_nodes) + elif self.position_embs: # but not on v + v_prime = self.v_proj[i](selected_nodes) + selected_nodes = selected_nodes + pos_lists[i] + k_prime = self.k_proj[i](selected_nodes) + else: + v_prime = self.v_proj[i](selected_nodes) + k_prime = self.k_proj[i](selected_nodes) + + q_primes.append(q_prime) + k_primes.append(k_prime) + v_primes.append(v_prime) + + edge_targets = torch.cat(edge_targets_list, dim=0) + q_prime = torch.cat(q_primes, dim=0) + k_prime = torch.cat(k_primes, dim=0) + v_prime = torch.cat(v_primes, dim=0) + + # ~~~~ From here, we are back in the general sparse self-attention setting ~~~~~ + messages, attn_weights = self.sparse_attn_forward( + q_prime, k_prime, v_prime, num_nodes, edge_targets, need_weights + ) + if need_weights: + return messages, attn_weights + return messages + + +class GGNNMessageLayer(nn.Module): + """Implements the MLP message function of the GGNN architecture, + optionally with position information embedded on edges. + Args: + edge_lists (for each edge type) + node_states + pos_lists (optionally) + Returns: + incoming messages per node of shape """ + + def __init__(self, config): + super().__init__() + self.edge_type_count = ( + config.edge_type_count * 2 + if config.backward_edges + else config.edge_type_count + ) + self.msg_mean_aggregation = config.msg_mean_aggregation + self.dim = config.hidden_size + + self.transform = LinearNet( + self.dim, + self.dim * self.edge_type_count, + bias=config.use_edge_bias, + dropout=config.edge_weight_dropout, + ) + + self.pos_transform = None + if getattr(config, "position_embeddings", False): + self.selector_size = getattr(config, "selector_size", 0) + self.emb_size = config.emb_size + self.position_embs = PositionEmbeddings() + + # legacy + # self.register_buffer( + # "position_embs", + # PositionEmbeddings()( + # torch.arange(512, dtype=torch.get_default_dtype()), + # config.emb_size, + # dpad=getattr(config, 'selector_size', 0), + # ), + # ) + self.pos_transform = LinearNet( + self.dim, + self.dim, + bias=config.use_edge_bias, + dropout=config.edge_weight_dropout, + ) + + def forward(self, edge_lists, node_states, pos_lists=None): + """edge_lists: [, ...]""" + + # all edge types are handled in one matrix, but we + # let propagated_states[i] be equal to the case with only edge_type i + # propagated_states = ( + # self.transform(node_states) + # .transpose(0, 1) + # .view(self.edge_type_count, self.dim, -1) + # ) + propagated_states = self.transform(node_states).chunk( + self.edge_type_count, dim=1 + ) + + messages_by_targets = torch.zeros_like(node_states) + if self.msg_mean_aggregation: + device = node_states.device + bincount = torch.zeros( + node_states.size()[0], dtype=torch.long, device=device + ) + + for i, edge_list in enumerate(edge_lists): + edge_targets = edge_list[:, 1] + edge_sources = edge_list[:, 0] + + # messages_by_source = F.embedding( + # edge_sources, propagated_states[i].transpose(0, 1) + # ) + messages_by_source = torch.index_select( + propagated_states[i], dim=0, index=edge_sources + ) + + if self.pos_transform: + pos_list = pos_lists[i] + # torch.index_select(pos_gating, dim=0, index=pos_list) + pos_by_source = self.position_embs( + pos_list.to(dtype=torch.get_default_dtype()), + self.emb_size, + dpad=self.selector_size, + ) + + pos_gating_by_source = 2 * torch.sigmoid( + self.pos_transform(pos_by_source) + ) + + # messages_by_source.mul_(pos_by_source) + messages_by_source = messages_by_source * pos_gating_by_source + + messages_by_targets.index_add_(0, edge_targets, messages_by_source) + + if self.msg_mean_aggregation: + bins = edge_targets.bincount(minlength=node_states.size()[0]) + bincount += bins + + if self.msg_mean_aggregation: + divisor = bincount.float() + divisor[bincount == 0] = 1.0 # avoid div by zero for lonely nodes + # messages_by_targets /= divisor.unsqueeze_(1) + SMALL_NUMBER + messages_by_targets = ( + messages_by_targets / divisor.unsqueeze_(1) + SMALL_NUMBER + ) + + return messages_by_targets + + +class PositionEmbeddings(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, positions, demb, dpad: int = 0): + """Transformer-like sinusoidal positional embeddings. + Args: + position: 1d long Tensor of positions, + demb: int size of embedding vector + """ + inv_freq = 1 / ( + 10000 ** (torch.arange(0.0, demb, 2.0, device=positions.device) / demb) + ) + + sinusoid_inp = torch.ger(positions, inv_freq) + pos_emb = torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) + + if dpad > 0: + in_length = positions.size()[0] + pad = torch.zeros((in_length, dpad)) + pos_emb = torch.cat([pos_emb, pad], dim=1) + assert torch.all( + pos_emb[:, -1] == torch.zeros(in_length) + ), f"test failed. pos_emb: \n{pos_emb}" + + return pos_emb + + # def forward(self, positions, dim, out): + # assert dim > 0, f'dim of position embs has to be > 0' + # power = 2 * (positions / 2) / dim + # position_enc = np.array( + # [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] + # for pos in range(n_pos)]) + # out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + # out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + # out.detach_() + # out.requires_grad = False + + +####### Update Layers + + +class GGNNUpdateLayer(nn.Module): + """GRU update function of GGNN architecture, optionally distinguishing two kinds of node types. + Args: + incoming messages (from message layer), + node_states , + node_types (optional) + Returns: + updated node_states + """ + + def __init__(self, config): + super().__init__() + self.dropout = config.graph_state_dropout + # TODO(github.com/ChrisCummins/ProGraML/issues/27): Maybe decouple hidden + # GRU size: make hidden GRU size larger and EdgeTrafo size non-square + # instead? Or implement stacking gru layers between message passing steps. + + self.gru = nn.GRUCell( + input_size=config.hidden_size, hidden_size=config.hidden_size + ) + + # currently only admits node types 0 and 1 for statements and identifiers. + self.use_node_types = getattr(config, "use_node_types", False) + if self.use_node_types: + self.id_gru = nn.GRUCell( + input_size=config.hidden_size, hidden_size=config.hidden_size + ) + + def forward(self, messages, node_states, node_types=None): + if self.use_node_types: + assert ( + node_types is not None + ), "Need to provide node_types if config.use_node_types!" + output = torch.zeros_like(node_states, device=node_states.device) + stmt_mask = node_types == 0 + output[stmt_mask] = self.gru(messages[stmt_mask], node_states[stmt_mask]) + id_mask = node_types == 1 + output[id_mask] = self.id_gru(messages[id_mask], node_states[id_mask]) + else: + output = self.gru(messages, node_states) + + if self.dropout > 0.0: + F.dropout(output, p=self.dropout, training=self.training, inplace=True) + return output + + +class TransformerUpdateLayer(nn.Module): + """Represents the residual MLP around the self-attention in the transformer + encoder layer. The implementation is sparse for usage in GNNs. + + Args: + messages (from self-attention layer) + node_states + node_types (optional and not yet implemented!) + Returns: + updated node_states + """ + + def __init__(self, config): + super().__init__() + self.use_node_types = getattr(config, "use_node_types", False) + assert not self.use_node_types, "not implemented" + + activation = config.tfmr_act # relu or gelu, default relu + dropout = config.tfmr_dropout # default 0.1 + dim_feedforward = config.tfmr_ff_sz # ~ 2.5 * model dim + + # Implementation of Feedforward model + self.linear1 = nn.Linear(config.hidden_size, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, config.hidden_size) + + self.norm1 = nn.LayerNorm(config.hidden_size) + self.norm2 = nn.LayerNorm(config.hidden_size) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = self.get_activation_fn(activation) + + def get_activation_fn(self, activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + else: + raise RuntimeError("activation should be relu/gelu, not %s." % activation) + + def forward(self, messages, node_states, node_types=None): + + # message layer is elsewhere! + # messages = self.self_attn(src, src, src)[0] + + # 1st 'Add & Norm' block (cf. vaswani et al. 2017, fig. 1) + node_states = node_states + self.dropout1(messages) + node_states = self.norm1(node_states) + + # 'Feed Forward' block + messages = self.linear2( + self.dropout(self.activation(self.linear1(node_states))) + ) + + # 2nd 'Add & Norm' block + node_states = node_states + self.dropout2(messages) + node_states = self.norm2(node_states) + + return node_states + + +######################################## +# GNN Output Layers +######################################## + + +class Readout(nn.Module): + """aka GatedRegression. See Eq. 4 in Gilmer et al. 2017 MPNN.""" + + def __init__(self, config): + super().__init__() + self.has_graph_labels = config.has_graph_labels + self.num_classes = config.num_classes + self.use_tanh_readout = getattr(config, "use_tanh_readout", False) + + self.regression_gate = LinearNet( + 2 * config.hidden_size, + self.num_classes, + dropout=config.output_dropout, + ) + self.regression_transform = LinearNet( + config.hidden_size, + self.num_classes, + dropout=config.output_dropout, + ) + + def forward( + self, + raw_node_in, + raw_node_out, + graph_nodes_list=None, + num_graphs=None, + auxiliary_features=None, + readout_mask=None, + ): + if readout_mask is not None: + # mask first to only process the stuff that goes into the loss function! + raw_node_in = raw_node_in[readout_mask] + raw_node_out = raw_node_out[readout_mask] + if graph_nodes_list is not None: + graph_nodes_list = graph_nodes_list[readout_mask] + + gate_input = torch.cat((raw_node_in, raw_node_out), dim=-1) + gating = torch.sigmoid(self.regression_gate(gate_input)) + if not self.use_tanh_readout: + nodewise_readout = gating * self.regression_transform(raw_node_out) + else: + nodewise_readout = gating * torch.tanh( + self.regression_transform(raw_node_out) + ) + + graph_readout = None + if self.has_graph_labels: + assert ( + graph_nodes_list is not None and num_graphs is not None + ), "has_graph_labels requires graph_nodes_list and num_graphs tensors." + # aggregate via sums over graphs + device = raw_node_out.device + graph_readout = torch.zeros(num_graphs, self.num_classes, device=device) + graph_readout.index_add_( + dim=0, index=graph_nodes_list, source=nodewise_readout + ) + if self.use_tanh_readout: + graph_readout = torch.tanh(graph_readout) + return nodewise_readout, graph_readout + + +class LinearNet(nn.Module): + """Single Linear layer with WeightDropout, ReLU and Xavier Uniform + initialization. Applies a linear transformation to the incoming data: + :math:`y = xA^T + b` + + Args: + in_features: size of each input sample + out_features: size of each output sample + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Shape: + - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of + additional dimensions and :math:`H_{in} = \text{in\_features}` + - Output: :math:`(N, *, H_{out})` where all but the last dimension + are the same shape as the input and :math:`H_{out} = \text{out\_features}`. + """ + + def __init__(self, in_features, out_features, bias=True, dropout=0.0, gain=1.0): + super().__init__() + self.dropout = dropout + self.in_features = in_features + self.out_features = out_features + self.gain = gain + self.test = nn.Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.test, gain=self.gain) + if self.bias is not None: + # fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + # bound = 1 / math.sqrt(fan_in) + # nn.init.uniform_(self.bias, -bound, bound) + nn.init.zeros_(self.bias) + + def forward(self, input): + if self.dropout > 0.0: + w = F.dropout(self.test, p=self.dropout, training=self.training) + else: + w = self.test + return F.linear(input, w, self.bias) + + def extra_repr(self): + return "in_features={}, out_features={}, bias={}, dropout={}".format( + self.in_features, + self.out_features, + self.bias is not None, + self.dropout, + ) + + +########################################### +# Mixing in graph-level features to readout + + +class AuxiliaryReadout(nn.Module): + """Produces per-graph predictions by combining + the per-graph predictions with auxiliary features. + Note that this AuxiliaryReadout after Readout is probably a bad idea + and BetterAuxiliaryReadout should be used instead.""" + + def __init__(self, config): + super().__init__() + self.num_classes = config.num_classes + self.aux_in_log1p = getattr(config, "aux_in_log1p", False) + assert ( + config.has_graph_labels + ), "We expect aux readout in combination with graph labels, not node labels" + self.feed_forward = None + + self.batch_norm = nn.BatchNorm1d(config.num_classes + config.aux_in_size) + self.feed_forward = nn.Sequential( + nn.Linear( + config.num_classes + config.aux_in_size, + config.aux_in_layer_size, + ), + nn.ReLU(), + nn.Dropout(config.output_dropout), + nn.Linear(config.aux_in_layer_size, config.num_classes), + ) + + def forward(self, graph_features, auxiliary_features): + assert ( + graph_features.size()[0] == auxiliary_features.size()[0] + ), "every graph needs aux_features. Dimension mismatch." + if self.aux_in_log1p: + auxiliary_features.log1p_() + + aggregate_features = torch.cat((graph_features, auxiliary_features), dim=1) + + normed_features = self.batch_norm(aggregate_features) + out = self.feed_forward(normed_features) + return out, graph_features + + +class BetterAuxiliaryReadout(nn.Module): + """Produces per-graph predictions by combining + the raw GNN Encoder output with auxiliary features. + The difference to AuxReadout(Readout()) is that the aux info + is concat'ed before the nodewise readout and not after the + reduction to graphwise predictions. + """ + + def __init__(self, config): + super().__init__() + + self.aux_in_log1p = getattr(config, "aux_in_log1p", False) + assert ( + config.has_graph_labels + ), "We expect aux readout in combination with graph labels, not node labels" + + self.has_graph_labels = config.has_graph_labels + self.num_classes = config.num_classes + + # now with aux_in concat'ed and batchnorm + self.regression_gate = nn.Sequential( + nn.BatchNorm1d(2 * config.hidden_size + config.aux_in_size), + LinearNet( + 2 * config.hidden_size + config.aux_in_size, + self.num_classes, + dropout=config.output_dropout, + ), + ) + # now with aux_in concat'ed and with intermediate layer + self.regression_transform = nn.Sequential( + nn.BatchNorm1d(config.hidden_size + config.aux_in_size), + LinearNet( + config.hidden_size + config.aux_in_size, + config.aux_in_layer_size, + dropout=config.output_dropout, + ), + nn.ReLU(), + LinearNet(config.aux_in_layer_size, config.num_classes), + ) + + def forward( + self, + raw_node_in, + raw_node_out, + graph_nodes_list, + num_graphs, + auxiliary_features, + readout_mask=None, + ): + assert ( + graph_nodes_list is not None and auxiliary_features is not None + ), "need those" + if readout_mask is not None: + # mask first to only process the stuff that goes into the loss function! + raw_node_in = raw_node_in[readout_mask] + raw_node_out = raw_node_out[readout_mask] + if graph_nodes_list is not None: + graph_nodes_list = graph_nodes_list[readout_mask] + + if self.aux_in_log1p: + auxiliary_features.log1p_() + aux_by_node = torch.index_select( + auxiliary_features, dim=0, index=graph_nodes_list + ) + + # info: the gate and regression include batch norm inside! + gate_input = torch.cat((raw_node_in, raw_node_out, aux_by_node), dim=-1) + gating = torch.sigmoid(self.regression_gate(gate_input)) + trafo_input = torch.cat((raw_node_out, aux_by_node), dim=-1) + nodewise_readout = gating * self.regression_transform(trafo_input) + + graph_readout = None + if self.has_graph_labels: + assert ( + graph_nodes_list is not None and num_graphs is not None + ), "has_graph_labels requires graph_nodes_list and num_graphs tensors." + # aggregate via sums over graphs + device = raw_node_out.device + graph_readout = torch.zeros(num_graphs, self.num_classes, device=device) + graph_readout.index_add_( + dim=0, index=graph_nodes_list, source=nodewise_readout + ) + return nodewise_readout, graph_readout + + +############################ +# GNN Input: Embedding Layers +############################ +# class NodeEmbeddingsForPretraining(nn.Module): +# """NodeEmbeddings with added embedding for [MASK] token.""" +# +# def __init__(self, config): +# super().__init__() +# +# print("Initializing with random embeddings for pretraining.") +# self.node_embs = nn.Embedding(config.vocab_size + 1, config.emb_size) +# +# def forward(self, vocab_ids): +# embs = self.node_embs(vocab_ids) +# return embs + + +class NodeEmbeddings(nn.Module): + """Construct node embeddings from node ids + Args: + pretrained_embeddings (Tensor, optional) – FloatTensor containing weights for + the Embedding. First dimension is being passed to Embedding as + num_embeddings, second as embedding_dim. + + Forward + Args: + vocab_ids: + Returns: + node_states: + """ + + # TODO(github.com/ChrisCummins/ProGraML/issues/27):: Maybe LayerNorm and + # Dropout on node_embeddings? + # TODO(github.com/ChrisCummins/ProGraML/issues/27):: Make selector embs + # trainable? + + def __init__(self, config, pretrained_embeddings=None): + super().__init__() + self.inst2vec_embeddings = config.inst2vec_embeddings + self.emb_size = config.emb_size + + if config.inst2vec_embeddings == "constant": + print("Using pre-trained inst2vec embeddings frozen.") + assert pretrained_embeddings is not None + assert ( + pretrained_embeddings.size()[0] == 8568 + ), "Wrong number of embs; don't come here with MLM models!" + self.node_embs = nn.Embedding.from_pretrained( + pretrained_embeddings, freeze=True + ) + elif config.inst2vec_embeddings == "zero": + init = torch.zeros(config.vocab_size, config.emb_size) + self.node_embs = nn.Embedding.from_pretrained(init, freeze=True) + elif config.inst2vec_embeddings == "constant_random": + init = torch.rand(config.vocab_size, config.emb_size) + self.node_embs = nn.Embedding.from_pretrained(init, freeze=True) + elif config.inst2vec_embeddings == "finetune": + print("Fine-tuning inst2vec embeddings") + assert pretrained_embeddings is not None + assert ( + pretrained_embeddings.size()[0] == 8568 + ), "Wrong number of embs; don't come here with MLM models!" + self.node_embs = nn.Embedding.from_pretrained( + pretrained_embeddings, freeze=False + ) + elif config.inst2vec_embeddings == "random": + print("Initializing with random embeddings") + self.node_embs = nn.Embedding(config.vocab_size, config.emb_size) + elif config.inst2vec_embeddings == "none": + print("Initializing with a embedding for statements and identifiers each.") + self.node_embs = nn.Embedding(2, config.emb_size) + else: + raise NotImplementedError(config.inst2vec_embeddings) + + def forward(self, vocab_ids, *ignored_args, **ignored_kwargs): + if self.inst2vec_embeddings == "none": + # map IDs to 1 and everything else to 0 + ids = (vocab_ids == 8565).to(torch.long) # !IDENTIFIER token id + embs = self.node_embs(ids) + else: # normal embeddings + embs = self.node_embs(vocab_ids) + + return embs + + +class NodeEmbeddingsWithSelectors(NodeEmbeddings): + """Construct node embeddings as content embeddings + selector embeddings. + + Args: + pretrained_embeddings (Tensor, optional) – FloatTensor containing weights for + the Embedding. First dimension is being passed to Embedding as + num_embeddings, second as embedding_dim. + + Forward + Args: + vocab_ids: + selector_ids: + Returns: + node_states: + """ + + def __init__(self, config, pretrained_embeddings=None): + super().__init__(config, pretrained_embeddings) + + self.node_embs = super().forward + assert ( + config.use_selector_embeddings + ), "This Module is for use with use_selector_embeddings!" + + selector_init = torch.tensor( + # TODO(github.com/ChrisCummins/ProGraML/issues/27): x50 is maybe a + # problem for unrolling (for selector_embs)? + [[0, 50.0], [50.0, 0]], + dtype=torch.get_default_dtype(), + ) + self.selector_embs = nn.Embedding.from_pretrained(selector_init, freeze=True) + + def forward(self, vocab_ids, selector_ids): + node_embs = self.node_embs(vocab_ids) + selector_embs = self.selector_embs(selector_ids) + embs = torch.cat((node_embs, selector_embs), dim=1) + return embs + + +############################# +# Loss Accuracy Prediction +############################# + + +class Loss(nn.Module): + """Cross Entropy loss with weighted intermediate loss, and + L2 loss if num_classes is just 1. + """ + + def __init__(self, config): + super().__init__() + self.config = config + if config.num_classes == 1: + # self.loss = nn.BCEWithLogitsLoss() # in: (N, *), target: (N, *) + self.loss = nn.MSELoss() + # self.loss = nn.L1Loss() + else: + # class labels '-1' don't contribute to the gradient! + # however in most cases it will be more efficient to gather + # the relevant data into a dense tensor + self.loss = nn.CrossEntropyLoss(ignore_index=-1, reduction="mean") + # loss = F.nll_loss( + # F.log_softmax(logits, dim=-1, dtype=torch.float32), + # targets, + # reduction='mean', + # ignore_index=-1, + # ) + + def forward(self, logits, targets): + """inputs: (logits) or (logits, intermediate_logits)""" + if self.config.num_classes == 1: + l = torch.sigmoid(logits[0]) + logits = (l, logits[1]) + loss = self.loss(logits[0].squeeze(dim=1), targets) + if getattr(self.config, "has_aux_input", False): + loss = loss + self.config.intermediate_loss_weight * self.loss( + logits[1], targets + ) + return loss + + +class Metrics(nn.Module): + """Common metrics and info for inspection of results. + Args: + logits, labels + Returns: + (accuracy, pred_targets, correct_preds, targets)""" + + def __init__(self): + super().__init__() + + def forward(self, logits, labels, runtimes=None): + # be flexible with 1hot labels vs indices + if len(labels.size()) == 2: + targets = labels.argmax(dim=1) + elif len(labels.size()) == 1: + targets = labels + else: + raise ValueError( + f"labels={labels.size()} tensor is is neither 1 nor 2-dimensional. :/" + ) + + pred_targets = logits.argmax(dim=1) + correct_preds = targets.eq(pred_targets).float() + accuracy = torch.mean(correct_preds) + + ret = accuracy, correct_preds, targets + + if runtimes is not None: + assert runtimes.size() == logits.size(), ( + f"We need to have a runtime for each sample and every possible label!" + f"runtimes={runtimes.size()}, logits={logits.size()}." + ) + # actual = runtimes[pred#torch.index_select(runtimes, dim=1, index=pred_targets) + actual = torch.gather( + runtimes, dim=1, index=pred_targets.view(-1, 1) + ).squeeze() + # actual = runtimes[:, pred_targets] + optimal = torch.gather(runtimes, dim=1, index=targets.view(-1, 1)).squeeze() + # optimal = runtimes[:, targets] + ret += (actual, optimal) + + return ret + + +# Huggingface implementation +# perplexity = torch.exp(torch.tensor(eval_loss)), where loss is just the ave diff --git a/programl/task/graph_level_classification/run.py b/programl/task/graph_level_classification/run.py new file mode 100644 index 000000000..1369fc944 --- /dev/null +++ b/programl/task/graph_level_classification/run.py @@ -0,0 +1,1067 @@ +# TODO: decide on default log dir in docstring below. +""" +Usage: + run.py [options] + +Options: + -h --help Show this screen. + --data_dir DATA_DIR Directory(*) to of dataset. (*)=relative to repository root ProGraML/. + Will overwrite the per-dataset defaults if provided. + + --log_dir LOG_DIR Directory(*) to store logfiles and trained models relative to repository dir. + [default: programl/task/graph_level_classification/logs/unspecified] + --model MODEL The model to run. + --dataset DATASET The dataset to us. + --config CONFIG Path(*) to a config json dump with params. + --config_json CONFIG_JSON Config json with params. + --restore CHECKPOINT Path(*) to a model file to restore from. + --skip_restore_config Whether to skip restoring the config from CHECKPOINT. + --test Test the model without training. + --restore_by_pattern PATTERN Restore newest model of this name from log_dir and + continue training. (AULT specific!) + PATTERN is a string that can be grep'ed for. + --kfold Run kfold cross-validation iff kfold is set. + Splits are currently dataset specific. + --transfer MODEL The model-class to transfer to. + The args specified will be applied to the transferred model to the extend applicable, e.g. + training params and Readout module specifications, but not to the transferred model trunk. + However, we strongly recommend to make all trunk-parameters match, in order to be able + to restore from transferred checkpoints without having to pass a matching config manually. + --transfer_mode MODE One of frozen, finetune (but not yet implemented) [default: frozen] + Mode frozen also sets all dropout in the restored model to zero (the newly initialized + readout function can have dropout nonetheless, depending on the config provided). + --skip_save_every_epoch Save latest model after every epoch (on a rolling basis). +""" + + +import json +import os +import sys +import time +from pathlib import Path + +import numpy as np +import torch +import tqdm +from docopt import docopt +from torch_geometric.data import DataLoader # (see below) + +# make this file executable from anywhere +full_path = os.path.realpath(__file__) +print(full_path) +REPO_ROOT = full_path.rsplit("ProGraML", maxsplit=1)[0] + "ProGraML" +print(REPO_ROOT) +# insert at 1, 0 is the script path (or '' in REPL) +sys.path.insert(1, REPO_ROOT) +REPO_ROOT = Path(REPO_ROOT) + +# Importing twice like this enables restoring +from . import configs, modeling +from .configs import ( + GGNN_BranchPrediction_Config, + GGNN_Devmap_Config, + GGNN_ForPretraining_Config, + GGNN_POJ104_Config, + GGNN_Threadcoarsening_Config, + GraphTransformer_BranchPrediction_Config, + GraphTransformer_Devmap_Config, + GraphTransformer_ForPretraining_Config, + GraphTransformer_POJ104_Config, + GraphTransformer_Threadcoarsening_Config, + ProGraMLBaseConfig, +) +from .dataloader import NodeLimitedDataLoader +from .dataset import ( + BranchPredictionDataset, + DevmapDataset, + NCCDataset, + POJ104Dataset, + ThreadcoarseningDataset, +) +from .modeling import GGNNModel, GraphTransformerModel + +# Slurm gives us among others: SLURM_JOBID, SLURM_JOB_NAME, +# SLURM_JOB_DEPENDENCY (set to the value of the --dependency option) +if os.environ.get("SLURM_JOBID"): + print("SLURM_JOB_NAME", os.environ.get("SLURM_JOB_NAME", "")) + print("SLURM_JOBID", os.environ.get("SLURM_JOBID", "")) + RUN_ID = "_".join( + [os.environ.get("SLURM_JOB_NAME", ""), os.environ.get("SLURM_JOBID")] + ) +else: + RUN_ID = str(os.getpid()) + + +MODEL_CLASSES = { + "ggnn_poj104": (GGNNModel, GGNN_POJ104_Config), + "ggnn_devmap": (GGNNModel, GGNN_Devmap_Config), + "ggnn_threadcoarsening": (GGNNModel, GGNN_Threadcoarsening_Config), + "ggnn_branch_prediction": (GGNNModel, GGNN_BranchPrediction_Config), + "ggnn_pretraining": (GGNNModel, GGNN_ForPretraining_Config), + "transformer_poj104": (GraphTransformerModel, GraphTransformer_POJ104_Config), + "transformer_devmap": (GraphTransformerModel, GraphTransformer_Devmap_Config), + "transformer_threadcoarsening": ( + GraphTransformerModel, + GraphTransformer_Threadcoarsening_Config, + ), + "transformer_branch_prediction": ( + GraphTransformerModel, + GraphTransformer_BranchPrediction_Config, + ), + "transformer_pretraining": ( + GraphTransformerModel, + GraphTransformer_ForPretraining_Config, + ), +} + +DATASET_CLASSES = { # DS, default data_dir, + "poj104": (POJ104Dataset, "deeplearning/ml4pl/poj104/classifyapp_data"), + "ncc": (NCCDataset, "deeplearning/ml4pl/poj104/ncc_data"), + "devmap_amd": (DevmapDataset, "deeplearning/ml4pl/poj104/devmap_data"), + "devmap_nvidia": (DevmapDataset, "deeplearning/ml4pl/poj104/devmap_data"), + "threadcoarsening_Cypress": ( + ThreadcoarseningDataset, + "deeplearning/ml4pl/poj104/threadcoarsening_data", + ), + "threadcoarsening_Tahiti": ( + ThreadcoarseningDataset, + "deeplearning/ml4pl/poj104/threadcoarsening_data", + ), + "threadcoarsening_Fermi": ( + ThreadcoarseningDataset, + "deeplearning/ml4pl/poj104/threadcoarsening_data", + ), + "threadcoarsening_Kepler": ( + ThreadcoarseningDataset, + "deeplearning/ml4pl/poj104/threadcoarsening_data", + ), + "branch_prediction": ( + BranchPredictionDataset, + "deeplearning/ml4pl/poj104/branch_prediction_data", + ), +} + +DEBUG = False +if DEBUG: + torch.autograd.set_detect_anomaly(True) + + +class Learner(object): + def __init__(self, model, dataset, args=None, current_kfold_split=None): + # Make class work without file being run as main + self.args = docopt(__doc__, argv=[]) + if args: + self.args.update(args) + + # prepare logging + self.parent_run_id = None # for restored models + self.run_id = f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{RUN_ID}" + if args["--kfold"]: + self.run_id += f"_{current_kfold_split}" + + log_dir = REPO_ROOT / self.args.get("--log_dir", ".") + log_dir.mkdir(parents=True, exist_ok=True) + self.log_file = log_dir / f"{self.run_id}_log.json" + self.best_model_file = log_dir / f"{self.run_id}_model_best.pickle" + self.last_model_file = log_dir / f"{self.run_id}_model_last.pickle" + + # ~~~~~~~~~~ load model ~~~~~~~~~~~~~ + if self.args.get("--restore"): + self.model = self.restore_model(path=REPO_ROOT / self.args["--restore"]) + elif self.args.get("--restore_by_pattern"): + self.model = self.restore_by_pattern( + pattern=self.args["--restore_by_pattern"], + log_dir=log_dir, + current_kfold_split=current_kfold_split, + ) + else: # initialize fresh model + # get model and dataset + assert model, "Need to provide --model to initialize freshly." + Model, Config = MODEL_CLASSES[model] + + self.global_training_step = 0 + self.current_epoch = 1 + + # get config + params = self.parse_config_params(args) + self.config = Config.from_dict(params=params) + + test_only = self.args.get("--test", False) + self.model = Model(self.config, test_only=test_only) + + # set seeds, NB: the NN on CUDA is partially non-deterministic! + torch.manual_seed(self.config.random_seed) + np.random.seed(self.config.random_seed) + + # ~~~~~~~~~~ transfer model ~~~~~~~~ + if self.args["--transfer"] is not None: + self.transfer_model(self.args["--transfer"], self.args["--transfer_mode"]) + + # ~~~~~~~~~~ load data ~~~~~~~~~~~~~ + self.load_data(dataset, args["--kfold"], current_kfold_split) + + # log config to file + config_dict = self.config.to_dict() + with open(log_dir / f"{self.run_id}_params.json", "w") as f: + json.dump(config_dict, f) + + # log parent run to file if run was restored + if self.parent_run_id: + with open(log_dir / f"{self.run_id}_parent.json", "w") as f: + json.dump( + { + "parent": self.parent_run_id, + "self": self.run_id, + "self_config": config_dict, + }, + f, + ) + + print( + "Run %s starting with following parameters:\n%s" + % (self.run_id, json.dumps(config_dict)) + ) + + def load_data(self, dataset, kfold, current_kfold_split): + """Set self.train_data, self.test_data, self.valid_data depending on the dataset used.""" + if not kfold: + assert current_kfold_split is None + if "_" in dataset: + split = dataset.rsplit("_", maxsplit=1)[-1] + Dataset, data_dir = DATASET_CLASSES[dataset] + if self.args.get("--data_dir", "."): + self.data_dir = REPO_ROOT / self.args.get("--data_dir", ".") + else: + self.data_dir = REPO_ROOT / data_dir + + # Switch cases by dataset + # ~~~~~~~~~~ NCC ~~~~~~~~~~~~~~~~~~~~~ + if dataset == "ncc": + # train set + if not self.args.get("--test"): + # take train_subset=[90,100] as validation data + if self.config.train_subset == [0, 100]: + print(f"!!!!!!!! WARNING !!!!!!!!!!!!") + print(f"SETTING TRAIN_SUBSET FROM [0,100] TO [0, 90]") + print(f"!!!!!!!! WARNING !!!!!!!!!!!!") + self.config.train_subset = [0, 90] + train_dataset = Dataset( + root=self.data_dir, + split="train", + train_subset=self.config.train_subset, + ) + train_dataset = train_dataset.filter_max_num_nodes( + self.config.max_num_nodes + ) + self.train_data = NodeLimitedDataLoader( + train_dataset, + batch_size=self.config.batch_size, + shuffle=True, + max_num_nodes=self.config.max_num_nodes, + warn_on_limit=True, + ) + # valid set (and test set) + valid_dataset = Dataset( + root=self.data_dir, split="train", train_subset=[90, 100] + ) + valid_dataset = valid_dataset.filter_max_num_nodes( + self.config.max_num_nodes + ) + self.valid_data = DataLoader( + valid_dataset, batch_size=self.config.batch_size * 2, shuffle=False + ) + self.test_data = None + # ~~~~~~~~~~ POJ 104 ~~~~~~~~~~~~~~~~~~~~~ + elif dataset == "poj104": + if not self.args.get("--test"): + train_dataset = Dataset( + root=self.data_dir, + split="train", + train_subset=self.config.train_subset, + cdfg=self.config.cdfg_vocab, + ablation_vocab=self.config.ablation_vocab, + ) + self.train_data = DataLoader( + train_dataset, + batch_size=self.config.batch_size, + shuffle=True, + # max_num_nodes=self.config.max_num_nodes + ) + + self.valid_data = DataLoader( + Dataset( + root=self.data_dir, + split="val", + cdfg=self.config.cdfg_vocab, + ablation_vocab=self.config.ablation_vocab, + ), + batch_size=self.config.batch_size * 2, + shuffle=False, + ) + self.test_data = DataLoader( + Dataset( + root=self.data_dir, + split="test", + cdfg=self.config.cdfg_vocab, + ablation_vocab=self.config.ablation_vocab, + ), + batch_size=self.config.batch_size * 2, + shuffle=False, + ) + + # ~~~~~~~~~~ DEVMAP ~~~~~~~~~~~~~~~~~~~~~ + elif dataset in ["devmap_amd", "devmap_nvidia"]: + assert ( + kfold and current_kfold_split is not None + ), "Devmap only supported with kfold flag!" + assert current_kfold_split < 10 + # get the whole dataset then get the correct split + ds = Dataset( + root=self.data_dir, + split=split, + train_subset=self.config.train_subset, + cdfg=self.config.cdfg, + ablation_vocab=self.config.ablation_vocab, + ) + train_dataset, valid_dataset = ds.return_cross_validation_splits( + current_kfold_split + ) + + self.train_data = None + self.valid_data = DataLoader( + valid_dataset, batch_size=self.config.batch_size * 2, shuffle=False + ) + + # only maybe set train_data. + if not self.args.get("--test"): + self.train_data = DataLoader( + train_dataset, + batch_size=self.config.batch_size, + shuffle=True, + ) + self.test_data = None + + # ~~~~~~~~~~ THREADCOARSENING ~~~~~~~~~~~~~~~~~~~~~ + elif dataset in [ + "threadcoarsening" + "_" + s + for s in ["Cypress", "Tahiti", "Fermi", "Kepler"] + ]: + assert ( + kfold and current_kfold_split is not None + ), "Threadcoarsening only supported with kfold flag!" + assert current_kfold_split < 17 and current_kfold_split >= 0 + if not self.args.get("--test"): + pass + # get the whole dataset then get the correct split + ds = Dataset( + root=self.data_dir, split=split, train_subset=self.config.train_subset + ) + train_dataset, valid_dataset = ds.return_cross_validation_splits( + current_kfold_split + ) + + self.train_data = None + self.valid_data = DataLoader( + valid_dataset, batch_size=self.config.batch_size * 2, shuffle=False + ) + + # only maybe set train_data. + if not self.args.get("--test"): + self.train_data = DataLoader( + train_dataset, + batch_size=self.config.batch_size, + shuffle=True, + ) + self.test_data = None + + # ~~~~~~~~~~~~ Branch Prediction ~~~~~~~~~~~~~~~~~~~~ + elif dataset in ["branch_prediction"]: + assert ( + kfold and current_kfold_split is not None + ), "Branch Prediction only supported with kfold flag!" + assert current_kfold_split < 10 + # train set + ds = Dataset( + root=self.data_dir, split="train", train_subset=self.config.train_subset + ) + ds = ds.filter_max_num_nodes(self.config.max_num_nodes) + + train_dataset, valid_dataset = ds.return_cross_validation_splits( + current_kfold_split + ) + # train_dataset.filter_max_num_nodes(self.config.max_num_nodes) + # valid_dataset.filter_max_num_nodes(self.config.max_num_nodes) + self.train_data = NodeLimitedDataLoader( + train_dataset, + batch_size=self.config.batch_size, + shuffle=True, + max_num_nodes=self.config.max_num_nodes, + warn_on_limit=False, + ) + self.valid_data = DataLoader( + valid_dataset, batch_size=self.config.batch_size, shuffle=False + ) + + self.test_data = None + # ~~~~~~~~~~~ Unknow Dataset ~~~~~~~~~~~~~~~~~ + else: + raise NotImplementedError + + def parse_config_params(self, args): + """Accesses self.args to parse config params from various flags.""" + params = None + if args.get("--config"): + with open(REPO_ROOT / args["--config"], "r") as f: + params = json.load(f) + elif args.get("--config_json"): + config_string = args["--config_json"] + # accept single quoted 'json'. This only works bc our json strings are simple enough. + config_string = ( + config_string.replace("\\'", "'") + .replace("'", '"') + .replace("True", "true") + .replace("False", "false") + ) + params = json.loads(config_string) + return params + + def data2input(self, batch): + """Glue method that converts a batch from the dataloader into the input format of the model""" + num_graphs = batch.batch[-1].item() + 1 + + edge_lists = [] + edge_positions = ( + [] if getattr(self.config, "position_embeddings", False) else None + ) + + edge_indices = list(range(3)) + if self.config.ablate_structure: + if self.config.ablate_structure == "control": + edge_indices[0] = -1 + elif self.config.ablate_structure == "data": + edge_indices[1] = -1 + elif self.config.ablate_structure == "call": + edge_indices[2] = -1 + else: + raise ValueError("unreachable") + + for i in edge_indices: + # mask by edge type + mask = batch.edge_attr[:, 0] == i # + edge_list = batch.edge_index[:, mask].t() # + edge_lists.append(edge_list) + + if getattr(self.config, "position_embeddings", False): + edge_pos = batch.edge_attr[mask, 1] # + edge_positions.append(edge_pos) + + inputs = { + "vocab_ids": batch.x[:, 0], + "edge_lists": edge_lists, + "pos_lists": edge_positions, + "num_graphs": num_graphs, + "graph_nodes_list": batch.batch, + "node_types": batch.x[:, 1], + } + + # maybe add labels + if batch.y is not None: + inputs.update( + { + "labels": batch.y, + } + ) + + # add other stuff + if hasattr(batch, "aux_in"): + inputs.update({"aux_in": batch.aux_in.to(dtype=torch.float)}) + if hasattr(batch, "runtimes"): + inputs.update({"runtimes": batch.runtimes.to(dtype=torch.float)}) + return inputs + + def make_branch_labels(self, batch): + """takes a batch and maps the profile info to branch labels for regression: + a branch has (true_weight+1, false_weight+1, total_weight+2) and we map to [0, 1] as + p(true) = true_weight / total_weight + note that the profile info adds 1 on both true and false weights! + """ + mask = batch.profile_info[:, 0].bool() + # clamp to be robust against 0 counts from problems with the data + yes = torch.clamp( + batch.profile_info[:, 1].to(dtype=torch.get_default_dtype()) - 1, min=0.0 + ) + total = 1e-7 + torch.clamp( + batch.profile_info[:, 3].to(torch.get_default_dtype()) - 2, min=0.0 + ) + p_yes = yes / total # true / total + p_yes = torch.clamp(p_yes, min=0.0, max=1.0) + # print([str(a) for a in p_yes[mask].clone().detach().to('cpu').numpy()]) + return p_yes, mask + + def bertify_batch(self, batch, config): + """takes a batch and returns the bertified input, labels and corresponding mask, + indicating what part of the input to predict.""" + vocab_ids = batch.x[:, 0] + labels = vocab_ids.clone() + node_types = batch.x[:, 1] + device = vocab_ids.device + + # we create a tensor that carries the probability of being masked for each node + probabilities = torch.full( + vocab_ids.size(), config.mlm_probability, device=device + ) + # set to 0.0 where nodes are !IDENTIFIERS, i.e. node_types == 1 + if config.mlm_statements_only: + probabilities.masked_fill_(node_types.bool(), 0.0) + # set to 0.0 where statements are !UNK + if config.mlm_exclude_unk_tokens: + probabilities.masked_fill_(vocab_ids == config.unk_token_id, 0.0) + + # get the node mask that determines the nodes we use as targets + mlm_target_mask = torch.bernoulli(probabilities).bool() + # of those, get the 80% where the input is masked + masked_out_nodes = ( + torch.bernoulli(torch.full(vocab_ids.size(), 0.8, device=device)).bool() + & mlm_target_mask + ) + + # the 10% where it's set to a random token + # (as 50% of the target nodes that are not masked out) + random_nodes = ( + torch.bernoulli(torch.full(vocab_ids.size(), 0.5, device=device)).bool() + & mlm_target_mask + & ~masked_out_nodes + ) + # and the 10% where it's the original id, we just leave alone. + + # apply the changes + random_ids = torch.randint( + config.vocab_size, vocab_ids.shape, dtype=torch.long, device=device + ) + vocab_ids[masked_out_nodes] = config.mlm_mask_token_id + vocab_ids[random_nodes] = random_ids[random_nodes] + # the loss function can ignore -1 labels for gradients, + # although it's more efficient to gather according to the mlm_target_mask mask + labels[~mlm_target_mask] = -1 + + return vocab_ids, labels, mlm_target_mask + + def run_epoch(self, loader, epoch_type, analysis_mode=False): + """ + args: + loader: a pytorch-geometric dataset loader, + epoch_type: 'train' or 'eval' + returns: + loss, accuracy, instance_per_second + """ + + bar = tqdm.tqdm(total=len(loader.dataset), smoothing=0.01, unit="inst") + if analysis_mode: + saved_outputs = [] + + epoch_loss, epoch_accuracy = 0, 0 + epoch_actual_rt, epoch_optimal_rt = 0, 0 + start_time = time.time() + processed_graphs = 0 + predicted_targets = 0 + + for step, batch in enumerate(loader): + ######### prepare input + # move batch to gpu and prepare input tensors: + batch.to(self.model.dev) + + inputs = self.data2input(batch) + num_graphs = inputs["num_graphs"] + + # only implemented nodewise model are for pretraining currently + if self.config.name in [ + "GGNN_ForPretraining_Config", + "GraphTransformer_ForPretraining_Config", + ]: + mlm_vocab_ids, mlm_labels, mlm_target_mask = self.bertify_batch( + batch, self.config + ) + inputs.update( + { + "vocab_ids": mlm_vocab_ids, + "labels": mlm_labels, + "readout_mask": mlm_target_mask, + } + ) + num_targets = torch.sum(mlm_target_mask.to(torch.long)).item() + elif self.config.name in [ + "GGNN_BranchPrediction_Config", + "GraphTransformer_BranchPrediction_Config", + ]: + y, mask = self.make_branch_labels(batch) + inputs.update( + { + "labels": y, + "readout_mask": mask, + } + ) + if not torch.any(mask): + print("Warning: batch has no labels! skipping.......") + continue + num_targets = torch.sum(mask.to(torch.long)).item() + # elif: other nodewise configs go here! + elif getattr(self.config, "has_graph_labels", False): # all graph models + num_targets = num_graphs + else: + raise NotImplementedError( + "We don't have other nodewise models currently." + ) + + predicted_targets += num_targets + processed_graphs += num_graphs + + ############# + # RUN MODEL FORWARD PASS + + # enter correct mode of model and fetch output + if epoch_type == "train": + self.global_training_step += 1 + if not self.model.training: + self.model.train() + outputs = self.model(**inputs) + else: # not TRAIN + if self.model.training: + self.model.eval() + self.model.opt.zero_grad() + with torch.no_grad(): # don't trace computation graph! + outputs = self.model(**inputs) + + if analysis_mode: + # TODO I don't know whether the outputs are properly cloned, moved to cpu and detached or not. + saved_outputs.append(outputs) + + if hasattr(batch, "runtimes"): + ( + logits, + accuracy, + correct, + targets, + actual_rt, + optimal_rt, + graph_features, + *unroll_stats, + ) = outputs + epoch_actual_rt += torch.sum(actual_rt).item() + epoch_optimal_rt += torch.sum(optimal_rt).item() + else: + ( + logits, + accuracy, + correct, + targets, + graph_features, + *unroll_stats, + ) = outputs + loss = self.model.loss((logits, graph_features), targets) + + epoch_loss += loss.item() * num_targets + epoch_accuracy += accuracy.item() * num_targets + + # update weights + if epoch_type == "train": + loss.backward() + if self.model.config.clip_grad_norm > 0.0: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.model.config.clip_grad_norm + ) + self.model.opt.step() + self.model.opt.zero_grad() + + # update bar display + bar_loss = epoch_loss / (predicted_targets + 1e-8) + bar_acc = epoch_accuracy / (predicted_targets + 1e-8) + bar.set_postfix(loss=bar_loss, acc=bar_acc, ppl=np.exp(bar_loss)) + bar.update(num_graphs) + + bar.close() + + # Return epoch stats + mean_loss = epoch_loss / predicted_targets + mean_accuracy = epoch_accuracy / predicted_targets + instance_per_sec = processed_graphs / (time.time() - start_time) + epoch_perplexity = np.exp(mean_loss) + + returns = ( + mean_loss, + mean_accuracy, + instance_per_sec, + epoch_perplexity, + epoch_actual_rt, + epoch_optimal_rt, + ) + + if analysis_mode: + returns += (saved_outputs,) + return returns + + def train(self): + log_to_save = [] + total_time_start = time.time() + + # we enter training after restore + if self.parent_run_id is not None: + print(f"== Epoch pre-validate epoch {self.current_epoch}") + _, valid_acc, _, ppl, _, _ = self.run_epoch(self.valid_data, "val") + best_val_acc = np.sum(valid_acc) + best_val_acc_epoch = self.current_epoch + print( + "\r\x1b[KResumed operation, initial cum. val. acc: %.5f, ppl %.5f" + % (best_val_acc, ppl) + ) + self.current_epoch += 1 + else: + (best_val_acc, best_val_acc_epoch) = (0.0, 0) + + # Training loop over epochs + target_epoch = self.current_epoch + self.config.num_epochs + for epoch in range(self.current_epoch, target_epoch): + print(f"== Epoch {epoch}/{target_epoch}") + + ( + train_loss, + train_acc, + train_speed, + train_ppl, + train_art, + train_ort, + ) = self.run_epoch(self.train_data, "train") + print( + "\r\x1b[K Train: loss: %.5f | acc: %s | ppl: %s | instances/sec: %.2f | runtime: %.1f opt: %.1f" + % ( + train_loss, + f"{train_acc:.5f}", + train_ppl, + train_speed, + train_art, + train_ort, + ) + ) + + ( + valid_loss, + valid_acc, + valid_speed, + valid_ppl, + valid_art, + valid_ort, + ) = self.run_epoch(self.valid_data, "eval") + print( + "\r\x1b[K Valid: loss: %.5f | acc: %s | ppl: %s | instances/sec: %.2f | runtime: %.1f opt: %.1f" + % ( + valid_loss, + f"{valid_acc:.5f}", + valid_ppl, + valid_speed, + valid_art, + valid_ort, + ) + ) + + # maybe run test epoch + if self.test_data is not None: + test_loss, test_acc, test_speed, test_ppl, _, _ = self.run_epoch( + self.test_data, "eval" + ) + print( + "\r\x1b[K Test: loss: %.5f | acc: %s | ppl: %s | instances/sec: %.2f" + % (test_loss, f"{test_acc:.5f}", test_ppl, test_speed) + ) + + epoch_time = time.time() - total_time_start + self.current_epoch = epoch + + log_entry = { + "epoch": epoch, + "time": epoch_time, + "train_results": ( + train_loss, + train_acc, + train_speed, + train_ppl, + train_art, + train_ort, + ), + "valid_results": ( + valid_loss, + valid_acc, + valid_speed, + valid_ppl, + valid_art, + valid_ort, + ), + } + + if self.test_data is not None: + log_entry.update( + {"test_results": (test_loss, test_acc, test_speed, test_ppl)} + ) + + log_to_save.append(log_entry) + + with open(self.log_file, "w") as f: + json.dump(log_to_save, f, indent=4) + + # TODO: sum seems redundant if only one task is trained. + val_acc = np.sum(valid_acc) # type: float + if val_acc > best_val_acc: + self.save_model(epoch, self.best_model_file) + print( + " (Best epoch so far, cum. val. acc increased to %.5f from %.5f. Saving to '%s')" + % (val_acc, best_val_acc, self.best_model_file) + ) + best_val_acc = val_acc + best_val_acc_epoch = epoch + elif epoch - best_val_acc_epoch >= self.config.patience: + print( + "Stopping training after %i epochs without improvement on validation accuracy." + % self.config.patience + ) + break + if not self.args["--skip_save_every_epoch"]: + self.save_model(epoch, self.last_model_file) + # save last model on finish of training + self.save_model(epoch, self.last_model_file) + + def test(self): + log_to_save = [] + total_time_start = time.time() + + print(f"== Epoch: Test only run.") + + ( + valid_loss, + valid_acc, + valid_speed, + valid_ppl, + valid_art, + valid_ort, + ) = self.run_epoch(self.valid_data, "eval") + print( + "\r\x1b[K Valid: loss: %.5f | acc: %s | ppl: %s | instances/sec: %.2f | runtime: %.1f opt: %.1f" + % ( + valid_loss, + f"{valid_acc:.5f}", + valid_ppl, + valid_speed, + valid_art, + valid_ort, + ) + ) + + if self.test_data is not None: + test_loss, test_acc, test_speed, test_ppl, _, _ = self.run_epoch( + self.test_data, "eval" + ) + print( + "\r\x1b[K Test: loss: %.5f | acc: %s | ppl: %s | instances/sec: %.2f" + % (test_loss, f"{test_acc:.5f}", test_ppl, test_speed) + ) + + epoch_time = time.time() - total_time_start + + log_entry = { + "epoch": "test_only", + "time": epoch_time, + "valid_results": ( + valid_loss, + valid_acc, + valid_speed, + valid_ppl, + valid_art, + valid_ort, + ), + } + if self.test_data is not None: + log_entry.update( + {"test_results": (test_loss, test_acc, test_speed, test_ppl)} + ) + + log_to_save.append(log_entry) + with open(self.log_file, "w") as f: + json.dump(log_to_save, f, indent=4) + + def save_model(self, epoch, path): + checkpoint = { + "run_id": self.run_id, + "global_training_step": self.global_training_step, + "epoch": epoch, + "config": self.config.to_dict(), + "model_name": self.model.__class__.__name__, + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.model.opt.state_dict(), + } + torch.save(checkpoint, path) + + def restore_by_pattern(self, pattern, log_dir, current_kfold_split=None): + """This method will restore the last checkpoint of a run that is identifiable by + the pattern . It could restore to model_last or model_best. + However if current_kfold_split is given, it will additionally filter for this split! + Therefore the split should not be part of the pattern. + """ + if current_kfold_split is not None: + checkpoints = list( + log_dir.glob(f"*{pattern}*_{current_kfold_split}_model_*.p*") + ) + else: + checkpoints = list(log_dir.glob(f"*{pattern}*_model_*.p*")) + last_mod_checkpoint = sorted(checkpoints, key=os.path.getmtime)[-1] + assert ( + last_mod_checkpoint.is_file() + ), f"Couldn't restore by jobname: No model files matching <{pattern}> found." + return self.restore_model(last_mod_checkpoint) + + def restore_model(self, path): + """loads and restores a model from file.""" + checkpoint = torch.load(path) + self.parent_run_id = checkpoint["run_id"] + self.global_training_step = checkpoint["global_training_step"] + self.current_epoch = checkpoint["epoch"] + + config_dict = ( + checkpoint["config"] + if isinstance(checkpoint["config"], dict) + else checkpoint["config"].to_dict() + ) + + if not self.args.get("--skip_restore_config"): + # maybe zero out dropout attributes + if ( + self.args["--transfer"] is not None + and self.args["--transfer_mode"] == "frozen" + ): + for key, value in config_dict.items(): + if "dropout" in key: + config_dict[key] = 0.0 + print( + f"*Restoring Config* Setting {key} from {value} to 0.0 while restoring config from checkpoint for transfer." + ) + config = getattr(configs, config_dict["name"]).from_dict(config_dict) + self.config = config + print( + f"*RESTORED* self.config = {config.name} from checkpoint {str(path)}." + ) + else: + print(f"Skipped restoring self.config from checkpoint!") + assert ( + self.args.get("--model") is not None + ), "Can only use --skip_restore_config if --model is given." + # initialize config from --model and compare to skipped config from restore. + _, Config = MODEL_CLASSES[self.args["--model"]] + self.config = Config.from_dict(self.parse_config_params(args)) + self.config.check_equal(config_dict) + + test_only = self.args.get("--test", False) + Model = getattr(modeling, checkpoint["model_name"]) + model = Model(self.config, test_only=test_only) + model.load_state_dict(checkpoint["model_state_dict"]) + print(f"*RESTORED* model parameters from checkpoint {str(path)}.") + if not self.args.get( + "--test", None + ): # only restore opt if needed. opt should be None o/w. + model.opt.load_state_dict(checkpoint["optimizer_state_dict"]) + print(f"*RESTORED* optimizer parameters from checkpoint as well.") + return model + + def transfer_model(self, transfer_model_class, mode): + """transfers the current model to a different model class. + Resets global_training_step and current_epoch. + + Mode: + frozen - only the new readout module will receive gradients. + finetune - the whole network will receive gradients. + """ + assert transfer_model_class in MODEL_CLASSES + self.global_training_step = 0 + self.current_epoch = 1 + + # freeze layers + if mode == "frozen": + for param in self.model.parameters(): + param.requires_grad = False + + # replace config + _, Config = MODEL_CLASSES[transfer_model_class] + params = self.parse_config_params(self.args) + self.config = Config.from_dict(params=params) + + # replace readout + if getattr(self.config, "has_aux_input", False) and getattr( + self.config, "aux_use_better", False + ): + self.model.readout = modeling.BetterAuxiliaryReadout(self.config) + elif getattr(self.config, "has_aux_input", False): + self.model.readout = modeling.Readout(self.config) + self.model.aux_readout = modeling.AuxiliaryReadout(self.config) + else: + assert not getattr( + self.config, "aux_use_better", False + ), "aux_use_better only with has_aux_input!" + self.model.readout = modeling.Readout(self.config) + + # assign config to model + self.model.config = self.config + + # re-setup model + test_only = self.args.get("--test", False) + assert ( + not test_only + ), "Why transfer if you don't train? Here is not restoring a transferred model!!!" + self.model.setup(self.config, test_only) + # print info + print(self.model) + print( + f"Number of trainable params in transferred model: {self.model.num_parameters()}" + ) + + +if __name__ == "__main__": + args = docopt(__doc__) + print(args) + assert not ( + args["--config"] and args["--config_json"] + ), "Can't decide which config to use!" + if args.get("--model"): + assert args.get("--model") in MODEL_CLASSES, f"Unknown model." + if args.get("--dataset"): + assert args.get("--dataset") in DATASET_CLASSES, f"Unknown dataset." + + if not args["--kfold"]: + learner = Learner(model=args["--model"], dataset=args["--dataset"], args=args) + learner.test() if args.get("--test") else learner.train() + else: # kfold + if args["--dataset"] in ["devmap_amd", "devmap_nvidia"]: + num_splits = 10 + elif args["--dataset"] in [ + "threadcoarsening_Cypress", + "threadcoarsening_Kepler", + "threadcoarsening_Fermi", + "threadcoarsening_Tahiti", + ]: + num_splits = 17 + elif args["--dataset"] in ["branch_prediction"]: + num_splits = 10 + else: + raise NotImplementedError("kfold not implemented for this dataset.") + + for split in range(num_splits): + print(f"#######################################") + print(f"CURRENT SPLIT: {split} + 1/{num_splits}") + print(f"#######################################") + learner = Learner( + model=args["--model"], + dataset=args["--dataset"], + args=args, + current_kfold_split=split, + ) + if len(learner.valid_data) == 0: + print("***" * 20) + print( + f"Validation Split is empty! Skipping split {split} + 1 / {num_splits}." + ) + print("***" * 20) + learner.test() if args.get("--test") else learner.train()