3
3
import sys
4
4
import os
5
5
import numpy as np
6
+ import math
6
7
7
8
# the next line can be removed after installation
8
9
sys .path .insert (0 , os .path .dirname (os .path .dirname (
12
13
import veriloggen .thread as vthread
13
14
import veriloggen .types .axi as axi
14
15
15
- datawidth = 8
16
+ mem_datawidth = 8
17
+ datawidth = 16
16
18
addrwidth = 8
17
19
18
- matrix_size = 8
20
+ matrix_size = 10
21
+
22
+ num_pack = math .ceil (datawidth / mem_datawidth )
23
+ addr_pack = math .ceil ((addrwidth + math .ceil (np .log2 (datawidth / mem_datawidth )))
24
+ / mem_datawidth )
25
+
26
+ matrix_size_addr = 0
27
+ a_offset_addr = 4
28
+ b_offset_addr = 8
29
+ c_offset_addr = 12
19
30
a_offset = 16
20
- b_offset = a_offset + matrix_size * matrix_size
21
- c_offset = b_offset + matrix_size * matrix_size
31
+ b_offset = a_offset + matrix_size * matrix_size * num_pack
32
+ c_offset = b_offset + matrix_size * matrix_size * num_pack
22
33
23
34
24
35
def mkLed ():
@@ -28,7 +39,8 @@ def mkLed():
28
39
start = m .Input ('start' )
29
40
busy = m .OutputReg ('busy' , initval = 0 )
30
41
31
- ram = vthread .ExtRAM (m , 'ram' , clk , rst , datawidth , addrwidth )
42
+ ram = vthread .ExtRAM (m , 'ram' , clk , rst , mem_datawidth ,
43
+ addrwidth + math .ceil (np .log2 (datawidth / mem_datawidth )))
32
44
33
45
def matmul ():
34
46
while True :
@@ -46,27 +58,31 @@ def wait():
46
58
busy .value = 1
47
59
48
60
def read_matrix_size ():
49
- size0 = ram .read (0 )
50
- size1 = ram .read (1 )
51
- size = (size1 << 8 ) | size0
61
+ size = 0
62
+ for i in range (addr_pack ):
63
+ size |= ((ram .read (matrix_size_addr + i ) & ((1 << mem_datawidth ) - 1 ))
64
+ << (mem_datawidth * i ))
52
65
return size
53
66
54
67
def read_matrix_a_offset ():
55
- offset0 = ram .read (4 ) & 0xff
56
- offset1 = ram .read (5 ) & 0xff
57
- offset = (offset1 << 8 ) | offset0
68
+ offset = 0
69
+ for i in range (addr_pack ):
70
+ offset |= ((ram .read (a_offset_addr + i ) & ((1 << mem_datawidth ) - 1 ))
71
+ << (mem_datawidth * i ))
58
72
return offset
59
73
60
74
def read_matrix_b_offset ():
61
- offset0 = ram .read (8 ) & 0xff
62
- offset1 = ram .read (9 ) & 0xff
63
- offset = (offset1 << 8 ) | offset0
75
+ offset = 0
76
+ for i in range (addr_pack ):
77
+ offset |= ((ram .read (b_offset_addr + i ) & ((1 << mem_datawidth ) - 1 ))
78
+ << (mem_datawidth * i ))
64
79
return offset
65
80
66
81
def read_matrix_c_offset ():
67
- offset0 = ram .read (12 ) & 0xff
68
- offset1 = ram .read (13 ) & 0xff
69
- offset = (offset1 << 8 ) | offset0
82
+ offset = 0
83
+ for i in range (addr_pack ):
84
+ offset |= ((ram .read (c_offset_addr + i ) & ((1 << mem_datawidth ) - 1 ))
85
+ << (mem_datawidth * i ))
70
86
return offset
71
87
72
88
def comp (matrix_size , a_offset , b_offset , c_offset ):
@@ -77,15 +93,24 @@ def comp(matrix_size, a_offset, b_offset, c_offset):
77
93
for j in range (matrix_size ):
78
94
sum = 0
79
95
for k in range (matrix_size ):
80
- x = ram .read (a_addr + k )
81
- y = ram .read (b_addr + k )
96
+ x = int (0 , base = 2 )
97
+ y = 0
98
+ for l in range (num_pack ):
99
+ x |= ((ram .read (a_addr + k * num_pack + l )
100
+ & ((1 << mem_datawidth ) - 1 ))
101
+ << (mem_datawidth * l ))
102
+ y |= ((ram .read (b_addr + k * num_pack + l )
103
+ & ((1 << mem_datawidth ) - 1 ))
104
+ << (mem_datawidth * l ))
82
105
sum += x * y
83
- ram .write (c_addr + j , sum )
106
+ for l in range (num_pack ):
107
+ ram .write (c_addr + j * num_pack + l ,
108
+ (sum >> (mem_datawidth * l )) & (1 << mem_datawidth )- 1 )
84
109
85
- b_addr += matrix_size * ( datawidth // 8 )
110
+ b_addr += matrix_size * num_pack
86
111
87
- a_addr += matrix_size * ( datawidth // 8 )
88
- c_addr += matrix_size * ( datawidth // 8 )
112
+ a_addr += matrix_size * num_pack
113
+ c_addr += matrix_size * num_pack
89
114
90
115
def done ():
91
116
busy .value = 0
@@ -128,13 +153,11 @@ def mkTest(memimg_name=None):
128
153
b [y ][x ] = 0
129
154
130
155
a_addr = a_offset
131
- size_a = n_a * datawidth // 8
132
156
b_addr = b_offset
133
- size_b = n_b * datawidth // 8
134
157
135
- mem = np .zeros ([2 ** addrwidth * ( 8 // datawidth ) ], dtype = np .int64 )
136
- axi .set_memory (mem , a , datawidth , datawidth , a_addr )
137
- axi .set_memory (mem , b , datawidth , datawidth , b_addr )
158
+ mem = np .zeros ([( 2 ** addrwidth ) * num_pack ], dtype = np .int64 )
159
+ axi .set_memory (mem , a , mem_datawidth , datawidth , a_addr )
160
+ axi .set_memory (mem , b , mem_datawidth , datawidth , b_addr )
138
161
139
162
led = mkLed ()
140
163
@@ -149,7 +172,8 @@ def mkTest(memimg_name=None):
149
172
150
173
start .initval = 0
151
174
152
- memory = vthread .RAM (m , 'memory' , clk , rst , datawidth , addrwidth ,
175
+ memory = vthread .RAM (m , 'memory' , clk , rst , mem_datawidth ,
176
+ addrwidth + math .ceil (np .log2 (datawidth / mem_datawidth )),
153
177
numports = 2 , initvals = mem .tolist ())
154
178
memory .connect_rtl (0 , ports ['ram_0_addr' ], ports ['ram_0_wdata' ],
155
179
ports ['ram_0_wenable' ], ports ['ram_0_rdata' ],
@@ -166,45 +190,33 @@ def ctrl():
166
190
for i in range (100 ):
167
191
pass
168
192
169
- awaddr = 0
170
- v = (matrix_size & 0xff )
171
- print ('# matrix_size[7:0] = %d' % v )
172
- memory .write (awaddr , v , port = 1 )
173
-
174
- awaddr = 1
175
- v = ((matrix_size >> 8 ) & 0xff )
176
- print ('# matrix_size[15:8] = %d' % v )
177
- memory .write (awaddr , v , port = 1 )
178
-
179
- awaddr = 4
180
- v = (a_offset & 0xff )
181
- print ('# a_offset[7:0] = %d' % v )
182
- memory .write (awaddr , v , port = 1 )
183
-
184
- awaddr = 5
185
- v = ((a_offset >> 8 ) & 0xff )
186
- print ('# a_offset[15:8] = %d' % v )
187
- memory .write (awaddr , v , port = 1 )
188
-
189
- awaddr = 8
190
- v = (b_offset & 0xff )
191
- print ('# b_offset[7:0] = %d' % v )
192
- memory .write (awaddr , v , port = 1 )
193
-
194
- awaddr = 9
195
- v = ((b_offset >> 8 ) & 0xff )
196
- print ('# b_offset[15:8] = %d' % v )
197
- memory .write (awaddr , v , port = 1 )
198
-
199
- awaddr = 12
200
- v = (c_offset & 0xff )
201
- print ('# c_offset[7:0] = %d' % v )
202
- memory .write (awaddr , v , port = 1 )
203
-
204
- awaddr = 13
205
- v = ((c_offset >> 8 ) & 0xff )
206
- print ('# c_offset[15:8] = %d' % v )
207
- memory .write (awaddr , v , port = 1 )
193
+ for i in range (addr_pack ):
194
+ awaddr = matrix_size_addr + i
195
+ v = (matrix_size >> (mem_datawidth * i )) & ((1 << mem_datawidth ) - 1 )
196
+ print ('# matrix_size[%d:%d] = %d' %
197
+ (mem_datawidth * (i + 1 ) - 1 , mem_datawidth * i , v ))
198
+ memory .write (awaddr , v , port = 1 )
199
+
200
+ for i in range (addr_pack ):
201
+ awaddr = a_offset_addr + i
202
+ v = (a_offset >> (mem_datawidth * i )) & ((1 << mem_datawidth ) - 1 )
203
+ print ('# a_offset[%d:%d] = %d' %
204
+ (mem_datawidth * (i + 1 ) - 1 , mem_datawidth * i , v ))
205
+ memory .write (awaddr , v , port = 1 )
206
+
207
+ for i in range (addr_pack ):
208
+ awaddr = b_offset_addr + i
209
+ v = (b_offset >> (mem_datawidth * i )) & ((1 << mem_datawidth ) - 1 )
210
+ print ('# b_offset[%d:%d] = %d' %
211
+ (mem_datawidth * (i + 1 ) - 1 , mem_datawidth * i , v ))
212
+ memory .write (awaddr , v , port = 1 )
213
+
214
+ for i in range (addr_pack ):
215
+ awaddr = c_offset_addr + i
216
+ v = (c_offset >> (mem_datawidth * i )) & ((1 << mem_datawidth ) - 1 )
217
+ print ('# c_offset[%d:%d] = %d' %
218
+ (mem_datawidth * (i + 1 ) - 1 , mem_datawidth * i , v ))
219
+ memory .write (awaddr , v , port = 1 )
208
220
209
221
start_time = counter
210
222
print ('# start time = %d' % start_time )
@@ -227,14 +239,19 @@ def ctrl():
227
239
all_ok = True
228
240
for y in range (matrix_size ):
229
241
for x in range (matrix_size ):
230
- v = memory .read (
231
- c_offset + (y * matrix_size + x ) * datawidth // 8 , port = 1 )
242
+ v = 0
243
+ v_addr = c_offset + (y * matrix_size + x ) * num_pack
244
+ for l in range (num_pack ):
245
+ v |= memory .read (v_addr + l , port = 1 ) << (mem_datawidth * l )
246
+ v |= ((memory .read (v_addr + l , port = 1 )
247
+ & ((1 << mem_datawidth ) - 1 ))
248
+ << (mem_datawidth * l ))
232
249
if y == x and vthread .verilog .NotEql (v , (y + 1 ) * 2 ):
233
250
all_ok = False
234
- print ("NG [%d,%d] = %d" % (y , x , v ))
251
+ print ("NG [%d,%d] = %d (expected: %d) " % (y , x , v , ( y + 1 ) * 2 ))
235
252
if y != x and vthread .verilog .NotEql (v , 0 ):
236
253
all_ok = False
237
- print ("NG [%d,%d] = %d" % (y , x , v ))
254
+ print ("NG [%d,%d] = %d (expected: %d) " % (y , x , v , 0 ))
238
255
239
256
if all_ok :
240
257
print ('# verify: PASSED' )
0 commit comments