-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualizer.py
125 lines (98 loc) · 4.75 KB
/
visualizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import matplotlib.pyplot as plt
import torch
import numpy as np
def plotLoss(loss_list):
plt.plot(loss_list, color='orange')
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.title("Logistic regression")
#plt.scatter(loss_list, epochs, c=colors, alpha=0.5)
plt.show()
def visualizeAttention(in_sentences, in_wordweights, in_sentenceweights, sentenceLength = 6, listed_sentences = 3):
print(in_sentences[0:sentenceLength], " with wordweights_sentence1: ", in_wordweights[0])
print("This sentence's weight is: ", in_sentenceweights[0,0])
print(in_sentences[sentenceLength:sentenceLength * 2], " with wordweights_sentence2: ", in_wordweights[1])
print("This sentence's weight is: ", in_sentenceweights[0,1])
print(in_sentences[sentenceLength*2: sentenceLength*3], " with wordweights_sentence3: ", in_wordweights[2])
print("This sentence's weight is: ", in_sentenceweights[0,2])
highSentences, highSentenceIndxs = torch.topk(in_sentenceweights, listed_sentences, 1)
words_with_attention = list()
#numb_Sentences = len(highSentences[0])
for ind in highSentenceIndxs[0]:
for w in range(sentenceLength):
# print("words: ", in_sentences[(ind*sentenceLength)+w])
words_with_attention.append(in_sentences[(ind*sentenceLength)+w])
print(words_with_attention)
# print(highSentences, " are highest weighted sentences", "|| length is: ", len(highSentences[0]))
# Do heatmap viz
# Plot it out
fig, ax = plt.subplots()
fig.set_size_inches(15, 5)
heatmap = ax.pcolor(in_wordweights,cmap='RdBu')
plt.colorbar(heatmap)
titlestring = "Predicted Class: " #+ idx2genre[predmax] + ", True Class: " + example["genre"]
plt.title(titlestring)
#plt.colorbar()
# Format
fig = plt.gcf()
# turn off the frame
ax.set_frame_on(False)
# put the major ticks at the middle of each cell
ax.set_yticks(np.arange(in_wordweights.shape[0]) + 0.5, minor=False)
# want a more natural, table-like display
ax.invert_yaxis()
# Set the labels
# note I could have used nba_sort.columns but made "labels" instead
# ax.set_yticklabels(np.round(in_sentenceweights,3), minor=False)
#adjust here the number of sentences that should be visualized
ax.set_yticklabels(('%.5f' % highSentences[0,0].item(), '%.5f' % highSentences[0,1].item(), '%.5f' % highSentences[0,2].item()))
# rotate the
plt.xticks(rotation=90)
ax.grid(False)
# Turn off all the ticks
ax = plt.gca()
for t in ax.xaxis.get_major_ticks():
t.tick2line.set_visible = True
t.tick2line.set_visible = True
for t in ax.yaxis.get_major_ticks():
t.tick2line.set_visible = False
t.tick2line.set_visible = False
print(words_with_attention, "|| listen sentences amount is : ", listed_sentences)
for y in range(listed_sentences):#(sentenceLength):
no_words = min(sentenceLength, len(words_with_attention[y]))#min(sentenceLength, len(words_with_attention[y]))#len(sentences[y]))
# print("amount no_words", no_words)
# print("y ",y)
# print("no words ",no_words)
for x in range(no_words):
if y == 0:
textInd = x
currSent = 0
if y == 1:
textInd = x+sentenceLength
currSent = 1
if y == 2:
textInd = x+sentenceLength+sentenceLength
currSent = 2
print("x is: ", x, " while y is:", y, " and highsentence is: ", words_with_attention[textInd])
thisWWeights = in_wordweights[currSent]
thisWordsW = '%.5f' % thisWWeights[x].item() #repr(thisWWeights[x].item())
# print(thisWordsW)
#sth = str(thisWWeights[textInd].item())
#print(sth)
thisText = words_with_attention[textInd] + '\n' + thisWordsW
plt.text(x + 0.5, y + 0.5, thisText, horizontalalignment='center',
verticalalignment='center',)
"""plt.text(x + 0.5, y + 0.5, sentences[0][len(sentences[1])],#.decode('utf-8'),
horizontalalignment='center',
verticalalignment='center',
)
"""
#plt.savefig('rock_example.pdf')
plt.show()
"""example:
sentences = ["1_word1","1_word2","1_word3","1_word4","1_word5","1_word6","2_word1","2_word2","2_word3","2_word4","2_word5","2_word6","3_word1","3_word2","3_word3","3_word4","3_word5","3_word6","4_word1","4_word2","4_word3","4_word4","4_word5","4_word6","5_word1","5_word2","5_word3","5_word4","5_word5","5_word6","6_word1","6_word2","6_word3","6_word4","6_word5","6_word6"]
wordweights = torch.randn([3,6])
sentenceweights = torch.randn([1,6])
#print(len(sentences))
visualizeAttention(sentences, wordweights, sentenceweights)
"""