1
+ import numpy as np
2
+ import matplotlib .pyplot as plt
3
+ import os
4
+ abspath = os .path .abspath (__file__ )
5
+ filename = os .sep .join (abspath .split (os .sep )[- 2 :])
6
+ abspath = abspath .replace (filename , "" )
7
+ import sys
8
+ sys .path .append (abspath )
9
+
10
+ from net .loss import cross_entropy_loss
11
+ import numpy as np
12
+ import pickle
13
+ from net .layernorm import layer_norm
14
+ from PatchEmbed import Position_Embedding
15
+ from attention import attention_layer
16
+ from attdecoderblock import attdecoderblock_layer
17
+ from gpt .gpt_linear import gpt_linear_layer
18
+ from gpt .gpt_train_potry3000 import getdata , create_masks_future
19
+ from net .layernorm import layer_norm
20
+ from net .fullconnect import fclayer
21
+ from classify import classify_layer
22
+
23
+ from copy import deepcopy
24
+ import json
25
+ from matplotlib import colors
26
+ from mpl_toolkits .axes_grid1 import make_axes_locatable
27
+
28
+ def predict (inputs ):
29
+ pretrained_model = r'C:\Users\10696\Desktop\Numpy\numpy_transformer\gpt\model\gpt_poetry3000_iters1999_1_loss_3259.634242.pkl'
30
+ vocab_size , id2char , char2id , input_texts = getdata ()
31
+
32
+ all_steps = 3000 - 1000
33
+ batchsize = 63 + 1
34
+ learning_rate = 0.003 # batchsize
35
+ embed_dim = 192 ## vocab_size if vocab_size%3==0 else (vocab_size//3) * 3 + 3 # 192
36
+ num_layer = 10 + 1 + 1
37
+ num_h = [3 ] * num_layer
38
+ context_length = 100
39
+
40
+ ADAM = False
41
+ cls_token = True
42
+ float32 = True
43
+
44
+ patchemb = Position_Embedding (context_length , vocab_size , embed_dim , adam = ADAM )
45
+ layers = [patchemb ]
46
+
47
+ at0 = attdecoderblock_layer (embed_dim , num_h [0 ], adam = ADAM , float32 = float32 , return_attention = True )
48
+ at1 = attdecoderblock_layer (embed_dim , num_h [1 ], adam = ADAM , float32 = float32 )
49
+ at2 = attdecoderblock_layer (embed_dim , num_h [2 ], adam = ADAM , float32 = float32 )
50
+ at3 = attdecoderblock_layer (embed_dim , num_h [3 ], adam = ADAM , float32 = float32 , return_attention = True )
51
+ at4 = attdecoderblock_layer (embed_dim , num_h [4 ], adam = ADAM , float32 = float32 )
52
+ at5 = attdecoderblock_layer (embed_dim , num_h [5 ], adam = ADAM , float32 = float32 )
53
+ at6 = attdecoderblock_layer (embed_dim , num_h [6 ], adam = ADAM , float32 = float32 )
54
+ at7 = attdecoderblock_layer (embed_dim , num_h [7 ], adam = ADAM , float32 = float32 )
55
+ at8 = attdecoderblock_layer (embed_dim , num_h [8 ], adam = ADAM , float32 = float32 )
56
+ at9 = attdecoderblock_layer (embed_dim , num_h [9 ], adam = ADAM , float32 = float32 )
57
+ at10 = attdecoderblock_layer (embed_dim , num_h [10 ], adam = ADAM , float32 = float32 )
58
+ at11 = attdecoderblock_layer (embed_dim , num_h [11 ], adam = ADAM , float32 = float32 )
59
+ # at12 = attdecoderblock_layer(embed_dim, num_h[12], adam=ADAM, float32=float32)
60
+ # at13 = attdecoderblock_layer(embed_dim, num_h[13], adam=ADAM, float32=float32)
61
+
62
+ # layers += [at0, at1, at2, at3, at4, at5, at6, at7, at8, at9, at10, at11, at12]
63
+ layers += [at0 , at1 , at2 , at3 , at4 , at5 , at6 , at7 , at8 , at9 , at10 , at11 ]
64
+ # layers += [at0, at1, at2, at3, at4, at5, at6]
65
+
66
+ norm = layer_norm (embed_dim , adam = ADAM )
67
+ # if not cls_token:
68
+ # cll = classify_layer(embed_dim, batchsize, 1, vocab_size, cls_token, adam=ADAM, relu=False, float32=float32)
69
+ # else:
70
+ cll = fclayer (embed_dim , vocab_size , True , adam = ADAM , float32 = float32 )
71
+ layers += [norm , cll ]
72
+
73
+ if os .path .exists (pretrained_model ):
74
+ with open (pretrained_model , 'rb' ) as obj :
75
+ models = pickle .load (obj )
76
+ cnt = 0
77
+ for l in layers :
78
+ k = dir (l )
79
+ if 'restore_model' in k and 'save_model' in k :
80
+ l .restore_model (models [cnt ])
81
+ cnt += 1
82
+ del models
83
+
84
+ inputs = [char2id [ci ] for ci in inputs ]
85
+ inputs = np .array ([inputs ])
86
+ # inputs = np.random.randint(0, vocab_size, (1, 1))
87
+ output = deepcopy (inputs )
88
+ for ij in range (context_length - 1 ):
89
+ text = deepcopy (inputs )
90
+ input_mask_fut = create_masks_future (inputs )
91
+ input_mask_fut [...] = 0
92
+ for l in range (len (layers )):
93
+ if isinstance (layers [l ], attdecoderblock_layer ):
94
+ inputs = layers [l ].forward (inputs , input_mask_fut )
95
+ if layers [l ].return_attention == True :
96
+ return inputs
97
+ else :
98
+ inputs = layers [l ].forward (inputs )
99
+ class MidpointNormalize (colors .Normalize ):
100
+ def __init__ (self , vmin = None , vmax = None , vcenter = None , clip = False ):
101
+ self .vcenter = vcenter
102
+ super ().__init__ (vmin , vmax , clip )
103
+
104
+ def __call__ (self , value , clip = None ):
105
+ # I'm ignoring masked values and all kinds of edge cases to make a
106
+ # simple example...
107
+ # Note also that we must extrapolate beyond vmin/vmax
108
+ x , y = [self .vmin , self .vcenter , self .vmax ], [0 , 0.5 , 1. ]
109
+ return np .ma .masked_array (np .interp (value , x , y ,
110
+ left = - np .inf , right = np .inf ))
111
+
112
+ def inverse (self , value ):
113
+ y , x = [self .vmin , self .vcenter , self .vmax ], [0 , 0.5 , 1 ]
114
+ return np .interp (value , x , y , left = - np .inf , right = np .inf )
115
+
116
+ def plotattention ():
117
+ # https://matplotlib.org/stable/users/explain/text/fonts.html
118
+ # trigger core fonts for PDF backend
119
+ # from matplotlib.font_manager import _get_win32_installed_fonts, FontProperties, get_font, findSystemFonts
120
+ # k = findSystemFonts()
121
+ # fp = FontProperties()
122
+ # fam = fp.get_family()
123
+ # fp.set_family()
124
+ plt .rcParams ['font.family' ] = ['DengXian' ]
125
+ # plt.rcParams['font.family'] = ['SimHei']
126
+ # trigger core fonts for PS backend
127
+ # plt.rcParams["ps.useafm"] = True
128
+ # https://matplotlib.org/stable/gallery/images_contours_and_fields/image_annotated_heatmap.html#sphx-glr-gallery-images-contours-and-fields-image-annotated-heatmap-py
129
+ inputs = r'床前明月光' #满树桃花映日开 床前明月光 山高江水深
130
+ # inputs = r'独客长安觉月遥'
131
+ after_softmax , before_softmax = predict (inputs )
132
+ after_softmax = after_softmax [0 ]
133
+ before_softmax = before_softmax [0 ]
134
+ fig , ax = plt .subplots (2 , 3 )
135
+ cmap = "viridis" #"PuOr"
136
+ ax_col = []
137
+ for i in range (len (after_softmax )):
138
+ att = np .array (after_softmax [i ])
139
+ att [att < 0 ] = 0
140
+ for j in range (len (att )):
141
+ att [j , j ] = 0
142
+ ax [i // 3 , i % 3 ].imshow (att , cmap = cmap )
143
+ ax [i // 3 , i % 3 ].label_outer ()
144
+ # Show all ticks and label them with the respective list entries
145
+ ax [i // 3 , i % 3 ].set_xticks (np .arange (len (inputs )), labels = inputs )
146
+ ax [i // 3 , i % 3 ].set_yticks (np .arange (len (inputs )), labels = inputs )
147
+ ax [i // 3 , i % 3 ].tick_params (top = True , bottom = False ,
148
+ labeltop = True , labelbottom = False )
149
+ ax [i // 3 , i % 3 ].spines [:].set_visible (False )
150
+ # ax[i//3, i%3].set_xticks(np.arange(len(inputs)+1)-.5, minor=True)
151
+ # ax[i//3, i%3].set_yticks(np.arange(len(inputs)+1)-.5, minor=True)
152
+ ax [i // 3 , i % 3 ].grid (which = "minor" , color = "w" , linestyle = '-' , linewidth = 3 )
153
+ ax [i // 3 , i % 3 ].tick_params (which = "minor" , bottom = False , left = False )
154
+
155
+ for i in range (len (before_softmax )):
156
+ att = np .array (before_softmax [i ])
157
+ att [att < 0 ] = 0
158
+ for j in range (len (att )):
159
+ att [j , j ] = 0
160
+ ax_col .append (ax [(i + 3 )// 3 , (i + 3 )% 3 ].imshow (att , cmap = cmap ))
161
+ ax [i // 3 , i % 3 ].label_outer ()
162
+ # Show all ticks and label them with the respective list entries
163
+ ax [(i + 3 )// 3 , (i + 3 )% 3 ].set_xticks (np .arange (len (inputs )), labels = inputs )
164
+ ax [(i + 3 )// 3 , (i + 3 )% 3 ].set_yticks (np .arange (len (inputs )), labels = inputs )
165
+ ax [(i + 3 )// 3 , (i + 3 )% 3 ].tick_params (top = True , bottom = False ,
166
+ labeltop = True , labelbottom = False )
167
+ ax [(i + 3 )// 3 , (i + 3 )% 3 ].spines [:].set_visible (False )
168
+ # ax[(i+3)//3, (i+3)%3].set_xticks(np.arange(len(inputs)+1)-.5, minor=True)
169
+ # ax[(i+3)//3, (i+3)%3].set_yticks(np.arange(len(inputs)+1)-.5, minor=True)
170
+ ax [(i + 3 )// 3 , (i + 3 )% 3 ].grid (which = "minor" , color = "w" , linestyle = '-' , linewidth = 3 )
171
+ ax [(i + 3 )// 3 , (i + 3 )% 3 ].tick_params (which = "minor" , bottom = False , left = False )
172
+
173
+ # Find the min and max of all colors for use in setting the color scale.
174
+ # vmin = min(image.get_array().min() for image in images)
175
+ # vmax = max(image.get_array().max() for image in images)
176
+ # norm = colors.Normalize(vmin=vmin, vmax=vmax)
177
+ # for im in images:
178
+ # im.set_norm(norm)
179
+ divider = make_axes_locatable (plt .gca ())
180
+ cax = divider .append_axes ("right" , "6%" , pad = "3%" )
181
+ # plt.colorbar(im, cax=cax)
182
+ plt .colorbar (ax_col [- 1 ], ax = ax , cax = cax , orientation = 'vertical' , fraction = .09 , location = "right" )
183
+ # fig.subplots_adjust(bottom=0.1, right=0.8, top=0.9)
184
+ # cax = fig.axes((0.85, 0.1, 0.07, 0.8))
185
+ # fig.colorbar(cax=cax)
186
+ fig .tight_layout ()
187
+ # Create colorbar
188
+ plt .show ()
189
+ plt .close ()
190
+
191
+ # k = np.array([[-0.01635286, -0.11992685, 0.55947667, 0.92379665, -0.51663721,
192
+ # -0.15404126, -0.00184499],
193
+ # [-0.2325137 , 0.07962042, -0.21066462, 0.26676813, 0.07328191,
194
+ # 0.05249183, -0.0999634 ],
195
+ # [-0.56667554, -0.40644667, -0.02308455, 0.33910039, -0.07462835,
196
+ # 0.32599962, -0.07650095],
197
+ # [-0.17541669, -0.11392358, 0.09849485, 0.29542559, -0.46458262,
198
+ # -0.17263578, 0.33820251],
199
+ # [-0.31038493, 0.1939393 , -0.33804718, 0.22234213, -0.46945375,
200
+ # 0.07869679, 0.31734848],
201
+ # [ 0.04049709, -0.38160264, -0.0219599 , -0.28630418, -0.98243105,
202
+ # -0.64073545, -0.28534561],
203
+ # [-0.6606887 , 0.15611888, -0.48505753, -0.24753423, -0.10505721,
204
+ # -0.26085418, 0.90769702]])
205
+ # k[k<0] = 0
206
+ # for i in range(len(k)):
207
+ # k[i, i] = 0
208
+ # plt.imshow(k)
209
+ # plt.show()
210
+ # plt.close()
211
+
212
+ if __name__ == "__main__" :
213
+ plotattention ()
0 commit comments