@@ -1933,7 +1933,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
19331933// rowStats [rows]
19341934// out [rows, cols]
19351935template <typename T, int THREADS, int SPARSE_DECOMP>
1936- __launch_bounds__ (1024 , BNB_MAX_THREADS_PER_SM / 1024 )
1936+ __launch_bounds__ (1024 , BNB_MAX_THREADS_PER_CU / 1024 )
19371937__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t * out, float * rowStats, float threshold, int rows, int cols) {
19381938
19391939 // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
@@ -1997,7 +1997,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
19971997}
19981998
19991999template <typename T, int THREADS, int SPARSE_DECOMP>
2000- __launch_bounds__ (1024 , BNB_MAX_THREADS_PER_SM / 1024 )
2000+ __launch_bounds__ (1024 , BNB_MAX_THREADS_PER_CU / 1024 )
20012001__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
20022002 using BlockReduceT = hipcub::BlockReduce<float , THREADS>;
20032003
@@ -2109,7 +2109,6 @@ __global__ void kdequant_mm_int32_fp16(
21092109#define DENORM 1 .0f /127 .0f
21102110#define MAX_SPARSE_COUNT 32
21112111#define SMEM_SIZE 8 *256
2112- #define WARP_SIZE warpSize
21132112template <typename T, int SPMM_ITEMS, int BITS>
21142113__global__ void kspmm_coo_very_sparse_naive (int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
21152114{
@@ -2130,9 +2129,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
21302129 const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1 ];
21312130 const int local_row_idx = rowidx[offset];
21322131
2133- const int warp_id = threadIdx.x / WARP_SIZE ;
2134- const int warp_idx = threadIdx.x % WARP_SIZE ;
2135- const int warp_offset = (warp_id*WARP_SIZE )*SPMM_ITEMS;
2132+ const int warp_id = threadIdx.x / BNB_WARP_SIZE ;
2133+ const int warp_idx = threadIdx.x % BNB_WARP_SIZE ;
2134+ const int warp_offset = (warp_id*BNB_WARP_SIZE )*SPMM_ITEMS;
21362135 const int num_items = BITS == 8 ? 8 : 8 ;
21372136 int idx_col_B = warp_offset;
21382137 int local_idx_col_B_offset = 0 ;
@@ -2152,7 +2151,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
21522151 }
21532152
21542153 // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
2155- // we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
2154+ // we expect each warp to be SPMM_ITEMS*BNB_WARP_SIZE apart
21562155 // we have a total of 128 bytes for the bank with a bank size of 4 bytes
21572156 // added 3 bytes = 6 values between warps should reduce bank conflicts
21582157 __shared__ half smem_dequant_stats[SMEM_SIZE];
@@ -2705,16 +2704,16 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
27052704{
27062705
27072706 // per threadblock:
2708- // load step-by-step in chunks of [warp_size ,warps]: 1xwarp_size * [warp_size ,warps] -> [1,warps]
2707+ // load step-by-step in chunks of [BNB_WARP_SIZE ,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE ,warps] -> [1,warps]
27092708 // 4 warps -> 4 loads per iter
2710- // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
2711- typedef hipcub::WarpReduce<float , warpSize > WarpReduce;
2712- __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize ];
2713-
2714- const int warp_idx = threadIdx.x / warpSize ;
2715- const int warp_lane = threadIdx.x % warpSize ;
2716- const int row_B = (THREADS/warpSize )*blockIdx.x + warp_idx;
2717- const int offset_B = ldb* row_B;
2709+ // 1xBNB_WARP_SIZE * BNB_WARP_SIZEx4 -> 1x4 outputs per thread block
2710+ typedef hipcub::WarpReduce<float , BNB_WARP_SIZE > WarpReduce;
2711+ __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE ];
2712+
2713+ const int warp_idx = threadIdx.x / BNB_WARP_SIZE ;
2714+ const int warp_lane = threadIdx.x % BNB_WARP_SIZE ;
2715+ const int row_B = (THREADS/BNB_WARP_SIZE )*blockIdx.x + warp_idx;
2716+ const int offset_B = ldb * row_B;
27182717 const int num_values_8bit = num_values_4bit/2 ;
27192718 float local_C = 0 .0f ;
27202719
@@ -2732,7 +2731,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
27322731
27332732 // A: [1, K]
27342733 // B: [M, K]
2735- for (int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize *num_values_4bit)
2734+ for (int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE *num_values_4bit)
27362735 {
27372736 const int inner_idx_halved = inner_idx/2 ;
27382737
@@ -3044,7 +3043,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
30443043MAKE_kQuantizeBlockwise(half, 512 , 2 , 0 , General8bit)
30453044MAKE_kQuantizeBlockwise(half, 256 , 2 , 0 , General8bit)
30463045MAKE_kQuantizeBlockwise(half, 128 , 2 , 0 , General8bit)
3047- #if WARP_SIZE == 32
3046+ #if BNB_WARP_SIZE == 32
30483047 MAKE_kQuantizeBlockwise (half, 64 , 2 , 0 , General8bit)
30493048#endif
30503049
@@ -3054,7 +3053,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
30543053MAKE_kQuantizeBlockwise(half, 512 , 2 , 0 , FP4)
30553054MAKE_kQuantizeBlockwise(half, 256 , 2 , 0 , FP4)
30563055MAKE_kQuantizeBlockwise(half, 128 , 2 , 0 , FP4)
3057- #if WARP_SIZE == 32
3056+ #if BNB_WARP_SIZE == 32
30583057 MAKE_kQuantizeBlockwise (half, 64 , 2 , 0 , FP4)
30593058#endif
30603059
@@ -3064,7 +3063,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
30643063MAKE_kQuantizeBlockwise(half, 512 , 2 , 0 , NF4)
30653064MAKE_kQuantizeBlockwise(half, 256 , 2 , 0 , NF4)
30663065MAKE_kQuantizeBlockwise(half, 128 , 2 , 0 , NF4)
3067- #if WARP_SIZE == 32
3066+ #if BNB_WARP_SIZE == 32
30683067 MAKE_kQuantizeBlockwise (half, 64 , 2 , 0 , NF4)
30693068#endif
30703069
@@ -3075,7 +3074,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
30753074MAKE_kQuantizeBlockwise(float , 512 , 2 , 0 , General8bit)
30763075MAKE_kQuantizeBlockwise(float , 256 , 2 , 0 , General8bit)
30773076MAKE_kQuantizeBlockwise(float , 128 , 2 , 0 , General8bit)
3078- #if WARP_SIZE == 32
3077+ #if BNB_WARP_SIZE == 32
30793078 MAKE_kQuantizeBlockwise (float , 64 , 2 , 0 , General8bit)
30803079#endif
30813080
@@ -3085,7 +3084,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
30853084MAKE_kQuantizeBlockwise(float , 512 , 2 , 0 , FP4)
30863085MAKE_kQuantizeBlockwise(float , 256 , 2 , 0 , FP4)
30873086MAKE_kQuantizeBlockwise(float , 128 , 2 , 0 , FP4)
3088- #if WARP_SIZE == 32
3087+ #if BNB_WARP_SIZE == 32
30893088 MAKE_kQuantizeBlockwise (float , 64 , 2 , 0 , FP4)
30903089#endif
30913090
@@ -3095,7 +3094,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
30953094MAKE_kQuantizeBlockwise(float , 512 , 2 , 0 , NF4)
30963095MAKE_kQuantizeBlockwise(float , 256 , 2 , 0 , NF4)
30973096MAKE_kQuantizeBlockwise(float , 128 , 2 , 0 , NF4)
3098- #if WARP_SIZE == 32
3097+ #if BNB_WARP_SIZE == 32
30993098 MAKE_kQuantizeBlockwise (float , 64 , 2 , 0 , NF4)
31003099#endif
31013100
@@ -3106,7 +3105,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
31063105MAKE_kQuantizeBlockwise(hip_bfloat16, 512 , 2 , 0 , General8bit)
31073106MAKE_kQuantizeBlockwise(hip_bfloat16, 256 , 2 , 0 , General8bit)
31083107MAKE_kQuantizeBlockwise(hip_bfloat16, 128 , 2 , 0 , General8bit)
3109- #if WARP_SIZE == 32
3108+ #if BNB_WARP_SIZE == 32
31103109 MAKE_kQuantizeBlockwise (hip_bfloat16, 64 , 2 , 0 , General8bit)
31113110#endif
31123111
@@ -3116,7 +3115,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4)
31163115MAKE_kQuantizeBlockwise(hip_bfloat16, 512 , 2 , 0 , FP4)
31173116MAKE_kQuantizeBlockwise(hip_bfloat16, 256 , 2 , 0 , FP4)
31183117MAKE_kQuantizeBlockwise(hip_bfloat16, 128 , 2 , 0 , FP4)
3119- #if WARP_SIZE == 32
3118+ #if BNB_WARP_SIZE == 32
31203119 MAKE_kQuantizeBlockwise (hip_bfloat16, 64 , 2 , 0 , FP4)
31213120#endif
31223121
@@ -3126,7 +3125,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4)
31263125MAKE_kQuantizeBlockwise(hip_bfloat16, 512 , 2 , 0 , NF4)
31273126MAKE_kQuantizeBlockwise(hip_bfloat16, 256 , 2 , 0 , NF4)
31283127MAKE_kQuantizeBlockwise(hip_bfloat16, 128 , 2 , 0 , NF4)
3129- #if WARP_SIZE == 32
3128+ #if BNB_WARP_SIZE == 32
31303129 MAKE_kQuantizeBlockwise (hip_bfloat16, 64 , 2 , 0 , NF4)
31313130#endif
31323131
0 commit comments