Skip to content

Commit 7dd7b88

Browse files
committed
Reuse BNB_WARP_SIZE macro
1 parent 8c24b4d commit 7dd7b88

File tree

3 files changed

+32
-29
lines changed

3 files changed

+32
-29
lines changed

csrc/common_hip.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
#pragma once
22

3-
#define BNB_WARP_SIZE warpSize
3+
#ifdef __GFX9__
4+
#define BNB_WARP_SIZE 64
5+
#else
6+
#define BNB_WARP_SIZE 32
7+
#endif
48

59
// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs
6-
#define BNB_MAX_THREADS_PER_SM 2048
10+
#define BNB_MAX_THREADS_PER_CU 2048
711
#define BNB_BF16_AVAILABLE true

csrc/kernels.hip

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,7 +1933,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
19331933
// rowStats [rows]
19341934
// out [rows, cols]
19351935
template<typename T, int THREADS, int SPARSE_DECOMP>
1936-
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
1936+
__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024)
19371937
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
19381938

19391939
// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
@@ -1997,7 +1997,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
19971997
}
19981998

19991999
template<typename T, int THREADS, int SPARSE_DECOMP>
2000-
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
2000+
__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024)
20012001
__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
20022002
using BlockReduceT = hipcub::BlockReduce<float, THREADS>;
20032003

@@ -2109,7 +2109,6 @@ __global__ void kdequant_mm_int32_fp16(
21092109
#define DENORM 1.0f/127.0f
21102110
#define MAX_SPARSE_COUNT 32
21112111
#define SMEM_SIZE 8*256
2112-
#define WARP_SIZE warpSize
21132112
template <typename T, int SPMM_ITEMS, int BITS>
21142113
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
21152114
{
@@ -2130,9 +2129,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
21302129
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
21312130
const int local_row_idx = rowidx[offset];
21322131

2133-
const int warp_id = threadIdx.x / WARP_SIZE;
2134-
const int warp_idx = threadIdx.x % WARP_SIZE;
2135-
const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS;
2132+
const int warp_id = threadIdx.x / BNB_WARP_SIZE;
2133+
const int warp_idx = threadIdx.x % BNB_WARP_SIZE;
2134+
const int warp_offset = (warp_id*BNB_WARP_SIZE)*SPMM_ITEMS;
21362135
const int num_items = BITS == 8 ? 8 : 8;
21372136
int idx_col_B = warp_offset;
21382137
int local_idx_col_B_offset = 0;
@@ -2152,7 +2151,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
21522151
}
21532152

21542153
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
2155-
// we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
2154+
// we expect each warp to be SPMM_ITEMS*BNB_WARP_SIZE apart
21562155
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
21572156
// added 3 bytes = 6 values between warps should reduce bank conflicts
21582157
__shared__ half smem_dequant_stats[SMEM_SIZE];
@@ -2705,16 +2704,16 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
27052704
{
27062705

27072706
// per threadblock:
2708-
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
2707+
// load step-by-step in chunks of [BNB_WARP_SIZE,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE,warps] -> [1,warps]
27092708
// 4 warps -> 4 loads per iter
2710-
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
2711-
typedef hipcub::WarpReduce<float, warpSize> WarpReduce;
2712-
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize];
2713-
2714-
const int warp_idx = threadIdx.x / warpSize;
2715-
const int warp_lane = threadIdx.x % warpSize;
2716-
const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx;
2717-
const int offset_B = ldb*row_B;
2709+
// 1xBNB_WARP_SIZE * BNB_WARP_SIZEx4 -> 1x4 outputs per thread block
2710+
typedef hipcub::WarpReduce<float, BNB_WARP_SIZE> WarpReduce;
2711+
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE];
2712+
2713+
const int warp_idx = threadIdx.x / BNB_WARP_SIZE;
2714+
const int warp_lane = threadIdx.x % BNB_WARP_SIZE;
2715+
const int row_B = (THREADS/BNB_WARP_SIZE)*blockIdx.x + warp_idx;
2716+
const int offset_B = ldb * row_B;
27182717
const int num_values_8bit = num_values_4bit/2;
27192718
float local_C = 0.0f;
27202719

@@ -2732,7 +2731,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
27322731

27332732
// A: [1, K]
27342733
// B: [M, K]
2735-
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit)
2734+
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE*num_values_4bit)
27362735
{
27372736
const int inner_idx_halved = inner_idx/2;
27382737

@@ -3044,7 +3043,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
30443043
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
30453044
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
30463045
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
3047-
#if WARP_SIZE == 32
3046+
#if BNB_WARP_SIZE == 32
30483047
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
30493048
#endif
30503049

@@ -3054,7 +3053,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
30543053
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
30553054
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
30563055
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
3057-
#if WARP_SIZE == 32
3056+
#if BNB_WARP_SIZE == 32
30583057
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
30593058
#endif
30603059

@@ -3064,7 +3063,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
30643063
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
30653064
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
30663065
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
3067-
#if WARP_SIZE == 32
3066+
#if BNB_WARP_SIZE == 32
30683067
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
30693068
#endif
30703069

@@ -3075,7 +3074,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
30753074
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
30763075
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
30773076
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
3078-
#if WARP_SIZE == 32
3077+
#if BNB_WARP_SIZE == 32
30793078
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
30803079
#endif
30813080

@@ -3085,7 +3084,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
30853084
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
30863085
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
30873086
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
3088-
#if WARP_SIZE == 32
3087+
#if BNB_WARP_SIZE == 32
30893088
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
30903089
#endif
30913090

@@ -3095,7 +3094,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
30953094
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
30963095
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
30973096
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
3098-
#if WARP_SIZE == 32
3097+
#if BNB_WARP_SIZE == 32
30993098
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
31003099
#endif
31013100

@@ -3106,7 +3105,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
31063105
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit)
31073106
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit)
31083107
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit)
3109-
#if WARP_SIZE == 32
3108+
#if BNB_WARP_SIZE == 32
31103109
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
31113110
#endif
31123111

@@ -3116,7 +3115,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4)
31163115
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4)
31173116
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4)
31183117
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4)
3119-
#if WARP_SIZE == 32
3118+
#if BNB_WARP_SIZE == 32
31203119
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
31213120
#endif
31223121

@@ -3126,7 +3125,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4)
31263125
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4)
31273126
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4)
31283127
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
3129-
#if WARP_SIZE == 32
3128+
#if BNB_WARP_SIZE == 32
31303129
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
31313130
#endif
31323131

csrc/ops.hip

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
693693
//warpsize - 32
694694
int num_blocks = (m+3)/4;
695695
//warpsize - 64
696-
if (warpSize == 64) {
696+
if (BNB_WARP_SIZE == 64) {
697697
num_blocks = (m+1)/2;
698698
}
699699

0 commit comments

Comments
 (0)