Skip to content
Open

Main #124

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

.DS_Store
dig/.DS_Store
dig/xgraph/.DS_Store
dig/xgraph/TAGE/.DS_Store
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ Please cite our paper if you find *DIG* useful in your work:

## Contact

If you have any questions, please submit a new issue or contact us: Meng Liu [[email protected]] or Shuiwang Ji [[email protected]].
If you have any technical questions, please submit a new issue.

If you have other questions, please contact us: Meng Liu [[email protected]] or Shuiwang Ji [[email protected]].


36 changes: 36 additions & 0 deletions dig/ggraph/JT-VAE/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn

import math, random, sys
import argparse

from jtnn_vae import JTNNVAE
from vocab import Vocab

import rdkit

lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)

parser = argparse.ArgumentParser()
parser.add_argument('--nsample', type=int, required=True)
parser.add_argument('--vocab', required=True)
parser.add_argument('--model', required=True)

parser.add_argument('--hidden_size', type=int, default=450)
parser.add_argument('--latent_size', type=int, default=56)
parser.add_argument('--depthT', type=int, default=20)
parser.add_argument('--depthG', type=int, default=3)

args = parser.parse_args()

vocab = [x.strip("\r\n ") for x in open(args.vocab)]
vocab = Vocab(vocab)

model = JTNNVAE(vocab, args.hidden_size, args.latent_size, args.depthT, args.depthG)
model.load_state_dict(torch.load(args.model))
model = model.cuda()

torch.manual_seed(0)
for i in range(args.nsample):
print(model.sample_prior())
114 changes: 114 additions & 0 deletions dig/ggraph/JT-VAE/vae_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random, sys
import numpy as np
import argparse
from collections import deque
import pickle

# from fast_jtnn import *
import rdkit

from vocab import Vocab
from jtnn_vae import JTNNVAE
from datautils import MolTreeFolder, PairTreeFolder, MolTreeDataset

from rdkit import RDLogger


lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

parser = argparse.ArgumentParser()
parser.add_argument('--train', required=True)
parser.add_argument('--vocab', required=True)
parser.add_argument('--save_dir', required=True)
parser.add_argument('--load_epoch', type=int, default=0)

parser.add_argument('--hidden_size', type=int, default=450)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--latent_size', type=int, default=56)
parser.add_argument('--depthT', type=int, default=20)
parser.add_argument('--depthG', type=int, default=3)

parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--clip_norm', type=float, default=50.0)
parser.add_argument('--beta', type=float, default=0.0)
parser.add_argument('--step_beta', type=float, default=0.002)
parser.add_argument('--max_beta', type=float, default=1.0)
parser.add_argument('--warmup', type=int, default=40000)

parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--anneal_rate', type=float, default=0.9)
parser.add_argument('--kl_anneal_iter', type=int, default=2000)
parser.add_argument('--print_iter', type=int, default=50)
parser.add_argument('--save_iter', type=int, default=5000)

args = parser.parse_args()
print(args)

vocab = [x.strip("\r\n ") for x in open(args.vocab)]
vocab = Vocab(vocab)

model = JTNNVAE(vocab, args.hidden_size, args.latent_size, args.depthT, args.depthG).cuda()
print(model)

for param in model.parameters():
if param.dim() == 1:
nn.init.constant_(param, 0)
else:
nn.init.xavier_normal_(param)

if args.load_epoch > 0:
model.load_state_dict(torch.load(args.save_dir + "/model.iter-" + str(args.load_epoch)))

print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, args.anneal_rate)
scheduler.step()

param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))

total_step = args.load_epoch
beta = args.beta
meters = np.zeros(4)

for epoch in range(args.epoch):
loader = MolTreeFolder(args.train, vocab, args.batch_size, num_workers=4)
for batch in loader:
total_step += 1
try:
model.zero_grad()
loss, kl_div, wacc, tacc, sacc = model(batch, beta)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
optimizer.step()
except Exception as e:
print(e)
continue

meters = meters + np.array([kl_div, wacc * 100, tacc * 100, sacc * 100])

if total_step % args.print_iter == 0:
meters /= args.print_iter
print("[%d] Beta: %.3f, KL: %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (total_step, beta, meters[0], meters[1], meters[2], meters[3], para$
sys.stdout.flush()
meters *= 0

if total_step % args.save_iter == 0:
torch.save(model.state_dict(), args.save_dir + "/model.iter-" + str(total_step))

if total_step % args.anneal_iter == 0:
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])

if total_step % args.kl_anneal_iter == 0 and total_step >= args.warmup:
beta = min(args.max_beta, beta + args.step_beta)

30 changes: 30 additions & 0 deletions dig/ggraph/JT-VAE/vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import rdkit
import rdkit.Chem as Chem
import copy

def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]

class Vocab(object):
benzynes = ['C1=CC=CC=C1', 'C1=CC=NC=C1', 'C1=CC=NN=C1', 'C1=CN=CC=N1', 'C1=CN=CN=C1', 'C1=CN=NC=N1', 'C1=CN=NN=C1', 'C1=NC=NC=N1', 'C1=NN=CN=N1']
penzynes = ['C1=C[NH]C=C1', 'C1=C[NH]C=N1', 'C1=C[NH]N=C1', 'C1=C[NH]N=N1', 'C1=COC=C1', 'C1=COC=N1', 'C1=CON=C1', 'C1=CSC=C1', 'C1=CSC=N1', 'C1=CSN=C1', 'C1=CSN=N1', 'C1=NN=C[NH]1', 'C1=NN=CO1', 'C1=NN=CS1', 'C1=N[NH]C=N1', 'C1=N[NH]N=C1', 'C1=N[NH]N=N1', 'C1=NN=N[NH]1', 'C1=NN=NS1', 'C1=NOC=N1', 'C1=NON=C1', 'C1=NSC=N1', 'C1=NSN=C1']

def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x:i for i,x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
Vocab.benzynes = [s for s in smiles_list if s.count('=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 6] + ['C1=CCNCC1']
Vocab.penzynes = [s for s in smiles_list if s.count('=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 5] + ['C1=NCCN1','C1=NNCC1']

def get_index(self, smiles):
return self.vmap[smiles]

def get_smiles(self, idx):
return self.vocab[idx]

def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])

def size(self):
return len(self.vocab)
1 change: 1 addition & 0 deletions dig/sslgraph/sslgraph/contrastive/views_fn/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def node_attr_mask(mode='whole', mask_ratio=0.1, mask_mean=0.5, mask_std=0.5):
def do_trans(data):
node_num, feat_dim = data.x.size()
x = data.x.detach().clone()
mask = torch.zeros(node_num)

if mode == 'whole':
mask_num = int(node_num * mask_ratio)
Expand Down
7 changes: 6 additions & 1 deletion dig/xgraph/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/datasets/*
!/datasets/Readme.md
!/datasets/load_datasets.py
!/datasets/load_datasets.py
.gitignore
.idea/
checkpoint*
dig/xgraph/SubgraphX/*.zip
*.zip
2 changes: 1 addition & 1 deletion dig/xgraph/PGExplainer/Configures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class DataParser(Tap):
dataset_name: str = 'bbbp'
dataset_dir: str = './datasets'
dataset_dir: str = '../datasets'
random_split: bool = True
data_split_ratio: List = [0.8, 0.1, 0.1] # the ratio of training, validation and testing set for random split
seed: int = 1
Expand Down
14 changes: 13 additions & 1 deletion dig/xgraph/PGExplainer/install.sh
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
textwrap
#!/bin/bash
conda create -y -n xgraph python=3.8
source activate xgraph
conda install -y pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.2 -c pytorch
pip install scipy
CUDA="cu102"
pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+${CUDA}.html
pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-1.6.0+${CUDA}.html
pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-1.6.0+${CUDA}.html
pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.6.0+${CUDA}.html
pip install torch-geometric
pip install cilog typed-argument-parser==1.5.4 tqdm
conda install -y -c conda-forge rdkit
30 changes: 19 additions & 11 deletions dig/xgraph/PGExplainer/metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import numpy as np
from torch_geometric.data import Data, Batch
from typing import Optional


def calculate_selected_nodes(data, edge_mask, top_k):
def calculate_selected_nodes(data, edge_mask, top_k, node_idx=None):
threshold = float(edge_mask.reshape(-1).sort(descending=True).values[min(top_k, edge_mask.shape[0]-1)])
hard_mask = (edge_mask > threshold).cpu()
edge_idx_list = torch.where(hard_mask == 1)[0]
Expand All @@ -12,25 +13,32 @@ def calculate_selected_nodes(data, edge_mask, top_k):
for edge_idx in edge_idx_list:
selected_nodes += [edge_index[0][edge_idx], edge_index[1][edge_idx]]
selected_nodes = list(set(selected_nodes))
if node_idx is not None:
selected_nodes.append(node_idx)
return selected_nodes


def top_k_fidelity(data: Data, edge_mask: np.array, top_k: int,
gnnNets: torch.nn.Module, label: int, target_id: int = -1):
gnnNets: torch.nn.Module, label: int,
target_id: int = -1, node_idx: Optional[int]=None,
undirected=True):
""" return the fidelity score of the subgraph with top_k score edges """
if undirected:
top_k = 2 * top_k
all_nodes = np.arange(data.x.shape[0]).tolist()
selected_nodes = calculate_selected_nodes(data, edge_mask, top_k)
score = gnn_score(all_nodes, data, gnnNets, label, target_id,
selected_nodes = calculate_selected_nodes(data, edge_mask, top_k, node_idx)
score = gnn_score(all_nodes, data, gnnNets, label, target_id, node_idx=node_idx,
subgraph_building_method='zero_filling')

unimportant_nodes = [node for node in all_nodes if node not in selected_nodes]
score_mask_important = gnn_score(unimportant_nodes, data, gnnNets, label, target_id,
score_mask_important = gnn_score(unimportant_nodes, data, gnnNets, label, target_id, node_idx=node_idx,
subgraph_building_method='zero_filling')
return score - score_mask_important


def top_k_sparsity(data: Data, edge_mask: np.array, top_k: int):
def top_k_sparsity(data: Data, edge_mask: np.array, top_k: int, undirected=True):
""" return the size ratio of the subgraph with top_k score edges"""
if undirected:
top_k = 2 * top_k
selected_nodes = calculate_selected_nodes(data, edge_mask, top_k)
return 1 - len(selected_nodes) / data.x.shape[0]

Expand Down Expand Up @@ -59,7 +67,7 @@ def graph_build_split(X, edge_index, node_mask: np.array):


def gnn_score(coalition: list, data: Data, gnnNets, label: int,
target_id: int = -1, subgraph_building_method='zero_filling') -> torch.Tensor:
target_id: int = -1, node_idx=None, subgraph_building_method='zero_filling') -> torch.Tensor:
""" the prob of subgraph with selected nodes for required label and target node """
num_nodes = data.num_nodes
subgraph_build_func = get_graph_build_func(subgraph_building_method)
Expand All @@ -71,9 +79,9 @@ def gnn_score(coalition: list, data: Data, gnnNets, label: int,
logits, probs, _ = gnnNets(mask_data)

# get the score of predicted class for graph or specific node idx
node_idx = 0 if node_idx is None else node_idx
if target_id == -1:
score = probs[0, label].item()
score = probs[node_idx, label].item()
else:
score = probs[0, target_id, label].item()
score = probs[node_idx, target_id, label].item()
return score

51 changes: 51 additions & 0 deletions dig/xgraph/PGExplainer/models/GAT.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch_sparse import SparseTensor
from torch_geometric.nn.conv import GATConv
from torch_geometric.nn.glob import global_mean_pool, global_add_pool, global_max_pool
from torch_geometric.typing import Adj, Size


def get_readout_layers(readout):
Expand All @@ -19,6 +22,54 @@ def get_readout_layers(readout):
return ret_readout


class GATConv(GATConv):
def __init__(self, *args, **kwargs):
super(GATConv, self).__init__(*args, **kwargs)

def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
size = self.__check_input__(edge_index, size)

# Run "fused" message and aggregation (if applicable).
if (isinstance(edge_index, SparseTensor) and self.fuse
and not self.__explain__):
coll_dict = self.__collect__(self.__fused_user_args__, edge_index,
size, kwargs)

msg_aggr_kwargs = self.inspector.distribute(
'message_and_aggregate', coll_dict)
out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)

update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)

# Otherwise, run both functions in separation.
elif isinstance(edge_index, Tensor) or not self.fuse:
coll_dict = self.__collect__(self.__user_args__, edge_index, size,
kwargs)

msg_kwargs = self.inspector.distribute('message', coll_dict)
out = self.message(**msg_kwargs)

# For `GNNExplainer`, we require a separate message and aggregate
# procedure since this allows us to inject the `edge_mask` into the
# message passing computation scheme.
if self.__explain__:
edge_mask = self.__edge_mask__
# Some ops add self-loops to `edge_index`. We need to do the
# same for `edge_mask` (but do not train those).
if out.size(self.node_dim) != edge_mask.size(0):
loop = edge_mask.new_ones(size[0])
edge_mask = torch.cat([edge_mask, loop], dim=0)
assert out.size(self.node_dim) == edge_mask.size(0)
out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))

aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)

update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)


# GAT
class GATNet(nn.Module):
def __init__(self, input_dim, output_dim, model_args):
Expand Down
Loading