@@ -105,6 +105,7 @@ __inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(half *shared_warp, int ax0
105105 : " r" (addr));
106106}
107107
108+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
108109__inline__ __device__ void cp_async_cg_A (uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask)
109110{
110111 const int cp_size = 16 ;
@@ -117,14 +118,37 @@ __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__r
117118 " l" (src),
118119 " n" (cp_size));
119120}
121+ #endif
120122
121123__device__ __inline__ void mma_m16n8k16 (float *C_warp, half *A_shared_warp, half *B_shared_warp)
122124{
125+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
126+ __asm__ __volatile__ (
127+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
128+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};"
129+ : " =f" (((float *)C_warp)[0 ]), " =f" (((float *)C_warp)[1 ]), " =f" (((float *)C_warp)[2 ]), " =f" (((float *)C_warp)[3 ])
130+ : " r" (((unsigned *)A_shared_warp)[0 ]), " r" (((unsigned *)A_shared_warp)[1 ]),
131+ " r" (((unsigned *)B_shared_warp)[0 ]),
132+ " f" (((float *)C_warp)[0 ]), " f" (((float *)C_warp)[1 ]), " f" (((float *)C_warp)[2 ]), " f" (((float *)C_warp)[3 ])
133+ );
134+ __asm__ __volatile__ (
135+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
136+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};"
137+ : " =f" (((float *)C_warp)[0 ]), " =f" (((float *)C_warp)[1 ]), " =f" (((float *)C_warp)[2 ]), " =f" (((float *)C_warp)[3 ])
138+ : " r" (((unsigned *)A_shared_warp)[2 ]), " r" (((unsigned *)A_shared_warp)[3 ]),
139+ " r" (((unsigned *)B_shared_warp)[1 ]),
140+ " f" (((float *)C_warp)[0 ]), " f" (((float *)C_warp)[1 ]), " f" (((float *)C_warp)[2 ]), " f" (((float *)C_warp)[3 ])
141+ );
142+ #else
123143 __asm__ __volatile__ (
124144 " mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
125145 " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
126146 : " =f" (((float *)C_warp)[0 ]), " =f" (((float *)C_warp)[1 ]), " =f" (((float *)C_warp)[2 ]), " =f" (((float *)C_warp)[3 ])
127- : " r" (((unsigned *)A_shared_warp)[0 ]), " r" (((unsigned *)A_shared_warp)[1 ]), " r" (((unsigned *)A_shared_warp)[2 ]), " r" (((unsigned *)A_shared_warp)[3 ]), " r" (((unsigned *)B_shared_warp)[0 ]), " r" (((unsigned *)B_shared_warp)[1 ]), " f" (((float *)C_warp)[0 ]), " f" (((float *)C_warp)[1 ]), " f" (((float *)C_warp)[2 ]), " f" (((float *)C_warp)[3 ]));
147+ : " r" (((unsigned *)A_shared_warp)[0 ]), " r" (((unsigned *)A_shared_warp)[1 ]), " r" (((unsigned *)A_shared_warp)[2 ]), " r" (((unsigned *)A_shared_warp)[3 ]),
148+ " r" (((unsigned *)B_shared_warp)[0 ]), " r" (((unsigned *)B_shared_warp)[1 ]),
149+ " f" (((float *)C_warp)[0 ]), " f" (((float *)C_warp)[1 ]), " f" (((float *)C_warp)[2 ]), " f" (((float *)C_warp)[3 ])
150+ );
151+ #endif
128152}
129153
130154template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
@@ -148,12 +172,14 @@ __device__ __inline__ void global_to_share_one_stage_A(half *src, half *dst, int
148172 int ld_col_swizzled = (ld_col ^ (ld_row) & 7 ) * PACK_SIZE;
149173 void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
150174 uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
175+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
151176 if constexpr (STAGES > 1 )
152177 {
153178 uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
154179 cp_async_cg_A (addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
155180 }
156181 else
182+ #endif
157183 {
158184 if (local_mask & (ld_row + cta_offset_m < global_nrows))
159185 *(uint4 *)dst_ptr = *src_ptr;
@@ -183,12 +209,14 @@ __device__ __inline__ void global_to_share_one_stage_B(half *src, half *dst, int
183209 int ld_col_swizzled = ld_col ^ (ld_row % 2 ) & 7 ;
184210 void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
185211 uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
212+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
186213 if constexpr (STAGES > 1 )
187214 {
188215 uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
189216 cp_async_cg_A (addr, src_ptr, local_mask);
190217 }
191218 else
219+ #endif
192220 {
193221 if (local_mask)
194222 *(uint4 *)dst_ptr = *src_ptr;
@@ -212,6 +240,7 @@ __device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst
212240 uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx .x / threads_per_row) * global_ncols + (threadIdx .x % threads_per_row) * PACK_SIZE);
213241 void *dst_ptr_z = (void *)(dst_z + (threadIdx .x / threads_per_row) * kSmemCol + (threadIdx .x % threads_per_row) * PACK_SIZE);
214242 uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx .x / threads_per_row) * global_ncols + (threadIdx .x % threads_per_row) * PACK_SIZE);
243+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
215244 if (STAGES > 1 )
216245 {
217246 uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
@@ -220,6 +249,7 @@ __device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst
220249 cp_async_cg_A (addr_z, src_ptr_z, local_mask);
221250 }
222251 else
252+ #endif
223253 {
224254 if (local_mask)
225255 {
@@ -606,12 +636,14 @@ __device__ __inline__ void global_to_share_one_stage_A_T2(half *src, half *dst,
606636 int ld_col_swizzled = (ld_col ^ (ld_row) & 7 ) * PACK_SIZE;
607637 void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
608638 uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
639+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
609640 if constexpr (STAGES > 1 )
610641 {
611642 uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
612643 cp_async_cg_A (addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
613644 }
614645 else
646+ #endif
615647 {
616648 if (local_mask & (ld_row + cta_offset_m < global_nrows))
617649 *(uint4 *)dst_ptr = *src_ptr;
@@ -641,12 +673,14 @@ __device__ __inline__ void global_to_share_one_stage_B_T2(half *src, half *dst,
641673 int ld_col_swizzled = ld_col ^ (ld_row % 2 ) & 7 ;
642674 void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
643675 uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE);
676+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
644677 if constexpr (STAGES > 1 )
645678 {
646679 uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
647680 cp_async_cg_A (addr, src_ptr, local_mask);
648681 }
649682 else
683+ #endif
650684 {
651685 if (local_mask)
652686 *(uint4 *)dst_ptr = *src_ptr;
@@ -669,6 +703,7 @@ __device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *
669703 uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx .x % threads_per_row) * PACK_SIZE);
670704 void *dst_ptr_z = (void *)(dst_z + (threadIdx .x % threads_per_row) * PACK_SIZE);
671705 uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx .x % threads_per_row) * PACK_SIZE);
706+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
672707 if (STAGES > 1 )
673708 {
674709 uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
@@ -677,6 +712,7 @@ __device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *
677712 cp_async_cg_A (addr_z, src_ptr_z, local_mask);
678713 }
679714 else
715+ #endif
680716 {
681717 if (local_mask)
682718 {
0 commit comments