|
| 1 | +from __future__ import print_function |
| 2 | +from __future__ import division |
| 3 | +import torch |
| 4 | +torch.multiprocessing.set_sharing_strategy('file_system') |
| 5 | + |
| 6 | +import os |
| 7 | + |
| 8 | +import json |
| 9 | +import git |
| 10 | +import argparse |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import random |
| 14 | +random.seed(0) |
| 15 | + |
| 16 | +import networkx as nx |
| 17 | + |
| 18 | +import matplotlib |
| 19 | +matplotlib.use('pdf') |
| 20 | +#matplotlib.use('tkagg') |
| 21 | +import matplotlib.pyplot as plt |
| 22 | +import cv2 |
| 23 | + |
| 24 | +from network.order_embeddings import OrderEmbedding, OrderEmbeddingLoss, EucConesLoss, Embedder |
| 25 | +from tensorboardX import SummaryWriter |
| 26 | +import torch.nn as nn |
| 27 | + |
| 28 | + |
| 29 | +class ToyGraph: |
| 30 | + def __init__(self, levels=4, branching_factor=3): |
| 31 | + self.n_levels = levels |
| 32 | + self.branching_factor = branching_factor |
| 33 | + self.levels = [self.branching_factor**i for i in range(1, self.n_levels)] |
| 34 | + self.level_names = [str(i) for i in range(1, self.n_levels)] |
| 35 | + |
| 36 | + for level_id, level_name in enumerate(self.level_names): |
| 37 | + setattr(self, level_name, {'{}_{}'.format(level_name, str(i)): i for i in range(self.levels[level_id])}) |
| 38 | + |
| 39 | + # make child_of_ |
| 40 | + for level_id, level_name in enumerate(self.level_names[:-1]): |
| 41 | + setattr(self, 'child_of_' + level_name, {'{}_{}'.format(level_name, str(i)): ['{}_{}'.format(self.level_names[level_id+1], str(j+(self.branching_factor*i))) for j in range(self.branching_factor)] for i in range(self.levels[level_id])}) |
| 42 | + |
| 43 | + self.n_classes = sum(self.levels) |
| 44 | + self.classes = [key for class_list in [getattr(self, level_name) for level_name in self.level_names] for key |
| 45 | + in class_list] |
| 46 | + self.level_stop, self.level_start = [], [] |
| 47 | + for level_id, level_len in enumerate(self.levels): |
| 48 | + if level_id == 0: |
| 49 | + self.level_start.append(0) |
| 50 | + self.level_stop.append(level_len) |
| 51 | + else: |
| 52 | + self.level_start.append(self.level_stop[level_id - 1]) |
| 53 | + self.level_stop.append(self.level_stop[level_id - 1] + level_len) |
| 54 | + |
| 55 | + self.edges = set() |
| 56 | + for level_id, level_name in enumerate(self.level_names[:-1]): |
| 57 | + child_of_dict = getattr(self, 'child_of_' + level_name) |
| 58 | + for parent_node in child_of_dict: |
| 59 | + for child_node in child_of_dict[parent_node]: |
| 60 | + u = getattr(self, level_name)[parent_node] + self.level_start[level_id] |
| 61 | + v = getattr(self, self.level_names[level_id+1])[child_node] + self.level_start[level_id+1] |
| 62 | + self.edges.add((u, v)) |
| 63 | + |
| 64 | + |
| 65 | +class ToyOrderEmbedding(OrderEmbedding): |
| 66 | + def __init__(self, labelmap, criterion, lr, batch_size, evaluator, experiment_name, embedding_dim, |
| 67 | + neg_to_pos_ratio, alpha, proportion_of_nb_edges_in_train, lr_step=[], pick_per_level=False, |
| 68 | + experiment_dir='../exp/', n_epochs=10, eval_interval=2, feature_extracting=True, load_wt=False, |
| 69 | + optimizer_method='adam', lr_decay=1.0, random_seed=0): |
| 70 | + torch.manual_seed(random_seed) |
| 71 | + |
| 72 | + self.epoch = 0 |
| 73 | + self.exp_dir = experiment_dir |
| 74 | + self.load_wt = load_wt |
| 75 | + self.pick_per_level = pick_per_level |
| 76 | + |
| 77 | + self.eval = evaluator |
| 78 | + self.criterion = criterion |
| 79 | + self.batch_size = batch_size |
| 80 | + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 81 | + print('Using device: {}'.format(self.device)) |
| 82 | + if torch.cuda.device_count() > 1: |
| 83 | + print("== Using", torch.cuda.device_count(), "GPUs!") |
| 84 | + self.n_epochs = n_epochs |
| 85 | + self.eval_interval = eval_interval |
| 86 | + |
| 87 | + self.log_dir = os.path.join(self.exp_dir, '{}').format(experiment_name) |
| 88 | + self.path_to_save_model = os.path.join(self.log_dir, 'weights') |
| 89 | + if not os.path.exists(self.path_to_save_model): |
| 90 | + os.makedirs(self.path_to_save_model) |
| 91 | + |
| 92 | + self.writer = SummaryWriter(log_dir=os.path.join(self.log_dir, 'tensorboard')) |
| 93 | + |
| 94 | + self.classes = labelmap.classes |
| 95 | + self.n_classes = labelmap.n_classes |
| 96 | + self.levels = labelmap.levels |
| 97 | + self.n_levels = len(self.levels) |
| 98 | + self.level_names = labelmap.level_names |
| 99 | + self.lr = lr |
| 100 | + self.batch_size = batch_size |
| 101 | + self.feature_extracting = feature_extracting |
| 102 | + self.optimizer_method = optimizer_method |
| 103 | + self.lr_step = lr_step |
| 104 | + |
| 105 | + self.optimal_threshold = 0 |
| 106 | + self.embedding_dim = embedding_dim |
| 107 | + self.neg_to_pos_ratio = neg_to_pos_ratio |
| 108 | + self.proportion_of_nb_edges_in_train = proportion_of_nb_edges_in_train |
| 109 | + if isinstance(criterion, EucConesLoss): |
| 110 | + self.model = Embedder(embedding_dim=self.embedding_dim, labelmap=labelmap, K=self.criterion.K) |
| 111 | + else: |
| 112 | + self.model = Embedder(embedding_dim=self.embedding_dim, labelmap=labelmap) |
| 113 | + self.model = self.model.to(self.device) |
| 114 | + self.model = nn.DataParallel(self.model) |
| 115 | + self.labelmap = labelmap |
| 116 | + |
| 117 | + self.G, self.G_train, self.G_val, self.G_test = nx.DiGraph(), nx.DiGraph(), nx.DiGraph(), nx.DiGraph() |
| 118 | + for edge in self.labelmap.edges: |
| 119 | + u, v = edge |
| 120 | + self.G.add_edge(u, v) |
| 121 | + |
| 122 | + self.G_tc = nx.transitive_closure(self.G) |
| 123 | + self.create_splits() |
| 124 | + |
| 125 | + self.criterion.set_negative_graph(self.G_train_neg, self.mapping_ix_to_node, self.mapping_node_to_ix) |
| 126 | + |
| 127 | + self.lr_decay = lr_decay |
| 128 | + |
| 129 | + self.check_graph_embedding_neg_graph = None |
| 130 | + self.check_reconstr_every = 1 |
| 131 | + self.save_model_every = 5 |
| 132 | + |
| 133 | + self.reconstruction_f1, self.reconstruction_threshold, self.reconstruction_accuracy, self.reconstruction_prec, self.reconstruction_recall = 0.0, 0.0, 0.0, 0.0, 0.0 |
| 134 | + self.n_proc = 512 if torch.cuda.device_count() > 0 else 4 |
| 135 | + print('Using {} processess!'.format(self.n_proc)) |
| 136 | + |
| 137 | +def embed_toy_model(arguments): |
| 138 | + if not os.path.exists(os.path.join(arguments.experiment_dir, arguments.experiment_name)): |
| 139 | + os.makedirs(os.path.join(arguments.experiment_dir, arguments.experiment_name)) |
| 140 | + args_dict = vars(arguments) |
| 141 | + repo = git.Repo(search_parent_directories=True) |
| 142 | + args_dict['commit_hash'] = repo.head.object.hexsha |
| 143 | + args_dict['branch'] = repo.active_branch.name |
| 144 | + with open(os.path.join(arguments.experiment_dir, arguments.experiment_name, 'config_params.txt'), 'w') as file: |
| 145 | + file.write(json.dumps(args_dict, indent=4)) |
| 146 | + |
| 147 | + print('Config parameters for this run are:\n{}'.format(json.dumps(vars(arguments), indent=4))) |
| 148 | + |
| 149 | + labelmap = ToyGraph(levels=arguments.tree_levels, branching_factor=arguments.tree_branching) |
| 150 | + |
| 151 | + batch_size = arguments.batch_size |
| 152 | + n_workers = arguments.n_workers |
| 153 | + |
| 154 | + eval_type = None |
| 155 | + |
| 156 | + use_criterion = None |
| 157 | + if arguments.loss == 'order_emb_loss': |
| 158 | + use_criterion = OrderEmbeddingLoss(labelmap=labelmap, neg_to_pos_ratio=arguments.neg_to_pos_ratio, alpha=arguments.alpha) |
| 159 | + elif arguments.loss == 'euc_cones_loss': |
| 160 | + use_criterion = EucConesLoss(labelmap=labelmap, neg_to_pos_ratio=arguments.neg_to_pos_ratio, alpha=arguments.alpha) |
| 161 | + else: |
| 162 | + print("== Invalid --loss argument") |
| 163 | + |
| 164 | + oe = ToyOrderEmbedding(labelmap=labelmap, criterion=use_criterion, lr=arguments.lr, |
| 165 | + batch_size=batch_size, experiment_name=arguments.experiment_name, |
| 166 | + embedding_dim=arguments.embedding_dim, neg_to_pos_ratio=arguments.neg_to_pos_ratio, |
| 167 | + alpha=arguments.alpha, pick_per_level=arguments.pick_per_level, |
| 168 | + proportion_of_nb_edges_in_train=arguments.prop_of_nb_edges, lr_step=arguments.lr_step, |
| 169 | + experiment_dir=arguments.experiment_dir, n_epochs=arguments.n_epochs, |
| 170 | + eval_interval=arguments.eval_interval, feature_extracting=False, evaluator=None, |
| 171 | + load_wt=arguments.resume, optimizer_method=arguments.optimizer_method, |
| 172 | + lr_decay=arguments.lr_decay, random_seed=arguments.random_seed) |
| 173 | + oe.prepare_model() |
| 174 | + f1, acc = oe.train() |
| 175 | + |
| 176 | + title = 'L={}, b={} \n F1 score: {:.4f} Accuracy: {:.4f}'.format(str(arguments.tree_levels-1), |
| 177 | + str(arguments.tree_branching), f1, acc) |
| 178 | + |
| 179 | + from network.viz_toy import VizualizeGraphRepresentation |
| 180 | + path_to_best = os.path.join(arguments.experiment_dir, arguments.experiment_name, 'weights', 'best_model.pth') |
| 181 | + viz = VizualizeGraphRepresentation(weights_to_load=path_to_best, title_text='', L=arguments.tree_levels, b=arguments.tree_branching) |
| 182 | + |
| 183 | + |
| 184 | +if __name__ == '__main__': |
| 185 | + parser = argparse.ArgumentParser() |
| 186 | + parser.add_argument("--lr", help='Input learning rate.', type=float, default=0.001) |
| 187 | + parser.add_argument("--batch_size", help='Batch size.', type=int, default=8) |
| 188 | + parser.add_argument("--experiment_name", help='Experiment name.', type=str, required=True) |
| 189 | + parser.add_argument("--experiment_dir", help='Experiment directory.', type=str, required=True) |
| 190 | + parser.add_argument("--n_epochs", help='Number of epochs to run training for.', type=int, required=True) |
| 191 | + parser.add_argument("--n_workers", help='Number of workers.', type=int, default=4) |
| 192 | + parser.add_argument("--eval_interval", help='Evaluate model every N intervals.', type=int, default=1) |
| 193 | + parser.add_argument("--embedding_dim", help='Dimensions of learnt embeddings.', type=int, default=10) |
| 194 | + parser.add_argument("--neg_to_pos_ratio", help='Number of negatives to sample for one positive.', type=int, |
| 195 | + default=5) |
| 196 | + parser.add_argument("--alpha", help='Margin alpha.', type=float, default=0.05) |
| 197 | + parser.add_argument("--prop_of_nb_edges", help='Proportion of non-basic edges to be added to train set.', |
| 198 | + type=float, default=0.0) |
| 199 | + parser.add_argument("--resume", help='Continue training from last checkpoint.', action='store_true') |
| 200 | + parser.add_argument("--optimizer_method", help='[adam, sgd]', type=str, default='adam') |
| 201 | + parser.add_argument("--loss", |
| 202 | + help='Loss function to use. [order_emb_loss, euc_emb_loss]', |
| 203 | + type=str, required=True) |
| 204 | + parser.add_argument("--pick_per_level", help='Pick negatives from each level in the graph.', action='store_true') |
| 205 | + parser.add_argument("--lr_step", help='List of epochs to make multiple lr by 0.1', nargs='*', default=[], |
| 206 | + type=int) |
| 207 | + parser.add_argument("--lr_decay", help='Decay lr by a factor.', default=1.0, type=float) |
| 208 | + parser.add_argument("--tree_levels", help='tree levels', required=True, type=int) |
| 209 | + parser.add_argument("--tree_branching", help='branching factor', required=True, type=int) |
| 210 | + parser.add_argument("--random_seed", help='pytorch random seed', default=0, type=int) |
| 211 | + # cmd = """--pick_per_level --tree_levels 6 --tree_branching 2 --n_epochs 5 --lr 0.1 --loss euc_cones_loss --embedding_dim 2 --neg_to_pos_ratio 5 --alpha 0.01 --experiment_name toy_graph --batch_size 10 --optimizer adam --experiment_dir ../exp/embed_toy/""" |
| 212 | + |
| 213 | + # args = parser.parse_args(cmd.split(' ')) |
| 214 | + args = parser.parse_args() |
| 215 | + |
| 216 | + embed_toy_model(args) |
0 commit comments