33
33
class VizualizeGraphRepresentation :
34
34
def __init__ (self , debug = False ,
35
35
dim = 2 , loss_fn = 'ec' , title_text = '' ,
36
- # weights_to_load='/home/ankit/learning_embeddings /exp/ethec_debug/ec_debug/d10/oe10d_debug /weights/best_model.pth'):
37
- weights_to_load = '/home/ankit/Desktop/emb_weights/joint_2xlr/best_model_model.pth' ):
36
+ weights_to_load = '/cluster/scratch/adhall /exp/ethec/final_ec_full/load_emb_5k/50-50_hide_levels/ec_2d_2xlr_init_5k /weights/best_model.pth' ):
37
+ # weights_to_load='/home/ankit/Desktop/emb_weights/joint_2xlr/best_model_model.pth'):
38
38
torch .manual_seed (0 )
39
39
40
40
self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -165,10 +165,11 @@ class VizualizeGraphRepresentationWithImages:
165
165
def __init__ (self , debug = False ,
166
166
dim = 2 ,
167
167
loss_fn = 'ec' ,
168
- weights_to_load = '/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/ec_2d_2xlr_feat4/ weights/160_model .pth' ,
169
- img_weights_to_load = '/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/ec_2d_2xlr_feat4/ weights/160_img_feat_net .pth' ):
168
+ weights_to_load = '/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/50-50_hide_levels/ec_2d_2xlr_init_5k/ weights/250_model .pth' ,
169
+ img_weights_to_load = '/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/50-50_hide_levels/ec_2d_2xlr_init_5k/ weights/250_img_feat_net .pth' , filename = 'combined_plot ' ):
170
170
torch .manual_seed (0 )
171
171
self .load_split = 'test'
172
+ self .filename = filename
172
173
173
174
self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
174
175
@@ -247,7 +248,7 @@ def create_loader(self):
247
248
])
248
249
test_set = ETHECHierarchyWithImages (self .graph_dict ['G_{}' .format (self .load_split )],
249
250
imageless_dataloaders = None ,
250
- transform = val_test_data_transforms )
251
+ transform = val_test_data_transforms , labelmap = self . labelmap )
251
252
252
253
testloader = torch .utils .data .DataLoader (test_set , collate_fn = my_collate ,
253
254
batch_size = 10 ,
@@ -319,8 +320,8 @@ def vizualize(self):
319
320
for key in self .img_to_emb :
320
321
emb = self .img_to_emb [key ]
321
322
e_norm = math .sqrt (emb [0 ]** 2 + emb [1 ]** 2 )
322
- if e_norm > 500 :
323
- emb = emb / e_norm * 500 .0
323
+ if e_norm > 10000 :
324
+ emb = emb / e_norm * 10000 .0
324
325
ax .scatter (emb [0 ], emb [1 ], c = level_color , alpha = 0.1 )
325
326
326
327
from_ix = max ([u for u , v in list (self .graph_dict ['G_{}' .format (self .load_split )].in_edges (key ))])
@@ -332,27 +333,28 @@ def vizualize(self):
332
333
333
334
ax .axis ('equal' )
334
335
if True :
335
- filename = 'combined_plot'
336
336
fig .set_size_inches (8 , 7 )
337
- fig .savefig (os .path .join (os .path .dirname (self .weights_to_load ), '..' , '{}.pdf' .format (filename )), dpi = 200 )
338
- fig .savefig (os .path .join (os .path .dirname (self .weights_to_load ), '..' , '{}.png' .format (filename )), dpi = 200 )
337
+ fig .savefig (os .path .join (os .path .dirname (self .weights_to_load ), '..' , '{}.pdf' .format (self . filename )), dpi = 200 )
338
+ fig .savefig (os .path .join (os .path .dirname (self .weights_to_load ), '..' , '{}.png' .format (self . filename )), dpi = 200 )
339
339
340
340
341
341
def create_images ():
342
- path_to_weights = '/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/ec_2d_2xlr_feat4_noinit /weights'
342
+ path_to_weights = '/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/50-50_hide_levels/ec_2d_2xlr_init_5k /weights'
343
343
loss_fn = 'ec'
344
344
files = os .listdir (path_to_weights )
345
345
files .sort ()
346
346
for filename in files :
347
347
if 'best_model' in filename or 'img' in filename :
348
348
continue
349
- viz = VizualizeGraphRepresentation (debug = False , dim = 2 , loss_fn = loss_fn , title_text = '' ,
350
- weights_to_load = os .path .join (path_to_weights , filename ))
351
- viz .vizualize (save_to_disk = True , filename = '{0:04d}' .format (int (filename [:- 10 ]) if 'model' in filename else int (filename [:- 4 ])))
349
+ #viz = VizualizeGraphRepresentation(debug=False, dim=2, loss_fn=loss_fn, title_text='',
350
+ # weights_to_load=os.path.join(path_to_weights, filename))
351
+ viz = VizualizeGraphRepresentationWithImages (debug = False , dim = 2 , loss_fn = loss_fn , filename = '{0:04d}' .format (int (filename .split ('_' )[0 ])),
352
+ weights_to_load = os .path .join (path_to_weights , filename ), img_weights_to_load = os .path .join (path_to_weights , '{}_img_feat_net.pth' .format (filename .split ('_' )[0 ])))
353
+ #viz.vizualize(save_to_disk=True, filename='{0:04d}'.format(int(filename[:-10]) if 'model' in filename else int(filename[:-4])))
352
354
plt .close ('all' )
353
355
354
356
355
357
if __name__ == '__main__' :
356
358
# obj = VizualizeGraphRepresentation(debug=False)
357
- obj = VizualizeGraphRepresentationWithImages (debug = False )
358
- # create_images()
359
+ # obj = VizualizeGraphRepresentationWithImages(debug=False)
360
+ create_images ()
0 commit comments