Skip to content

Commit 5725edd

Browse files
committed
change lr=0.1, single layer feat net
1 parent 453e50f commit 5725edd

File tree

1 file changed

+63
-34
lines changed

1 file changed

+63
-34
lines changed

network/oe.py

+63-34
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,12 @@ def __init__(self, normalize, input_dim=2048, output_dim=10, K=None):
8989
Constructor to prepare layers for the embedding.
9090
"""
9191
super(FeatNet, self).__init__()
92-
self.fc1 = nn.Linear(input_dim, 4096)
93-
self.fc2 = nn.Linear(4096, 2048)
94-
self.fc3 = nn.Linear(2048, 1024)
95-
self.fc4 = nn.Linear(1024, output_dim)
92+
# self.fc1 = nn.Linear(input_dim, 4096)
93+
# self.fc2 = nn.Linear(4096, 2048)
94+
# self.fc3 = nn.Linear(2048, 1024)
95+
# self.fc4 = nn.Linear(1024, output_dim)
96+
97+
self.fc1 = nn.Linear(input_dim, output_dim)
9698

9799
self.normalize = normalize
98100
self.K = K
@@ -103,10 +105,12 @@ def forward(self, x):
103105
"""
104106
Forward pass through the model.
105107
"""
106-
x = F.relu(self.fc1(x))
107-
x = F.relu(self.fc2(x))
108-
x = F.relu(self.fc3(x))
109-
x = self.fc4(x)
108+
# x = F.relu(self.fc1(x))
109+
# x = F.relu(self.fc2(x))
110+
# x = F.relu(self.fc3(x))
111+
# x = self.fc4(x)
112+
113+
x = self.fc1(x)
110114

111115
if self.normalize == 'unit_norm':
112116
original_shape = x.shape
@@ -318,7 +322,7 @@ def calculate_best(self, threshold):
318322
else:
319323
f1_score = (2 * precision * recall) / (precision + recall)
320324

321-
return f1_score, threshold, accuracy, precision, recall, correct_positives, correct_negatives
325+
return f1_score, threshold, accuracy, precision, recall, correct_positives/self.e_for_u_v_positive.shape[0], correct_negatives/self.e_for_u_v_negative.shape[0]
322326

323327
def calculate_metrics(self):
324328
if self.phase == 'val':
@@ -354,7 +358,7 @@ def calculate_metrics(self):
354358
f1_score = 0.0
355359
else:
356360
f1_score = (2 * precision * recall) / (precision + recall)
357-
return f1_score, self.threshold, accuracy, precision, recall, correct_positives, correct_negatives
361+
return f1_score, self.threshold, accuracy, precision, recall, correct_positives/self.e_for_u_v_positive.shape[0], correct_negatives/self.e_for_u_v_negative.shape[0]
358362

359363

360364
def create_combined_graphs(dataloaders, labelmap):
@@ -417,13 +421,15 @@ def create_combined_graphs(dataloaders, labelmap):
417421
np.save('neg_adjacency.npy', A)
418422

419423
nx.write_gpickle(G, 'G')
424+
nx.write_gpickle(nx.transitive_closure(G), 'G_tc')
420425
nx.write_gpickle(G_train, 'G_train')
421426
nx.write_gpickle(G_val, 'G_val')
422427
nx.write_gpickle(G_test, 'G_test')
423428
nx.write_gpickle(G_train_skeleton_full, 'G_train_skeleton_full')
424429
nx.write_gpickle(G_train_tc, 'G_train_tc')
425430

426431
return {'graph': G, # graph with labels only; edges between labels only
432+
'graph_tc': nx.transitive_closure(G), # tc(graph with labels only; edges between labels only)
427433
'G_train': G_train, 'G_val': G_val, 'G_test': G_test, # graph with labels and images; edges between labels and images only
428434
'G_train_skeleton_full': G_train_skeleton_full, # graph with edge between labels + between labels and images
429435
'G_train_neg': A, # adjacency labels and images but only negative edges
@@ -714,7 +720,7 @@ def sample_negative_edge(self, u=None, v=None, level_id=None):
714720
np.where(np.logical_and(choose_from >= level_start, choose_from < level_stop))].tolist()
715721
else:
716722
level_start, level_stop = self.labelmap.level_stop[-1], None
717-
if type(v) == str:
723+
if type(v) == str or type(u) == str:
718724
choose_from = choose_from[np.where(choose_from < level_start)[0]].tolist()
719725
else:
720726
choose_from = choose_from[np.where(choose_from >= level_start)[0]].tolist()
@@ -739,7 +745,7 @@ def sample_negative_edge(self, u=None, v=None, level_id=None):
739745
choose_from = choose_from[np.where(np.logical_and(choose_from >= level_start, choose_from < level_stop))].tolist()
740746
else:
741747
level_start, level_stop = self.labelmap.level_stop[-1], None
742-
if type(v) == str:
748+
if type(v) == str or type(u) == str:
743749
choose_from = choose_from[np.where(choose_from < level_start)[0]].tolist()
744750
else:
745751
choose_from = choose_from[np.where(choose_from >= level_start)[0]].tolist()
@@ -802,7 +808,7 @@ def forward(self, model, img_feat_net, inputs_from, inputs_to, original_from, o
802808
negative_from[
803809
2 * self.neg_to_pos_ratio * batch_id + pass_ix + self.neg_to_pos_ratio] = self.mapping_from_ix_to_node[corrupted_ix]
804810
negative_to[2 * self.neg_to_pos_ratio * batch_id + pass_ix + self.neg_to_pos_ratio] = sample_inputs_to
805-
811+
806812
# get embeddings for concepts and images
807813
negative_from_embeddings, negative_to_embeddings = self.calculate_from_and_to_emb(model, img_feat_net, negative_from, negative_to)
808814

@@ -1283,7 +1289,7 @@ def __init__(self, graph_dict, imageless_dataloaders, image_dir,
12831289

12841290
self.check_graph_embedding_neg_graph = None
12851291

1286-
self.check_reconstr_every = 10
1292+
self.check_reconstr_every = 1
12871293
self.save_model_every = 10
12881294
self.reconstruction_f1, self.reconstruction_threshold, self.reconstruction_accuracy, self.reconstruction_prec, self.reconstruction_recall = 0.0, 0.0, 0.0, 0.0, 0.0
12891295
self.n_proc = 512 if torch.cuda.device_count() > 0 else 4
@@ -1294,8 +1300,8 @@ def make_dir_if_non_existent(directory):
12941300
os.makedirs(directory)
12951301

12961302
def prepare_model(self):
1297-
self.params_to_update = [{'params': self.model.parameters(), 'lr': 0.001},
1298-
{'params': self.img_feat_net.parameters()}]
1303+
self.params_to_update = [{'params': self.model.parameters(), 'lr': 0.1},
1304+
{'params': self.img_feat_net.parameters(), 'weight_decay': 0.0}]
12991305

13001306
def create_splits(self):
13011307
input_size = 224
@@ -1350,28 +1356,27 @@ def find_existing_weights(self):
13501356

13511357
def run_model(self, optimizer):
13521358
self.optimizer = optimizer
1353-
scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.lr_step, gamma=0.1)
1359+
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.lr_step, gamma=0.1)
13541360

13551361
if self.load_wt:
13561362
self.find_existing_weights()
13571363

13581364
self.best_model_wts = copy.deepcopy(self.model.state_dict())
13591365
self.best_score = 0.0
13601366

1361-
levels_to_hide_for_epoch = {}
1367+
self.levels_to_hide_for_epoch = {}
13621368
if self.hide_levels:
1363-
levels_to_hide_for_epoch = {0: [1, 2, 3], 20: [2, 3], 50: [3], 100: []}
1364-
#levels_to_hide_for_epoch = {0: [1, 2, 3], 100: [2, 3], 200: [3], 500: []}
1369+
self.levels_to_hide_for_epoch = {0: [1, 2, 3], 20: [2, 3], 50: [3], 100: []}
13651370

13661371
if True:
13671372
current_levels = None
1368-
for key in levels_to_hide_for_epoch:
1373+
for key in self.levels_to_hide_for_epoch:
13691374
if self.epoch >= key:
13701375
current_levels = key
13711376
if current_levels is not None:
1372-
print('Set levels to hide to: {}'.format(levels_to_hide_for_epoch[current_levels]))
1373-
self.train_set.set_levels_to_hide(levels_to_hide_for_epoch[current_levels])
1374-
self.criterion.set_levels_to_hide(levels_to_hide_for_epoch[current_levels])
1377+
print('Set levels to hide to: {}'.format(self.levels_to_hide_for_epoch[current_levels]))
1378+
self.train_set.set_levels_to_hide(self.levels_to_hide_for_epoch[current_levels])
1379+
self.criterion.set_levels_to_hide(self.levels_to_hide_for_epoch[current_levels])
13751380
trainloader = torch.utils.data.DataLoader(self.train_set,
13761381
batch_size=self.batch_size,
13771382
num_workers=self.n_workers, collate_fn=my_collate,
@@ -1387,10 +1392,10 @@ def run_model(self, optimizer):
13871392
print('Epoch {}/{}'.format(self.epoch, self.n_epochs - 1))
13881393
print('=' * 10)
13891394

1390-
if self.epoch in levels_to_hide_for_epoch:
1391-
print('Set levels to hide to: {}'.format(levels_to_hide_for_epoch[self.epoch]))
1392-
self.train_set.set_levels_to_hide(levels_to_hide_for_epoch[self.epoch])
1393-
self.criterion.set_levels_to_hide(levels_to_hide_for_epoch[self.epoch])
1395+
if self.epoch in self.levels_to_hide_for_epoch:
1396+
print('Set levels to hide to: {}'.format(self.levels_to_hide_for_epoch[self.epoch]))
1397+
self.train_set.set_levels_to_hide(self.levels_to_hide_for_epoch[self.epoch])
1398+
self.criterion.set_levels_to_hide(self.levels_to_hide_for_epoch[self.epoch])
13941399
trainloader = torch.utils.data.DataLoader(self.train_set,
13951400
batch_size=self.batch_size,
13961401
num_workers=self.n_workers, collate_fn=my_collate,
@@ -1412,7 +1417,7 @@ def run_model(self, optimizer):
14121417
self.pass_samples(phase='test')
14131418
self.writer.add_scalar('epoch_time_test', time.time() - test_start_time, self.epoch)
14141419

1415-
scheduler.step()
1420+
self.scheduler.step()
14161421

14171422
epoch_time = time.time() - epoch_start_time
14181423
self.writer.add_scalar('epoch_time', time.time() - epoch_start_time, self.epoch)
@@ -1469,7 +1474,6 @@ def pass_samples(self, phase, save_to_tensorboard=True):
14691474
# statistics
14701475
running_loss += loss.item()
14711476

1472-
14731477
# metrics = EmbeddingMetrics(e_positive, e_negative, self.optimal_threshold, phase)
14741478
# f1_score, threshold, accuracy = metrics.calculate_metrics()
14751479
classification_metrics = self.calculate_classification_metrics(phase)
@@ -1478,6 +1482,9 @@ def pass_samples(self, phase, save_to_tensorboard=True):
14781482
if save_to_tensorboard:
14791483
self.writer.add_scalar('{}_loss'.format(phase), epoch_loss, self.epoch)
14801484
self.writer.add_scalar('{}_thresh'.format(phase), self.optimal_threshold, self.epoch)
1485+
for pg_ix, param_group in enumerate(self.optimizer.param_groups):
1486+
self.writer.add_scalar('lr_param_group_{}'.format(pg_ix), param_group['lr'], self.epoch)
1487+
14811488
print('train loss: {}'.format(epoch_loss))
14821489

14831490
else:
@@ -1689,13 +1696,16 @@ def calculate_classification_metrics(self, phase, k=[1, 3, 5]):
16891696
img_rep[ix:min(ix + bs, len(images_in_graph) - 1), :] = self.img_feat_net(torch.stack(image_stack).to(self.device)).cpu().detach()
16901697
else:
16911698
img_rep[ix:min(ix+bs, len(images_in_graph)-1), :] = self.img_feat_net(self.criterion.get_img_features(images_in_graph[ix:min(ix+bs, len(images_in_graph)-1)]).to(self.device)).cpu().detach()
1699+
calculated_metrics['median_img_norm'] = torch.median(torch.norm(img_rep, dim=1))
1700+
16921701
img_rep = img_rep.unsqueeze(0)
16931702

16941703
label_rep = torch.zeros((len(labels_in_graph), self.embedding_dim)).to(self.device)
16951704
for ix in range(0, len(labels_in_graph), bs):
16961705
label_rep[ix:min(ix + bs, len(labels_in_graph) - 1), :] = self.model(
16971706
torch.tensor(labels_in_graph[ix:min(ix + bs, len(labels_in_graph) - 1)], dtype=torch.long).to(
16981707
self.device))
1708+
calculated_metrics['median_label_norm'] = torch.median(torch.norm(label_rep, dim=1))
16991709
label_rep = label_rep.cpu().detach().unsqueeze(0)
17001710

17011711
for image_name in images_in_graph:
@@ -1858,20 +1868,37 @@ def calculate_classification_metrics(self, phase, k=[1, 3, 5]):
18581868
return calculated_metrics
18591869

18601870
def check_graph_embedding(self):
1861-
if self.check_graph_embedding_neg_graph is None:
1871+
if self.check_graph_embedding_neg_graph is None or self.epoch in self.levels_to_hide_for_epoch:
1872+
# self.levels_to_hide_for_epoch = {0: [1, 2, 3], 20: [2, 3], 50: [3], 100: []}
18621873
# make negative graph
1874+
1875+
sub_G = nx.DiGraph()
1876+
edge_list = [e for e in self.graph_dict['graph_tc'].edges() if type(e[0]) != str and type(e[1]) != str]
1877+
for u, v in edge_list:
1878+
flag = True
1879+
if self.hide_levels:
1880+
for level_to_hide in self.levels_to_hide_for_epoch[self.epoch]:
1881+
if (type(u) != str and self.labelmap.level_start[level_to_hide] <= u < self.labelmap.level_stop[level_to_hide]) or (type(v) != str and self.labelmap.level_start[level_to_hide] <= v < self.labelmap.level_stop[level_to_hide]):
1882+
flag = False
1883+
break
1884+
# add edge only if it does not have a node belonging to a level to hide
1885+
if flag:
1886+
sub_G.add_edge(u, v)
1887+
else:
1888+
sub_G.add_edge(u, v)
1889+
18631890
start_time = time.time()
1864-
n_nodes = len(list(self.graph_dict['graph'].nodes()))
1891+
n_nodes = len(sub_G.nodes())
18651892

18661893
A = np.ones((n_nodes, n_nodes), dtype=np.bool)
18671894

1868-
for u, v in list(self.graph_dict['graph'].edges()):
1895+
for u, v in list(sub_G.edges()):
18691896
# remove edges that are in G_train_tc
18701897
A[u, v] = 0
18711898
np.fill_diagonal(A, 0)
18721899
self.check_graph_embedding_neg_graph = A
18731900

1874-
self.edges_in_G = self.graph_dict['graph'].edges()
1901+
self.edges_in_G = sub_G.edges()
18751902
self.n_nodes_in_G = n_nodes
18761903
self.nodes_in_G = [i for i in range(self.n_nodes_in_G)]
18771904

@@ -1918,6 +1945,7 @@ def load_combined_graphs(debug):
19181945
path_to_folder = '../database/ETHEC/ETHEC_embeddings/graphs'
19191946

19201947
G = nx.read_gpickle(os.path.join(path_to_folder, 'G'))
1948+
G_tc = nx.read_gpickle(os.path.join(path_to_folder, 'G_tc'))
19211949
G_train = nx.read_gpickle(os.path.join(path_to_folder, 'G_train'))
19221950
G_val = nx.read_gpickle(os.path.join(path_to_folder, 'G_val'))
19231951
G_test = nx.read_gpickle(os.path.join(path_to_folder, 'G_test'))
@@ -1948,6 +1976,7 @@ def load_combined_graphs(debug):
19481976
mapping_node_to_ix = {mapping_ix_to_node[k]: k for k in mapping_ix_to_node}
19491977

19501978
return {'graph': G, # graph with labels only; edges between labels only
1979+
'graph_tc': G_tc, # tc(graph with labels only; edges between labels only)
19511980
'G_train': G_train, 'G_val': G_val, 'G_test': G_test,
19521981
# graph with labels and images; edges between labels and images only
19531982
'G_train_skeleton_full': G_train_skeleton_full,

0 commit comments

Comments
 (0)