Skip to content

Commit dae8499

Browse files
committed
modify viz hypernym to get plots with images for all weights
1 parent 3678557 commit dae8499

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

network/viz_hypernymy.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
class VizualizeGraphRepresentation:
3434
def __init__(self, debug=False,
3535
dim=2, loss_fn='ec', title_text='',
36-
# weights_to_load='/home/ankit/learning_embeddings/exp/ethec_debug/ec_debug/d10/oe10d_debug/weights/best_model.pth'):
37-
weights_to_load='/home/ankit/Desktop/emb_weights/joint_2xlr/best_model_model.pth'):
36+
weights_to_load='/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/50-50_hide_levels/ec_2d_2xlr_init_5k/weights/best_model.pth'):
37+
# weights_to_load='/home/ankit/Desktop/emb_weights/joint_2xlr/best_model_model.pth'):
3838
torch.manual_seed(0)
3939

4040
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -165,10 +165,11 @@ class VizualizeGraphRepresentationWithImages:
165165
def __init__(self, debug=False,
166166
dim=2,
167167
loss_fn='ec',
168-
weights_to_load='/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/ec_2d_2xlr_feat4/weights/160_model.pth',
169-
img_weights_to_load='/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/ec_2d_2xlr_feat4/weights/160_img_feat_net.pth'):
168+
weights_to_load='/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/50-50_hide_levels/ec_2d_2xlr_init_5k/weights/250_model.pth',
169+
img_weights_to_load='/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/50-50_hide_levels/ec_2d_2xlr_init_5k/weights/250_img_feat_net.pth', filename='combined_plot'):
170170
torch.manual_seed(0)
171171
self.load_split = 'test'
172+
self.filename = filename
172173

173174
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
174175

@@ -247,7 +248,7 @@ def create_loader(self):
247248
])
248249
test_set = ETHECHierarchyWithImages(self.graph_dict['G_{}'.format(self.load_split)],
249250
imageless_dataloaders=None,
250-
transform=val_test_data_transforms)
251+
transform=val_test_data_transforms, labelmap=self.labelmap)
251252

252253
testloader = torch.utils.data.DataLoader(test_set, collate_fn=my_collate,
253254
batch_size=10,
@@ -319,8 +320,8 @@ def vizualize(self):
319320
for key in self.img_to_emb:
320321
emb = self.img_to_emb[key]
321322
e_norm = math.sqrt(emb[0]**2 + emb[1]**2)
322-
if e_norm > 500:
323-
emb = emb/e_norm*500.0
323+
if e_norm > 10000:
324+
emb = emb/e_norm*10000.0
324325
ax.scatter(emb[0], emb[1], c=level_color, alpha=0.1)
325326

326327
from_ix = max([u for u, v in list(self.graph_dict['G_{}'.format(self.load_split)].in_edges(key))])
@@ -332,27 +333,28 @@ def vizualize(self):
332333

333334
ax.axis('equal')
334335
if True:
335-
filename = 'combined_plot'
336336
fig.set_size_inches(8, 7)
337-
fig.savefig(os.path.join(os.path.dirname(self.weights_to_load), '..', '{}.pdf'.format(filename)), dpi=200)
338-
fig.savefig(os.path.join(os.path.dirname(self.weights_to_load), '..', '{}.png'.format(filename)), dpi=200)
337+
fig.savefig(os.path.join(os.path.dirname(self.weights_to_load), '..', '{}.pdf'.format(self.filename)), dpi=200)
338+
fig.savefig(os.path.join(os.path.dirname(self.weights_to_load), '..', '{}.png'.format(self.filename)), dpi=200)
339339

340340

341341
def create_images():
342-
path_to_weights = '/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/ec_2d_2xlr_feat4_noinit/weights'
342+
path_to_weights = '/cluster/scratch/adhall/exp/ethec/final_ec_full/load_emb_5k/50-50_hide_levels/ec_2d_2xlr_init_5k/weights'
343343
loss_fn = 'ec'
344344
files = os.listdir(path_to_weights)
345345
files.sort()
346346
for filename in files:
347347
if 'best_model' in filename or 'img' in filename:
348348
continue
349-
viz = VizualizeGraphRepresentation(debug=False, dim=2, loss_fn=loss_fn, title_text='',
350-
weights_to_load=os.path.join(path_to_weights, filename))
351-
viz.vizualize(save_to_disk=True, filename='{0:04d}'.format(int(filename[:-10]) if 'model' in filename else int(filename[:-4])))
349+
#viz = VizualizeGraphRepresentation(debug=False, dim=2, loss_fn=loss_fn, title_text='',
350+
# weights_to_load=os.path.join(path_to_weights, filename))
351+
viz = VizualizeGraphRepresentationWithImages(debug=False, dim=2, loss_fn=loss_fn, filename='{0:04d}'.format(int(filename.split('_')[0])),
352+
weights_to_load=os.path.join(path_to_weights, filename), img_weights_to_load=os.path.join(path_to_weights, '{}_img_feat_net.pth'.format(filename.split('_')[0])))
353+
#viz.vizualize(save_to_disk=True, filename='{0:04d}'.format(int(filename[:-10]) if 'model' in filename else int(filename[:-4])))
352354
plt.close('all')
353355

354356

355357
if __name__ == '__main__':
356358
# obj = VizualizeGraphRepresentation(debug=False)
357-
obj = VizualizeGraphRepresentationWithImages(debug=False)
358-
# create_images()
359+
#obj = VizualizeGraphRepresentationWithImages(debug=False)
360+
create_images()

0 commit comments

Comments
 (0)