Skip to content

Commit c753f1c

Browse files
committed
attention plot
1 parent ea5ffd6 commit c753f1c

File tree

6 files changed

+233
-1
lines changed

6 files changed

+233
-1
lines changed

README.md

+13
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,19 @@ the predict model is the gpt_poetry3000_iters1999_1_loss_3259.634242.pkl
6363
```
6464
python gpt_predict_poetrythree.py
6565
```
66+
67+
#### show attention
68+
those attentions are extracted from the first attention block $\text{Softmax}\left(\frac{Q\cdot K^T}{\sqrt{d_i}}\right)$,the first row is $\text{Softmax}\left(\frac{Q\cdot K^T}{\sqrt{d_i}}\right)$, the second row is $\frac{Q\cdot K^T}{\sqrt{d_i}}$ without softmax.
69+
70+
床前明月光
71+
<img src="./dataset/cqmyg.png" width="66%"/>
72+
73+
满树桃花映日开
74+
<img src="./dataset/msthyrk.png" width="66%"/>
75+
76+
山高江水深
77+
<img src="./dataset/sgjss.png" width="66%"/>
78+
6679
##### blogs
6780
[https://zhuanlan.zhihu.com/p/659018819 numpy实现GPT的decoder来产生旧诗词的](https://zhuanlan.zhihu.com/p/659018819)
6881

dataset/cqmyg.png

29.5 KB
Loading

dataset/msthyrk.png

37.1 KB
Loading

dataset/sgjss.png

18.5 KB
Loading

gpt/attdecoderblock.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from net.activation import Softmax, GELU, ReLU
1212

1313
class attdecoderblock_layer():
14-
def __init__(self, embed_dim, num_h, adam=False, float32 = False):
14+
def __init__(self, embed_dim, num_h, adam=False, float32 = False, return_attention = False):
1515
self.embed_dim = embed_dim
1616
self.num_h = num_h
1717
self.len_single = embed_dim // num_h
@@ -26,6 +26,7 @@ def __init__(self, embed_dim, num_h, adam=False, float32 = False):
2626
self.softmax = Softmax()
2727
self.relu = ReLU()
2828
self.adam = adam
29+
self.return_attention = return_attention
2930

3031
def forward(self, inputs, masks = []):
3132
self.masks = masks
@@ -47,13 +48,16 @@ def forward(self, inputs, masks = []):
4748
self.block = block
4849
self.qkv = qkv
4950
self.atg__ = [[[] for j in range(self.num_h)] for i in range(batch)]
51+
self.att_col = [[[] for j in range(self.num_h)] for i in range(batch)]
5052
for n in range(batch):
5153
tmp = []
5254
for i in range(self.num_h):
5355
niq = qkv[n, :, 0, i]
5456
nik = qkv[n, :, 1, i]
5557
niv = qkv[n, :, 2, i]
5658
att = np.matmul(niq, nik.T) / np.sqrt(self.len_single)
59+
if self.return_attention:
60+
self.att_col[n][i] = att
5761
if len(masks) > 0:
5862
att = att + masks
5963
atg__ = self.softmax.forward(att, axis=-1)
@@ -66,6 +70,8 @@ def forward(self, inputs, masks = []):
6670

6771
self.out1 = self.fc_out.forward(self.rkk)
6872
input1 = self.out6 + self.out1
73+
if self.return_attention:
74+
return self.atg__, self.att_col
6975
return input1
7076

7177
def backward(self, delta):

gpt/show_attention.py

+213
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import os
4+
abspath = os.path.abspath(__file__)
5+
filename = os.sep.join(abspath.split(os.sep)[-2:])
6+
abspath = abspath.replace(filename, "")
7+
import sys
8+
sys.path.append(abspath)
9+
10+
from net.loss import cross_entropy_loss
11+
import numpy as np
12+
import pickle
13+
from net.layernorm import layer_norm
14+
from PatchEmbed import Position_Embedding
15+
from attention import attention_layer
16+
from attdecoderblock import attdecoderblock_layer
17+
from gpt.gpt_linear import gpt_linear_layer
18+
from gpt.gpt_train_potry3000 import getdata, create_masks_future
19+
from net.layernorm import layer_norm
20+
from net.fullconnect import fclayer
21+
from classify import classify_layer
22+
23+
from copy import deepcopy
24+
import json
25+
from matplotlib import colors
26+
from mpl_toolkits.axes_grid1 import make_axes_locatable
27+
28+
def predict(inputs):
29+
pretrained_model = r'C:\Users\10696\Desktop\Numpy\numpy_transformer\gpt\model\gpt_poetry3000_iters1999_1_loss_3259.634242.pkl'
30+
vocab_size, id2char, char2id, input_texts = getdata()
31+
32+
all_steps = 3000 - 1000
33+
batchsize = 63 + 1
34+
learning_rate = 0.003 # batchsize
35+
embed_dim = 192 ## vocab_size if vocab_size%3==0 else (vocab_size//3) * 3 + 3 # 192
36+
num_layer = 10 + 1 + 1
37+
num_h = [3] * num_layer
38+
context_length = 100
39+
40+
ADAM = False
41+
cls_token = True
42+
float32 = True
43+
44+
patchemb = Position_Embedding(context_length, vocab_size, embed_dim, adam=ADAM)
45+
layers = [patchemb]
46+
47+
at0 = attdecoderblock_layer(embed_dim, num_h[0], adam=ADAM, float32=float32, return_attention=True)
48+
at1 = attdecoderblock_layer(embed_dim, num_h[1], adam=ADAM, float32=float32)
49+
at2 = attdecoderblock_layer(embed_dim, num_h[2], adam=ADAM, float32=float32)
50+
at3 = attdecoderblock_layer(embed_dim, num_h[3], adam=ADAM, float32=float32, return_attention=True)
51+
at4 = attdecoderblock_layer(embed_dim, num_h[4], adam=ADAM, float32=float32)
52+
at5 = attdecoderblock_layer(embed_dim, num_h[5], adam=ADAM, float32=float32)
53+
at6 = attdecoderblock_layer(embed_dim, num_h[6], adam=ADAM, float32=float32)
54+
at7 = attdecoderblock_layer(embed_dim, num_h[7], adam=ADAM, float32=float32)
55+
at8 = attdecoderblock_layer(embed_dim, num_h[8], adam=ADAM, float32=float32)
56+
at9 = attdecoderblock_layer(embed_dim, num_h[9], adam=ADAM, float32=float32)
57+
at10 = attdecoderblock_layer(embed_dim, num_h[10], adam=ADAM, float32=float32)
58+
at11 = attdecoderblock_layer(embed_dim, num_h[11], adam=ADAM, float32=float32)
59+
# at12 = attdecoderblock_layer(embed_dim, num_h[12], adam=ADAM, float32=float32)
60+
# at13 = attdecoderblock_layer(embed_dim, num_h[13], adam=ADAM, float32=float32)
61+
62+
# layers += [at0, at1, at2, at3, at4, at5, at6, at7, at8, at9, at10, at11, at12]
63+
layers += [at0, at1, at2, at3, at4, at5, at6, at7, at8, at9, at10, at11]
64+
# layers += [at0, at1, at2, at3, at4, at5, at6]
65+
66+
norm = layer_norm(embed_dim, adam=ADAM)
67+
# if not cls_token:
68+
# cll = classify_layer(embed_dim, batchsize, 1, vocab_size, cls_token, adam=ADAM, relu=False, float32=float32)
69+
# else:
70+
cll = fclayer(embed_dim, vocab_size, True, adam=ADAM, float32=float32)
71+
layers += [norm, cll]
72+
73+
if os.path.exists(pretrained_model):
74+
with open(pretrained_model, 'rb') as obj:
75+
models = pickle.load(obj)
76+
cnt = 0
77+
for l in layers:
78+
k = dir(l)
79+
if 'restore_model' in k and 'save_model' in k:
80+
l.restore_model(models[cnt])
81+
cnt += 1
82+
del models
83+
84+
inputs = [char2id[ci] for ci in inputs]
85+
inputs = np.array([inputs])
86+
# inputs = np.random.randint(0, vocab_size, (1, 1))
87+
output = deepcopy(inputs)
88+
for ij in range(context_length - 1):
89+
text = deepcopy(inputs)
90+
input_mask_fut = create_masks_future(inputs)
91+
input_mask_fut[...] = 0
92+
for l in range(len(layers)):
93+
if isinstance(layers[l], attdecoderblock_layer):
94+
inputs = layers[l].forward(inputs, input_mask_fut)
95+
if layers[l].return_attention==True:
96+
return inputs
97+
else:
98+
inputs = layers[l].forward(inputs)
99+
class MidpointNormalize(colors.Normalize):
100+
def __init__(self, vmin=None, vmax=None, vcenter=None, clip=False):
101+
self.vcenter = vcenter
102+
super().__init__(vmin, vmax, clip)
103+
104+
def __call__(self, value, clip=None):
105+
# I'm ignoring masked values and all kinds of edge cases to make a
106+
# simple example...
107+
# Note also that we must extrapolate beyond vmin/vmax
108+
x, y = [self.vmin, self.vcenter, self.vmax], [0, 0.5, 1.]
109+
return np.ma.masked_array(np.interp(value, x, y,
110+
left=-np.inf, right=np.inf))
111+
112+
def inverse(self, value):
113+
y, x = [self.vmin, self.vcenter, self.vmax], [0, 0.5, 1]
114+
return np.interp(value, x, y, left=-np.inf, right=np.inf)
115+
116+
def plotattention():
117+
# https://matplotlib.org/stable/users/explain/text/fonts.html
118+
# trigger core fonts for PDF backend
119+
# from matplotlib.font_manager import _get_win32_installed_fonts, FontProperties, get_font, findSystemFonts
120+
# k = findSystemFonts()
121+
# fp = FontProperties()
122+
# fam = fp.get_family()
123+
# fp.set_family()
124+
plt.rcParams['font.family'] = ['DengXian']
125+
# plt.rcParams['font.family'] = ['SimHei']
126+
# trigger core fonts for PS backend
127+
# plt.rcParams["ps.useafm"] = True
128+
# https://matplotlib.org/stable/gallery/images_contours_and_fields/image_annotated_heatmap.html#sphx-glr-gallery-images-contours-and-fields-image-annotated-heatmap-py
129+
inputs = r'床前明月光'#满树桃花映日开 床前明月光 山高江水深
130+
# inputs = r'独客长安觉月遥'
131+
after_softmax, before_softmax = predict(inputs)
132+
after_softmax = after_softmax[0]
133+
before_softmax = before_softmax[0]
134+
fig, ax = plt.subplots(2, 3)
135+
cmap = "viridis" #"PuOr"
136+
ax_col = []
137+
for i in range(len(after_softmax)):
138+
att = np.array(after_softmax[i])
139+
att[att<0] = 0
140+
for j in range(len(att)):
141+
att[j, j] = 0
142+
ax[i//3, i%3].imshow(att, cmap = cmap)
143+
ax[i//3, i%3].label_outer()
144+
# Show all ticks and label them with the respective list entries
145+
ax[i//3, i%3].set_xticks(np.arange(len(inputs)), labels=inputs)
146+
ax[i//3, i%3].set_yticks(np.arange(len(inputs)), labels=inputs)
147+
ax[i//3, i%3].tick_params(top=True, bottom=False,
148+
labeltop=True, labelbottom=False)
149+
ax[i//3, i%3].spines[:].set_visible(False)
150+
# ax[i//3, i%3].set_xticks(np.arange(len(inputs)+1)-.5, minor=True)
151+
# ax[i//3, i%3].set_yticks(np.arange(len(inputs)+1)-.5, minor=True)
152+
ax[i//3, i%3].grid(which="minor", color="w", linestyle='-', linewidth=3)
153+
ax[i//3, i%3].tick_params(which="minor", bottom=False, left=False)
154+
155+
for i in range(len(before_softmax)):
156+
att = np.array(before_softmax[i])
157+
att[att<0] = 0
158+
for j in range(len(att)):
159+
att[j, j] = 0
160+
ax_col.append(ax[(i+3)//3, (i+3)%3].imshow(att, cmap=cmap))
161+
ax[i//3, i%3].label_outer()
162+
# Show all ticks and label them with the respective list entries
163+
ax[(i+3)//3, (i+3)%3].set_xticks(np.arange(len(inputs)), labels=inputs)
164+
ax[(i+3)//3, (i+3)%3].set_yticks(np.arange(len(inputs)), labels=inputs)
165+
ax[(i+3)//3, (i+3)%3].tick_params(top=True, bottom=False,
166+
labeltop=True, labelbottom=False)
167+
ax[(i+3)//3, (i+3)%3].spines[:].set_visible(False)
168+
# ax[(i+3)//3, (i+3)%3].set_xticks(np.arange(len(inputs)+1)-.5, minor=True)
169+
# ax[(i+3)//3, (i+3)%3].set_yticks(np.arange(len(inputs)+1)-.5, minor=True)
170+
ax[(i+3)//3, (i+3)%3].grid(which="minor", color="w", linestyle='-', linewidth=3)
171+
ax[(i+3)//3, (i+3)%3].tick_params(which="minor", bottom=False, left=False)
172+
173+
# Find the min and max of all colors for use in setting the color scale.
174+
# vmin = min(image.get_array().min() for image in images)
175+
# vmax = max(image.get_array().max() for image in images)
176+
# norm = colors.Normalize(vmin=vmin, vmax=vmax)
177+
# for im in images:
178+
# im.set_norm(norm)
179+
divider = make_axes_locatable(plt.gca())
180+
cax = divider.append_axes("right", "6%", pad="3%")
181+
# plt.colorbar(im, cax=cax)
182+
plt.colorbar(ax_col[-1], ax=ax, cax = cax, orientation='vertical', fraction=.09, location="right")
183+
# fig.subplots_adjust(bottom=0.1, right=0.8, top=0.9)
184+
# cax = fig.axes((0.85, 0.1, 0.07, 0.8))
185+
# fig.colorbar(cax=cax)
186+
fig.tight_layout()
187+
# Create colorbar
188+
plt.show()
189+
plt.close()
190+
191+
# k = np.array([[-0.01635286, -0.11992685, 0.55947667, 0.92379665, -0.51663721,
192+
# -0.15404126, -0.00184499],
193+
# [-0.2325137 , 0.07962042, -0.21066462, 0.26676813, 0.07328191,
194+
# 0.05249183, -0.0999634 ],
195+
# [-0.56667554, -0.40644667, -0.02308455, 0.33910039, -0.07462835,
196+
# 0.32599962, -0.07650095],
197+
# [-0.17541669, -0.11392358, 0.09849485, 0.29542559, -0.46458262,
198+
# -0.17263578, 0.33820251],
199+
# [-0.31038493, 0.1939393 , -0.33804718, 0.22234213, -0.46945375,
200+
# 0.07869679, 0.31734848],
201+
# [ 0.04049709, -0.38160264, -0.0219599 , -0.28630418, -0.98243105,
202+
# -0.64073545, -0.28534561],
203+
# [-0.6606887 , 0.15611888, -0.48505753, -0.24753423, -0.10505721,
204+
# -0.26085418, 0.90769702]])
205+
# k[k<0] = 0
206+
# for i in range(len(k)):
207+
# k[i, i] = 0
208+
# plt.imshow(k)
209+
# plt.show()
210+
# plt.close()
211+
212+
if __name__=="__main__":
213+
plotattention()

0 commit comments

Comments
 (0)