Skip to content

Commit b91895b

Browse files
authored
Merge pull request #7 from ankitdhall/Butterfly200
Butterfly200
2 parents 1b2d6bd + c21d3bb commit b91895b

16 files changed

+10963
-109
lines changed

data/db.py

+1,131-30
Large diffs are not rendered by default.

data/visualize_graph/viz.html

+21-6
Original file line numberDiff line numberDiff line change
@@ -7365,7 +7365,9 @@
73657365
.id(function(d) { return d.name; });
73667366

73677367
var charge_force = d3.forceManyBody()
7368-
.strength(-1000);
7368+
.distanceMin(10)
7369+
.distanceMax(500)
7370+
.strength(-2000);
73697371

73707372
var center_force = d3.forceCenter(width / 2, height / 2);
73717373

@@ -7390,7 +7392,7 @@
73907392
.data(links_data)
73917393
.enter().append("line")
73927394
.attr("stroke-width", 4)
7393-
.style("stroke", "red");
7395+
.style("stroke", "black");
73947396

73957397
//draw circles for the nodes
73967398
var node = g.append("g")
@@ -7437,22 +7439,35 @@
74377439
//Let's return blue for males and red for females
74387440
function circleColour(d){
74397441
if(d.color == "b"){
7440-
return "blue";
7442+
return "#1A5276";
74417443
}
74427444
else if(d.color == "r"){
7443-
return "red";
7445+
return "#17A589";
74447446
}
74457447
else if(d.color == "y"){
74467448
return "black";
74477449
}
74487450
else{
7449-
return "orange";
7451+
return "#D68910";
74507452
}
74517453

74527454
}
74537455

74547456
function circleRadius(d){
7455-
return Math.max(20*Math.log10(d.count), 5);
7457+
//return 25
7458+
//return Math.max(20*Math.log10(d.count), 5);
7459+
if(d.color == "b"){
7460+
return 50;
7461+
}
7462+
else if(d.color == "r"){
7463+
return 35;
7464+
}
7465+
else if(d.color == "y"){
7466+
return 25;
7467+
}
7468+
else{
7469+
return 15;
7470+
}
74567471

74577472
}
74587473

network/embed_toy.py

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
from __future__ import print_function
2+
from __future__ import division
3+
import torch
4+
torch.multiprocessing.set_sharing_strategy('file_system')
5+
6+
import os
7+
8+
import json
9+
import git
10+
import argparse
11+
12+
import numpy as np
13+
import random
14+
random.seed(0)
15+
16+
import networkx as nx
17+
18+
import matplotlib
19+
matplotlib.use('pdf')
20+
#matplotlib.use('tkagg')
21+
import matplotlib.pyplot as plt
22+
import cv2
23+
24+
from network.order_embeddings import OrderEmbedding, OrderEmbeddingLoss, EucConesLoss, Embedder
25+
from tensorboardX import SummaryWriter
26+
import torch.nn as nn
27+
28+
29+
class ToyGraph:
30+
def __init__(self, levels=4, branching_factor=3):
31+
self.n_levels = levels
32+
self.branching_factor = branching_factor
33+
self.levels = [self.branching_factor**i for i in range(1, self.n_levels)]
34+
self.level_names = [str(i) for i in range(1, self.n_levels)]
35+
36+
for level_id, level_name in enumerate(self.level_names):
37+
setattr(self, level_name, {'{}_{}'.format(level_name, str(i)): i for i in range(self.levels[level_id])})
38+
39+
# make child_of_
40+
for level_id, level_name in enumerate(self.level_names[:-1]):
41+
setattr(self, 'child_of_' + level_name, {'{}_{}'.format(level_name, str(i)): ['{}_{}'.format(self.level_names[level_id+1], str(j+(self.branching_factor*i))) for j in range(self.branching_factor)] for i in range(self.levels[level_id])})
42+
43+
self.n_classes = sum(self.levels)
44+
self.classes = [key for class_list in [getattr(self, level_name) for level_name in self.level_names] for key
45+
in class_list]
46+
self.level_stop, self.level_start = [], []
47+
for level_id, level_len in enumerate(self.levels):
48+
if level_id == 0:
49+
self.level_start.append(0)
50+
self.level_stop.append(level_len)
51+
else:
52+
self.level_start.append(self.level_stop[level_id - 1])
53+
self.level_stop.append(self.level_stop[level_id - 1] + level_len)
54+
55+
self.edges = set()
56+
for level_id, level_name in enumerate(self.level_names[:-1]):
57+
child_of_dict = getattr(self, 'child_of_' + level_name)
58+
for parent_node in child_of_dict:
59+
for child_node in child_of_dict[parent_node]:
60+
u = getattr(self, level_name)[parent_node] + self.level_start[level_id]
61+
v = getattr(self, self.level_names[level_id+1])[child_node] + self.level_start[level_id+1]
62+
self.edges.add((u, v))
63+
64+
65+
class ToyOrderEmbedding(OrderEmbedding):
66+
def __init__(self, labelmap, criterion, lr, batch_size, evaluator, experiment_name, embedding_dim,
67+
neg_to_pos_ratio, alpha, proportion_of_nb_edges_in_train, lr_step=[], pick_per_level=False,
68+
experiment_dir='../exp/', n_epochs=10, eval_interval=2, feature_extracting=True, load_wt=False,
69+
optimizer_method='adam', lr_decay=1.0, random_seed=0):
70+
torch.manual_seed(random_seed)
71+
72+
self.epoch = 0
73+
self.exp_dir = experiment_dir
74+
self.load_wt = load_wt
75+
self.pick_per_level = pick_per_level
76+
77+
self.eval = evaluator
78+
self.criterion = criterion
79+
self.batch_size = batch_size
80+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81+
print('Using device: {}'.format(self.device))
82+
if torch.cuda.device_count() > 1:
83+
print("== Using", torch.cuda.device_count(), "GPUs!")
84+
self.n_epochs = n_epochs
85+
self.eval_interval = eval_interval
86+
87+
self.log_dir = os.path.join(self.exp_dir, '{}').format(experiment_name)
88+
self.path_to_save_model = os.path.join(self.log_dir, 'weights')
89+
if not os.path.exists(self.path_to_save_model):
90+
os.makedirs(self.path_to_save_model)
91+
92+
self.writer = SummaryWriter(log_dir=os.path.join(self.log_dir, 'tensorboard'))
93+
94+
self.classes = labelmap.classes
95+
self.n_classes = labelmap.n_classes
96+
self.levels = labelmap.levels
97+
self.n_levels = len(self.levels)
98+
self.level_names = labelmap.level_names
99+
self.lr = lr
100+
self.batch_size = batch_size
101+
self.feature_extracting = feature_extracting
102+
self.optimizer_method = optimizer_method
103+
self.lr_step = lr_step
104+
105+
self.optimal_threshold = 0
106+
self.embedding_dim = embedding_dim
107+
self.neg_to_pos_ratio = neg_to_pos_ratio
108+
self.proportion_of_nb_edges_in_train = proportion_of_nb_edges_in_train
109+
if isinstance(criterion, EucConesLoss):
110+
self.model = Embedder(embedding_dim=self.embedding_dim, labelmap=labelmap, K=self.criterion.K)
111+
else:
112+
self.model = Embedder(embedding_dim=self.embedding_dim, labelmap=labelmap)
113+
self.model = self.model.to(self.device)
114+
self.model = nn.DataParallel(self.model)
115+
self.labelmap = labelmap
116+
117+
self.G, self.G_train, self.G_val, self.G_test = nx.DiGraph(), nx.DiGraph(), nx.DiGraph(), nx.DiGraph()
118+
for edge in self.labelmap.edges:
119+
u, v = edge
120+
self.G.add_edge(u, v)
121+
122+
self.G_tc = nx.transitive_closure(self.G)
123+
self.create_splits()
124+
125+
self.criterion.set_negative_graph(self.G_train_neg, self.mapping_ix_to_node, self.mapping_node_to_ix)
126+
127+
self.lr_decay = lr_decay
128+
129+
self.check_graph_embedding_neg_graph = None
130+
self.check_reconstr_every = 1
131+
self.save_model_every = 5
132+
133+
self.reconstruction_f1, self.reconstruction_threshold, self.reconstruction_accuracy, self.reconstruction_prec, self.reconstruction_recall = 0.0, 0.0, 0.0, 0.0, 0.0
134+
self.n_proc = 512 if torch.cuda.device_count() > 0 else 4
135+
print('Using {} processess!'.format(self.n_proc))
136+
137+
def embed_toy_model(arguments):
138+
if not os.path.exists(os.path.join(arguments.experiment_dir, arguments.experiment_name)):
139+
os.makedirs(os.path.join(arguments.experiment_dir, arguments.experiment_name))
140+
args_dict = vars(arguments)
141+
repo = git.Repo(search_parent_directories=True)
142+
args_dict['commit_hash'] = repo.head.object.hexsha
143+
args_dict['branch'] = repo.active_branch.name
144+
with open(os.path.join(arguments.experiment_dir, arguments.experiment_name, 'config_params.txt'), 'w') as file:
145+
file.write(json.dumps(args_dict, indent=4))
146+
147+
print('Config parameters for this run are:\n{}'.format(json.dumps(vars(arguments), indent=4)))
148+
149+
labelmap = ToyGraph(levels=arguments.tree_levels, branching_factor=arguments.tree_branching)
150+
151+
batch_size = arguments.batch_size
152+
n_workers = arguments.n_workers
153+
154+
eval_type = None
155+
156+
use_criterion = None
157+
if arguments.loss == 'order_emb_loss':
158+
use_criterion = OrderEmbeddingLoss(labelmap=labelmap, neg_to_pos_ratio=arguments.neg_to_pos_ratio, alpha=arguments.alpha)
159+
elif arguments.loss == 'euc_cones_loss':
160+
use_criterion = EucConesLoss(labelmap=labelmap, neg_to_pos_ratio=arguments.neg_to_pos_ratio, alpha=arguments.alpha)
161+
else:
162+
print("== Invalid --loss argument")
163+
164+
oe = ToyOrderEmbedding(labelmap=labelmap, criterion=use_criterion, lr=arguments.lr,
165+
batch_size=batch_size, experiment_name=arguments.experiment_name,
166+
embedding_dim=arguments.embedding_dim, neg_to_pos_ratio=arguments.neg_to_pos_ratio,
167+
alpha=arguments.alpha, pick_per_level=arguments.pick_per_level,
168+
proportion_of_nb_edges_in_train=arguments.prop_of_nb_edges, lr_step=arguments.lr_step,
169+
experiment_dir=arguments.experiment_dir, n_epochs=arguments.n_epochs,
170+
eval_interval=arguments.eval_interval, feature_extracting=False, evaluator=None,
171+
load_wt=arguments.resume, optimizer_method=arguments.optimizer_method,
172+
lr_decay=arguments.lr_decay, random_seed=arguments.random_seed)
173+
oe.prepare_model()
174+
f1, acc = oe.train()
175+
176+
title = 'L={}, b={} \n F1 score: {:.4f} Accuracy: {:.4f}'.format(str(arguments.tree_levels-1),
177+
str(arguments.tree_branching), f1, acc)
178+
179+
from network.viz_toy import VizualizeGraphRepresentation
180+
path_to_best = os.path.join(arguments.experiment_dir, arguments.experiment_name, 'weights', 'best_model.pth')
181+
viz = VizualizeGraphRepresentation(weights_to_load=path_to_best, title_text='', L=arguments.tree_levels, b=arguments.tree_branching)
182+
183+
184+
if __name__ == '__main__':
185+
parser = argparse.ArgumentParser()
186+
parser.add_argument("--lr", help='Input learning rate.', type=float, default=0.001)
187+
parser.add_argument("--batch_size", help='Batch size.', type=int, default=8)
188+
parser.add_argument("--experiment_name", help='Experiment name.', type=str, required=True)
189+
parser.add_argument("--experiment_dir", help='Experiment directory.', type=str, required=True)
190+
parser.add_argument("--n_epochs", help='Number of epochs to run training for.', type=int, required=True)
191+
parser.add_argument("--n_workers", help='Number of workers.', type=int, default=4)
192+
parser.add_argument("--eval_interval", help='Evaluate model every N intervals.', type=int, default=1)
193+
parser.add_argument("--embedding_dim", help='Dimensions of learnt embeddings.', type=int, default=10)
194+
parser.add_argument("--neg_to_pos_ratio", help='Number of negatives to sample for one positive.', type=int,
195+
default=5)
196+
parser.add_argument("--alpha", help='Margin alpha.', type=float, default=0.05)
197+
parser.add_argument("--prop_of_nb_edges", help='Proportion of non-basic edges to be added to train set.',
198+
type=float, default=0.0)
199+
parser.add_argument("--resume", help='Continue training from last checkpoint.', action='store_true')
200+
parser.add_argument("--optimizer_method", help='[adam, sgd]', type=str, default='adam')
201+
parser.add_argument("--loss",
202+
help='Loss function to use. [order_emb_loss, euc_emb_loss]',
203+
type=str, required=True)
204+
parser.add_argument("--pick_per_level", help='Pick negatives from each level in the graph.', action='store_true')
205+
parser.add_argument("--lr_step", help='List of epochs to make multiple lr by 0.1', nargs='*', default=[],
206+
type=int)
207+
parser.add_argument("--lr_decay", help='Decay lr by a factor.', default=1.0, type=float)
208+
parser.add_argument("--tree_levels", help='tree levels', required=True, type=int)
209+
parser.add_argument("--tree_branching", help='branching factor', required=True, type=int)
210+
parser.add_argument("--random_seed", help='pytorch random seed', default=0, type=int)
211+
# cmd = """--pick_per_level --tree_levels 6 --tree_branching 2 --n_epochs 5 --lr 0.1 --loss euc_cones_loss --embedding_dim 2 --neg_to_pos_ratio 5 --alpha 0.01 --experiment_name toy_graph --batch_size 10 --optimizer adam --experiment_dir ../exp/embed_toy/"""
212+
213+
# args = parser.parse_args(cmd.split(' '))
214+
args = parser.parse_args()
215+
216+
embed_toy_model(args)

0 commit comments

Comments
 (0)