1
+ // 2 mma + pipeline + simplify
2
+
1
3
// A100 PCIE 80GB
4
+ // Test performance using shape M=5376, N=5376, K=2048
5
+ // Running cost of CUDA kernel is 1.21901ms
6
+ // TFLOPS: 97.1117
2
7
3
8
// 3090
4
9
@@ -114,19 +119,17 @@ __device__ void loadFragA(unsigned int *frag, half *smem, int ki)
114
119
// load 64x16
115
120
int tx = threadIdx .x ;
116
121
int tz = threadIdx .z ;
122
+ int row = tz * 64 + tx / 4 ;
123
+ int col = ki * KII + tx % 4 * 2 ;
124
+ half* ptr = smem + row / 16 * (2 * 16 * 16 ) + col / 16 * (16 * 16 ) + row % 16 * 16 + col % 16 ;
117
125
for (int i = 0 ; i < 4 ; ++i)
118
126
{
119
- int row = tz * 64 + i * 16 + tx / 16 * 8 + tx % 8 ;
120
- int col = ki * KII + tx / 8 % 2 * 8 ;
121
- void *ptr = (void *)(smem + row / 16 * (2 * 16 * 16 ) + col / 16 * (16 * 16 ) + row % 16 * 16 + col % 16 );
122
- uint32_t smem_ptr;
123
- asm (
124
- " { .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n "
125
- : " =r" (smem_ptr)
126
- : " l" (ptr));
127
- asm volatile (" ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n "
128
- : " =r" (frag[i * 4 + 0 ]), " =r" (frag[i * 4 + 1 ]), " =r" (frag[i * 4 + 2 ]), " =r" (frag[i * 4 + 3 ])
129
- : " r" (smem_ptr));
127
+ frag[i * 4 + 0 ] = *(reinterpret_cast <unsigned int *>(ptr));
128
+ frag[i * 4 + 1 ] = *(reinterpret_cast <unsigned int *>(ptr + 8 ));
129
+
130
+ frag[i * 4 + 2 ] = *(reinterpret_cast <unsigned int *>(ptr + 8 * 16 ));
131
+ frag[i * 4 + 3 ] = *(reinterpret_cast <unsigned int *>(ptr + 8 * 16 + 8 ));
132
+ ptr += 16 * 16 * 2 ;
130
133
}
131
134
}
132
135
@@ -136,19 +139,17 @@ __device__ void loadFragB(unsigned int *frag, half *smem, int ki)
136
139
// load 64x16
137
140
int tx = threadIdx .x ;
138
141
int ty = threadIdx .y ;
142
+ int row = ty * 64 + tx / 4 ;
143
+ int col = ki * KII + tx % 4 * 2 ;
144
+ half* ptr = smem + row / 16 * (2 * 16 * 16 ) + col / 16 * (16 * 16 ) + row % 16 * 16 + col % 16 ;
139
145
for (int i = 0 ; i < 4 ; ++i)
140
146
{
141
- int row = ty * 64 + i * 16 + tx / 16 * 8 + tx % 8 ;
142
- int col = ki * KII + tx / 8 % 2 * 8 ;
143
- void *ptr = (void *)(smem + row / 16 * (2 * 16 * 16 ) + col / 16 * (16 * 16 ) + row % 16 * 16 + col % 16 );
144
- uint32_t smem_ptr;
145
- asm (
146
- " { .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n "
147
- : " =r" (smem_ptr)
148
- : " l" (ptr));
149
- asm volatile (" ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n "
150
- : " =r" (frag[i * 4 + 0 ]), " =r" (frag[i * 4 + 1 ]), " =r" (frag[i * 4 + 2 ]), " =r" (frag[i * 4 + 3 ])
151
- : " r" (smem_ptr));
147
+ frag[i * 4 + 0 ] = *(reinterpret_cast <unsigned int *>(ptr));
148
+ frag[i * 4 + 1 ] = *(reinterpret_cast <unsigned int *>(ptr + 8 ));
149
+
150
+ frag[i * 4 + 2 ] = *(reinterpret_cast <unsigned int *>(ptr + 8 * 16 ));
151
+ frag[i * 4 + 3 ] = *(reinterpret_cast <unsigned int *>(ptr + 8 * 16 + 8 ));
152
+ ptr += 16 * 16 * 2 ;
152
153
}
153
154
}
154
155
@@ -159,22 +160,27 @@ __device__ void storeAccum(float *ptr, float *frag)
159
160
int tx = threadIdx .x ;
160
161
int ty = threadIdx .y ;
161
162
int tz = threadIdx .z ;
163
+ int row = tz * 64 + tx / 4 ;
164
+ int col = ty * 64 + tx % 4 * 2 ;
165
+ float *dst = ptr + row / 16 * (8 * 16 * 16 ) + col / 16 * (16 * 16 ) + row % 16 * 16 + col % 16 ;
162
166
for (int i = 0 ; i < 4 ; ++i)
163
167
{
164
- for (int j = 0 ; j < 4 ; ++j)
165
- {
166
- for (int r = 0 ; r < 2 ; ++r)
167
- {
168
- for (int c = 0 ; c < 2 ; ++c)
169
- {
170
- int row = tz * 64 + i * 16 + r * 8 + tx / 4 ;
171
- int col = ty * 64 + j * 16 + c * 8 + tx % 4 * 2 ;
172
- float *dst = ptr + row / 16 * (8 * 16 * 16 ) + col / 16 * (16 * 16 ) + row % 16 * 16 + col % 16 ;
173
- dst[0 ] = frag[i * 32 + j * 8 + r * 4 + c * 2 ];
174
- dst[1 ] = frag[i * 32 + j * 8 + r * 4 + c * 2 + 1 ];
175
- }
176
- }
168
+ for (int j = 0 ; j < 4 ; ++j) {
169
+ dst[0 ] = frag[i * 32 + j * 8 + 0 * 4 + 0 * 2 ];
170
+ dst[1 ] = frag[i * 32 + j * 8 + 0 * 4 + 0 * 2 + 1 ];
171
+
172
+ dst[0 + 8 ] = frag[i * 32 + j * 8 + 0 * 4 + 1 * 2 ];
173
+ dst[1 + 8 ] = frag[i * 32 + j * 8 + 0 * 4 + 1 * 2 + 1 ];
174
+
175
+ dst[0 + 8 * 16 ] = frag[i * 32 + j * 8 + 1 * 4 + 0 * 2 ];
176
+ dst[1 + 8 * 16 ] = frag[i * 32 + j * 8 + 1 * 4 + 0 * 2 + 1 ];
177
+
178
+ dst[0 + 8 * 16 + 8 ] = frag[i * 32 + j * 8 + 1 * 4 + 1 * 2 ];
179
+ dst[1 + 8 * 16 + 8 ] = frag[i * 32 + j * 8 + 1 * 4 + 1 * 2 + 1 ];
180
+
181
+ dst += 16 * 16 ;
177
182
}
183
+ dst += 4 * 16 * 16 ;
178
184
}
179
185
}
180
186
@@ -221,9 +227,9 @@ __global__ void matmul(half *A, half *B, half *C, int M, int N, int K)
221
227
half *SB4 = SB3 + NI * KI;
222
228
float *SC = reinterpret_cast <float *>(shared_storage);
223
229
224
- unsigned int FragA[16 ];
225
- unsigned int FragB[16 ];
226
- float Accum[128 ] = {0.0 };
230
+ unsigned int FragA[4 * 4 ]; // [4, 4]
231
+ unsigned int FragB[4 * 4 ]; // [4, 4]
232
+ float Accum[4 * 4 * 8 ] = {0.0 }; // [4, 4, 8]
227
233
228
234
// prologue
229
235
loadSmemA (SA1, A, M, K, 0 );
0 commit comments