@@ -200,22 +200,72 @@ __global__ void __launch_bounds__(128) gemm_forward_4bit_cuda_m128n64k32(int spl
200200
201201 for (int i_0_3 = 0 ; i_0_3 < 4 ; ++i_0_3) {
202202 for (int j_0_4 = 0 ; j_0_4 < 2 ; ++j_0_4) {
203+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
204+ {
205+ __asm__ __volatile__ (
206+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
207+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
208+ : " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
209+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]),
210+ " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]),
211+ " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
212+ );
213+ }
214+
215+ {
216+ __asm__ __volatile__ (
217+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
218+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
219+ : " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
220+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]),
221+ " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]),
222+ " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
223+ );
224+ }
203225
226+ {
227+ __asm__ __volatile__ (
228+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
229+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
230+ : " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
231+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]),
232+ " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]),
233+ " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
234+ );
235+ }
236+ {
237+ __asm__ __volatile__ (
238+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
239+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
240+ : " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
241+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]),
242+ " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]),
243+ " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
244+ );
245+ }
246+ #else
204247 {
205248 __asm__ __volatile__ (
206249 " mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
207250 " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n "
208251 : " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
209- : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ]));
252+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]),
253+ " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]),
254+ " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
255+ );
210256 }
211257
212258 {
213259 __asm__ __volatile__ (
214260 " mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
215261 " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n "
216262 : " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
217- : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ]));
263+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]),
264+ " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]),
265+ " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
266+ );
218267 }
268+ #endif
219269 }
220270 }
221271 }
0 commit comments