From 6be1412307c517243675c2e2e5245d87bccb6735 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 28 Oct 2025 15:02:33 +0000 Subject: [PATCH 01/47] add template to support more dtypes Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 66 ++++++++++++++++++++++++++++++++++++---- csrc/cpu_ops.h | 11 ++++++- csrc/pythonInterface.cpp | 14 ++++++++- 3 files changed, 83 insertions(+), 8 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 5c2bc6332..0aadf596e 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -1,15 +1,26 @@ #include #include +#include #include using namespace BinSearch; -void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n) { - for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { - long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; - long long block_end = block_idx + valid_items; - for (long long i = block_idx; i < block_end; i++) - out[i] = code[A[i]] * absmax[block_idx / blocksize]; +template +void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) { + switch (DATA_TYPE) { + case General8bit: + #pragma omp parallel for + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + for (long long i = block_idx; i < block_end; i++) + out[i] = static_cast(code[A[i]] * absmax[block_idx / blocksize]); + } + case NF4: + return; + case FP4: + return; + break; } } @@ -59,3 +70,46 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long threads[i].join(); } } + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); + +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); + +// template void gemv_4bit_inference( +// int m, int n, int k, at::Half* A, unsigned char* B, float* absmax, float* datatype, at::Half* out, +// int lda, int ldb, int ldc, int blocksize); + +// template void gemv_4bit_inference( +// int m, int n, int k, at::BFloat16* A, unsigned char* B, float* absmax, float* datatype, at::BFloat16* out, +// int lda, int ldb, int ldc, int blocksize); + +// template void gemv_4bit_inference( +// int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, +// int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 3c10e6d13..72f759497 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -3,8 +3,17 @@ #include #include +#include void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); -void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +template +void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 28121240f..8bf32417f 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -845,6 +845,18 @@ void cquantize_blockwise_cpu_fp32( void cdequantize_blockwise_cpu_fp32( float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n ) { - dequantize_cpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_bf16( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_fp16( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } } From 252ac0f84af0f0c425f7b1075c2ee63b60b1ed6e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 28 Oct 2025 15:04:45 +0000 Subject: [PATCH 02/47] update cmake list Signed-off-by: jiqing-feng --- CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c133e09f..952be8a04 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -243,6 +243,7 @@ elseif(BUILD_XPU) set(CMAKE_CXX_COMPILER icx) endif() else() + find_package(Torch REQUIRED) string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) endif() @@ -317,7 +318,9 @@ if(BUILD_XPU) set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20) target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS}) target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS}) - +else() + target_link_options(bitsandbytes PRIVATE ${TORCH_LIBRARIES}) + include_directories(${TORCH_INCLUDE_DIRS}) endif() if(WIN32) From f98c9e5d98ffc2f6332665971bbd6cf08c8a3003 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 28 Oct 2025 15:14:48 +0000 Subject: [PATCH 03/47] fix typo Signed-off-by: jiqing-feng --- csrc/pythonInterface.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 8bf32417f..eaaa953f8 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -849,13 +849,13 @@ void cdequantize_blockwise_cpu_fp32( } void cdequantize_blockwise_cpu_bf16( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } From 902bf359f51705ffa4d7f2ba527caaf5be77782f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 28 Oct 2025 18:17:23 +0000 Subject: [PATCH 04/47] fix compile cpu Signed-off-by: jiqing-feng --- CMakeLists.txt | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 952be8a04..808ade86f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,9 +78,16 @@ else() set(BUILD_HIP OFF) set(BUILD_MPS OFF) set(BUILD_XPU OFF) + set(BUILD_CPU ON) endif() +if (BUILD_CPU) + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + find_package(Torch REQUIRED) +endif() + if(BUILD_CUDA) # NVCC normally will only work with MSVC up to 1939. VS2022 17.10+ starts using versions 1940+. # Workaround: use --allow-unsupported-compiler @@ -242,10 +249,13 @@ elseif(BUILD_XPU) if(WIN32) set(CMAKE_CXX_COMPILER icx) endif() -else() +elseif(BUILD_CPU) find_package(Torch REQUIRED) string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) +else() + string(APPEND BNB_OUTPUT_NAME "_cpu") + set(GPU_SOURCES) endif() @@ -263,6 +273,9 @@ add_library(bitsandbytes SHARED ${SRC_FILES}) target_compile_features(bitsandbytes PUBLIC cxx_std_17) target_include_directories(bitsandbytes PUBLIC csrc include) +if (BUILD_CPU) + target_link_libraries(bitsandbytes "${TORCH_LIBRARIES}") +endif() if(BUILD_CUDA) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) @@ -318,9 +331,7 @@ if(BUILD_XPU) set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20) target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS}) target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS}) -else() - target_link_options(bitsandbytes PRIVATE ${TORCH_LIBRARIES}) - include_directories(${TORCH_INCLUDE_DIRS}) + endif() if(WIN32) From fef8459f52c21b38f50ad7250f9ef9f4013b7e31 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 29 Oct 2025 09:36:04 +0000 Subject: [PATCH 05/47] make different dtype works Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index e295cc2a3..a69d89f3a 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -76,10 +76,8 @@ def _( torch._check_is_size(blocksize) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - # Only FP32 has c++ kernrl + out = torch.empty_like(A, dtype=dtype) if dtype == torch.float32: - out = torch.empty_like(A, dtype=dtype) - lib.cdequantize_blockwise_cpu_fp32( get_ptr(code), get_ptr(A), @@ -88,6 +86,24 @@ def _( ct.c_longlong(blocksize), ct.c_longlong(A.numel()), ) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_cpu_bf16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif dtype == torch.float16: + lib.cdequantize_blockwise_cpu_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) else: out = code[A.reshape(-1).int()] blocks = out.shape[-1] // blocksize From 55cbaa0d0809711b1df1b5bb44da49312bc535d0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 29 Oct 2025 09:46:18 +0000 Subject: [PATCH 06/47] use bf16 on CPU Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index ece18caa3..1cc24bb46 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -432,6 +432,9 @@ def matmul_4bit( bias: Optional[torch.Tensor] = None, ): assert quant_state is not None + # Change dtype to bfloat16 on CPU + if A.device.type == "cpu" and quant_state.dtype == torch.float32: + quant_state.dtype = torch.bfloat16 if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: From bbef95b3bab9168879e99a891ceecc924c140a14 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 29 Oct 2025 09:52:52 +0000 Subject: [PATCH 07/47] fix state2 dtype Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 1cc24bb46..0aba814c1 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -433,8 +433,11 @@ def matmul_4bit( ): assert quant_state is not None # Change dtype to bfloat16 on CPU - if A.device.type == "cpu" and quant_state.dtype == torch.float32: - quant_state.dtype = torch.bfloat16 + if A.device.type == "cpu": + if quant_state.dtype == torch.float32: + quant_state.dtype = torch.bfloat16 + if hasattr(quant_state, "state2") and quant_state.state2.dtype == torch.float32: + quant_state.state2.dtype = torch.bfloat16 if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: From e8425135de33f01d75cf1479ea14a387449fd268 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 15:27:34 +0000 Subject: [PATCH 08/47] remove torch Signed-off-by: jiqing-feng --- CMakeLists.txt | 8 ++---- csrc/cpu_ops.cpp | 63 ++++++++++++++++++++++++++++-------------------- csrc/cpu_ops.h | 10 ++++++++ 3 files changed, 49 insertions(+), 32 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 808ade86f..c5abfca78 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,7 +85,7 @@ endif() if (BUILD_CPU) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) - find_package(Torch REQUIRED) + find_package(OpenMP) endif() if(BUILD_CUDA) @@ -249,10 +249,6 @@ elseif(BUILD_XPU) if(WIN32) set(CMAKE_CXX_COMPILER icx) endif() -elseif(BUILD_CPU) - find_package(Torch REQUIRED) - string(APPEND BNB_OUTPUT_NAME "_cpu") - set(GPU_SOURCES) else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) @@ -274,7 +270,7 @@ target_compile_features(bitsandbytes PUBLIC cxx_std_17) target_include_directories(bitsandbytes PUBLIC csrc include) if (BUILD_CPU) - target_link_libraries(bitsandbytes "${TORCH_LIBRARIES}") + target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX) endif() if(BUILD_CUDA) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 0aadf596e..ec07a593b 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,25 +5,36 @@ using namespace BinSearch; + template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) { - switch (DATA_TYPE) { - case General8bit: +void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, + long long blocksize, long long n) { + if (DATA_TYPE > 0) { #pragma omp parallel for for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; long long block_end = block_idx + valid_items; - for (long long i = block_idx; i < block_end; i++) - out[i] = static_cast(code[A[i]] * absmax[block_idx / blocksize]); + float scale = absmax[block_idx / blocksize]; + for (long long i = block_idx; i < block_end; i++) { + float v = code[A[i]] * scale; + if constexpr (std::is_same::value) { + out[i] = float_to_bf16(v); + } else { + out[i] = static_cast(v); + } + } } - case NF4: - return; - case FP4: - return; - break; + } else { + // 4bit path + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, n); } } +template +void dequantizeBlockwise4bitCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) { + return; +} + void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below @@ -84,30 +95,30 @@ template void dequantizeBlockwiseCpu( template void dequantizeBlockwiseCpu( float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n); +template void dequantizeBlockwiseCpu( + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); -// template void gemv_4bit_inference( -// int m, int n, int k, at::Half* A, unsigned char* B, float* absmax, float* datatype, at::Half* out, +// template void gemv_4bit_inference( +// int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, // int lda, int ldb, int ldc, int blocksize); -// template void gemv_4bit_inference( -// int m, int n, int k, at::BFloat16* A, unsigned char* B, float* absmax, float* datatype, at::BFloat16* out, +// template void gemv_4bit_inference( +// int m, int n, int k, bf16_t* A, unsigned char* B, float* absmax, float* datatype, bf16_t* out, // int lda, int ldb, int ldc, int blocksize); // template void gemv_4bit_inference( diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 72f759497..37026939a 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -4,6 +4,7 @@ #include #include #include +#include void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); @@ -13,7 +14,16 @@ typedef enum DataType_t { NF4 = 2, } DataType_t; +using fp16_t = _Float16; + +struct bf16_t { + uint16_t v; +}; + template void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); +template +void dequantizeBlockwise4bitCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) + #endif From d4473fa9314dfbb11135f7566b7a4a09acf21750 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 15:30:19 +0000 Subject: [PATCH 09/47] rm torch Signed-off-by: jiqing-feng --- csrc/cpu_ops.h | 2 -- csrc/pythonInterface.cpp | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 37026939a..19e4cf909 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -3,8 +3,6 @@ #include #include -#include -#include void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index eaaa953f8..5056ccf0c 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -849,14 +849,14 @@ void cdequantize_blockwise_cpu_fp32( } void cdequantize_blockwise_cpu_bf16( - float* code, unsigned char* A, float* absmax, at::BFloat16* out, long long blocksize, long long n + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( - float* code, unsigned char* A, float* absmax, at::Half* out, long long blocksize, long long n + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } } From dea8dd6377d788aaccabccc350d343c3a41d114f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 15:31:55 +0000 Subject: [PATCH 10/47] enable float to bf16 Signed-off-by: jiqing-feng --- csrc/cpu_ops.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 19e4cf909..85c13a334 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -3,6 +3,9 @@ #include #include +#include +#include +#include void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); @@ -18,6 +21,13 @@ struct bf16_t { uint16_t v; }; +static inline bf16_t float_to_bf16(float x) { + uint32_t bits; + std::memcpy(&bits, &x, 4); + uint32_t r = bits + 0x7FFF + ((bits >> 16) & 1); + return bf16_t{static_cast(r >> 16)}; +} + template void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); From e9bb4fe15ae0dbb73241d9e49a80ddacab1ecdd1 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 15:33:39 +0000 Subject: [PATCH 11/47] rm dequantizeBlockwise4bitCpu Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 9 ++------- csrc/cpu_ops.h | 3 --- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index ec07a593b..205b6824d 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -25,16 +25,11 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out } } } else { - // 4bit path - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, n); + // TODO: enable nf4 and fp4 + return; } } -template -void dequantizeBlockwise4bitCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) { - return; -} - void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 85c13a334..77791b3e6 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -31,7 +31,4 @@ static inline bf16_t float_to_bf16(float x) { template void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); -template -void dequantizeBlockwise4bitCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) - #endif From cdc8d5e02606bb740da660f309be00d543455eaa Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 15:43:03 +0000 Subject: [PATCH 12/47] fix check Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 205b6824d..dafc4c91f 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -9,7 +9,7 @@ using namespace BinSearch; template void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n) { - if (DATA_TYPE > 0) { + if (DATA_TYPE == 0) { #pragma omp parallel for for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; From baacfac22604061533d500625c3698b85d8cf5b1 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 16:10:47 +0000 Subject: [PATCH 13/47] enable dequant 4bit kernel Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 106 +++++++++++++++++++++++++++++++ csrc/cpu_ops.cpp | 25 +++++++- csrc/cpu_ops.h | 88 +++++++++++++++++++++++++ csrc/pythonInterface.cpp | 34 ++++++++++ 4 files changed, 251 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index a69d89f3a..3c4399873 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -115,3 +115,109 @@ def _( out = out.reshape(A.shape) return out + +@register_kernel("bitsandbytes::dequantize_4bit", "cpu") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + # Enable non uint8 dtype + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + + out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) + if quant_type == "fp4": + if dtype == torch.float32: + lib.cdequantize_blockwise_cpu_fp4_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_cpu_fp4_bf16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif dtype == torch.float16: + lib.cdequantize_blockwise_cpu_fp4_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif quant_type == "nf4": + if dtype == torch.float32: + lib.cdequantize_blockwise_cpu_nf4_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_cpu_nf4_bf16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif dtype == torch.float16: + lib.cdequantize_blockwise_cpu_nf4_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + else: + A = A.reshape(-1) + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # code is fp32, cast to dtype to avoid the mismatch issue + code = CODE[quant_type].to(dtype).to(A.device) + out_dq = code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + if has_rem: + out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) + out[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) + + out = out.reshape(-1, *shape[1:]).to(dtype) + + return out \ No newline at end of file diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index dafc4c91f..83fa9db42 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -25,8 +25,29 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out } } } else { - // TODO: enable nf4 and fp4 - return; + #pragma omp parallel for + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + float scale = absmax[block_idx / blocksize]; + for (long long i = block_idx; i * 2 + 1 < block_end; i+=2) { + if (DATA_TYPE == 1) { + float up = dDequantizeFP4(A[i] >> 4) * scale; + float low = dDequantizeFP4(A[i] & 0x0F) * scale; + } elif (DATA_TYPE == 1) { + float up = dDequantizeNF4(A[i] >> 4) * scale; + float low = dDequantizeNF4(A[i] & 0x0F) * scale; + } + + if constexpr (std::is_same::value) { + out[i*2] = float_to_bf16(up); + out[i*2+1] = float_to_bf16(low); + } else { + out[i*2] = static_cast(up); + out[i*2+1] = static_cast(low); + } + } + } } } diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 77791b3e6..9fd111719 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -28,6 +28,94 @@ static inline bf16_t float_to_bf16(float x) { return bf16_t{static_cast(r >> 16)}; } +inline float dDequantizeFP4(unsigned char val) { + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -0.25000000f; + else + return -0.16666667f; + else if ((val & 0b0001) == 1) + return -0.50000000f; + else + return -0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -1.00000000f; + else + return -0.66666667f; + else if ((val & 0b0001) == 1) + return -5.208333333e-03f; + else + return 0.00000000f; + else if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 0.25000000f; + else + return 0.16666667f; + else if ((val & 0b0001) == 1) + return 0.50000000f; + else + return 0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 1.00000000f; + else + return 0.66666667f; + else if ((val & 0b0001) == 1) + return 5.208333333e-03f; + else + return 0.00000000f; +} + +inline float dDequantizeNF4(unsigned char val) { + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) // 1 + if ((val & 0b0010) == 2) // 11 + if ((val & 0b0001) == 1) // 111 + return 1.0f; //*1111 + else + return 0.7229568362236023f; //*1110 + else if ((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; //*1101 + else + return 0.44070982933044434f; //*1100 + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; //*1011 + else + return 0.24611230194568634f; //*1010 + else if ((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; //*1001 + else + return 0.07958029955625534f; //*1000 + + else if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 011 + return 0.0f; //*0111 + else + return -0.09105003625154495f; //*0110 + else if ((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; //*0101 + else + return -0.28444138169288635f; //*0100 + else if ((val & 0b0010) == 2) // 00 + if ((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; //*0011 + else + return -0.5250730514526367f; //*0010 + else if ((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; //*0001 + else + return -1.0f; //*0000 +} + template void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 5056ccf0c..b69679cb7 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -859,4 +859,38 @@ void cdequantize_blockwise_cpu_fp16( ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } +void cdequantize_blockwise_cpu_fp4_fp32( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_fp4_bf16( + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_fp4_fp16( + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} +void cdequantize_blockwise_cpu_nf4_fp32( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_nf4_bf16( + float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_nf4_fp16( + float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n +) { + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); +} } From eec35212b86201b97c48af45dadca2514ceaeb49 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 16:11:53 +0000 Subject: [PATCH 14/47] fix typo Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 83fa9db42..2ccb4b3f9 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -34,7 +34,7 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out if (DATA_TYPE == 1) { float up = dDequantizeFP4(A[i] >> 4) * scale; float low = dDequantizeFP4(A[i] & 0x0F) * scale; - } elif (DATA_TYPE == 1) { + } else if (DATA_TYPE == 1) { float up = dDequantizeNF4(A[i] >> 4) * scale; float low = dDequantizeNF4(A[i] & 0x0F) * scale; } From d7cc1c5e6bf3f4afcfe9b615f7990e200640616c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 16:21:36 +0000 Subject: [PATCH 15/47] fix typo Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 13 +++++++------ csrc/cpu_ops.cpp | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 3c4399873..76ffec650 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence import ctypes as ct import logging @@ -139,7 +140,7 @@ def _( if quant_type == "fp4": if dtype == torch.float32: lib.cdequantize_blockwise_cpu_fp4_fp32( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -148,7 +149,7 @@ def _( ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_fp4_bf16( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -157,7 +158,7 @@ def _( ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp4_fp16( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -167,7 +168,7 @@ def _( elif quant_type == "nf4": if dtype == torch.float32: lib.cdequantize_blockwise_cpu_nf4_fp32( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -176,7 +177,7 @@ def _( ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_nf4_bf16( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -185,7 +186,7 @@ def _( ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( - get_ptr(code), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 2ccb4b3f9..ec7317b7d 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -31,10 +31,11 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; for (long long i = block_idx; i * 2 + 1 < block_end; i+=2) { + float up, low; if (DATA_TYPE == 1) { float up = dDequantizeFP4(A[i] >> 4) * scale; float low = dDequantizeFP4(A[i] & 0x0F) * scale; - } else if (DATA_TYPE == 1) { + } else { float up = dDequantizeNF4(A[i] >> 4) * scale; float low = dDequantizeNF4(A[i] & 0x0F) * scale; } From 124b754e85f6df425c29fdf00ce0f7e4f94452c4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 19:30:26 +0000 Subject: [PATCH 16/47] fix dequantize Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 45 +++++++++++++++++++++++++++----- csrc/cpu_ops.cpp | 22 ++++++++-------- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 76ffec650..d43d13dcd 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -145,7 +145,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_fp4_bf16( @@ -154,7 +154,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp4_fp16( @@ -163,7 +163,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) elif quant_type == "nf4": if dtype == torch.float32: @@ -173,7 +173,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_nf4_bf16( @@ -182,7 +182,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( @@ -191,7 +191,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(out.numel()), ) else: A = A.reshape(-1) @@ -221,4 +221,37 @@ def _( out = out.reshape(-1, *shape[1:]).to(dtype) + return out + +def dequant_nf4_x(A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype,): + out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) + A = A.reshape(-1) + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # code is fp32, cast to dtype to avoid the mismatch issue + code = CODE[quant_type].to(dtype).to(A.device) + out_dq = code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + if has_rem: + out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) + out[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) return out \ No newline at end of file diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index ec7317b7d..cff55a3bf 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -27,25 +27,25 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out } else { #pragma omp parallel for for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { - long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; - long long block_end = block_idx + valid_items; + long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); float scale = absmax[block_idx / blocksize]; - for (long long i = block_idx; i * 2 + 1 < block_end; i+=2) { + for (long long i = 0; i < valid_items; i+=2) { float up, low; + long long index = (i + block_idx) / 2; if (DATA_TYPE == 1) { - float up = dDequantizeFP4(A[i] >> 4) * scale; - float low = dDequantizeFP4(A[i] & 0x0F) * scale; + up = dDequantizeFP4(A[index] >> 4) * scale; + low = dDequantizeFP4(A[index] & 0x0F) * scale; } else { - float up = dDequantizeNF4(A[i] >> 4) * scale; - float low = dDequantizeNF4(A[i] & 0x0F) * scale; + up = dDequantizeNF4(A[index] >> 4) * scale; + low = dDequantizeNF4(A[index] & 0x0F) * scale; } if constexpr (std::is_same::value) { - out[i*2] = float_to_bf16(up); - out[i*2+1] = float_to_bf16(low); + out[i + block_idx] = float_to_bf16(up); + out[i+1 + block_idx] = float_to_bf16(low); } else { - out[i*2] = static_cast(up); - out[i*2+1] = static_cast(low); + out[i + block_idx] = static_cast(up); + out[i+1 + block_idx] = static_cast(low); } } } From 0f918c72cca2d60ee3e5f7272ccf74eb3fe31faf Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:12:53 +0000 Subject: [PATCH 17/47] fix Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 7 +++--- bitsandbytes/backends/cpu/ops.py | 37 ++--------------------------- csrc/cpu_ops.cpp | 37 +++++++++++++++-------------- csrc/cpu_ops.h | 2 +- csrc/pythonInterface.cpp | 18 +++++++------- 5 files changed, 34 insertions(+), 67 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 0aba814c1..158088c97 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -434,10 +434,9 @@ def matmul_4bit( assert quant_state is not None # Change dtype to bfloat16 on CPU if A.device.type == "cpu": - if quant_state.dtype == torch.float32: - quant_state.dtype = torch.bfloat16 - if hasattr(quant_state, "state2") and quant_state.state2.dtype == torch.float32: - quant_state.state2.dtype = torch.bfloat16 + quant_state.dtype = A.dtype + if hasattr(quant_state, "state2"): + quant_state.state2.dtype = A.dtype if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index d43d13dcd..e92c9b3f4 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -135,7 +135,8 @@ def _( # Enable non uint8 dtype if A.dtype != torch.uint8: A = A.view(torch.uint8) - + + A = A.reshape(-1) out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) if quant_type == "fp4": if dtype == torch.float32: @@ -194,7 +195,6 @@ def _( ct.c_longlong(out.numel()), ) else: - A = A.reshape(-1) # Map nf4 to [-1, 1] out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) n = out_dq.numel() @@ -222,36 +222,3 @@ def _( out = out.reshape(-1, *shape[1:]).to(dtype) return out - -def dequant_nf4_x(A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype,): - out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) - A = A.reshape(-1) - # Map nf4 to [-1, 1] - out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) - n = out_dq.numel() - out_dq[1::2] = A & 0xF - out_dq[::2] = A >> 4 - # code is fp32, cast to dtype to avoid the mismatch issue - code = CODE[quant_type].to(dtype).to(A.device) - out_dq = code[out_dq] - - # Apply scales - if out_dq.numel() != n: - assert out_dq.numel() == n + 1 - out_dq = torch.narrow(out_dq, 0, 0, n) - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - rem = n % blocksize - has_rem = rem > 0 - - if has_rem: - out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) - out[n - rem :] = out_dq[n - rem :] * absmax[-1] - else: - out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) - return out \ No newline at end of file diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index cff55a3bf..f2c6f0690 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -7,15 +7,15 @@ using namespace BinSearch; template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, +void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n) { if (DATA_TYPE == 0) { #pragma omp parallel for for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { - long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; - for (long long i = block_idx; i < block_end; i++) { + for (long long i = block_idx; i < block_end; ++i) { float v = code[A[i]] * scale; if constexpr (std::is_same::value) { out[i] = float_to_bf16(v); @@ -29,23 +29,24 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); float scale = absmax[block_idx / blocksize]; - for (long long i = 0; i < valid_items; i+=2) { - float up, low; - long long index = (i + block_idx) / 2; - if (DATA_TYPE == 1) { - up = dDequantizeFP4(A[index] >> 4) * scale; - low = dDequantizeFP4(A[index] & 0x0F) * scale; - } else { - up = dDequantizeNF4(A[index] >> 4) * scale; - low = dDequantizeNF4(A[index] & 0x0F) * scale; - } - + for (long long i = 0; i < valid_items; i += 2) { + long long byte_index = (block_idx + i) >> 1; + unsigned char byte = A[byte_index]; + float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) + : dDequantizeNF4(byte & 0x0F)) * scale; + float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) + : dDequantizeNF4(byte >> 4)) * scale; if constexpr (std::is_same::value) { - out[i + block_idx] = float_to_bf16(up); - out[i+1 + block_idx] = float_to_bf16(low); + out[block_idx + i] = float_to_bf16(v0); } else { - out[i + block_idx] = static_cast(up); - out[i+1 + block_idx] = static_cast(low); + out[block_idx + i] = static_cast(v0); + } + if (i + 1 < valid_items) { + if constexpr (std::is_same::value) { + out[block_idx + i + 1] = float_to_bf16(v1); + } else { + out[block_idx + i + 1] = static_cast(v1); + } } } } diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 9fd111719..3ad7b3ac2 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -117,6 +117,6 @@ inline float dDequantizeNF4(unsigned char val) { } template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, float* absmax, T* out, long long blocksize, long long n); +void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n); #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index b69679cb7..33fcb6041 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -843,53 +843,53 @@ void cquantize_blockwise_cpu_fp32( } void cdequantize_blockwise_cpu_fp32( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_bf16( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_fp32( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_bf16( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_fp16( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_fp32( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_bf16( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_fp16( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } From e1a8b20d262eab013489a6f3b31b8e48a4a3a760 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:15:19 +0000 Subject: [PATCH 18/47] fix Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index f2c6f0690..46e238386 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -105,31 +105,31 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long //============================================================== template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, fp16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, float* absmax, bf16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); // template void gemv_4bit_inference( // int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, From eab45c8565f68224b9f896376a257d561b06f5fc Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:20:42 +0000 Subject: [PATCH 19/47] test Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 39 ++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index e92c9b3f4..e49543867 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -185,6 +185,11 @@ def _( ct.c_longlong(blocksize), ct.c_longlong(out.numel()), ) + out_2 = dequantize_nf4_test(A, absmax, blocksize, quant_type, shape, dtype) + out = out.reshape(shape) + out_2 = out_2.reshape(shape) + if torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): + import pdb; pdb.set_trace() elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( None, @@ -222,3 +227,37 @@ def _( out = out.reshape(-1, *shape[1:]).to(dtype) return out + +def dequantize_nf4_test( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +): + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # code is fp32, cast to dtype to avoid the mismatch issue + code = CODE[quant_type].to(dtype).to(A.device) + out_dq = code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + if has_rem: + out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) + out[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) + + return out From d9f5dd8e215c16e7b12d0bf994b8ba1530bd52b4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:22:49 +0000 Subject: [PATCH 20/47] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index e49543867..6b131be90 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -6,6 +6,7 @@ from bitsandbytes.functional import get_ptr +from ..util import CODE from ..._ops import register_kernel from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib From 070f8a082b623bdccd8755f8ea2152118223e8d8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:23:40 +0000 Subject: [PATCH 21/47] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 6b131be90..ecc744e2d 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -6,7 +6,7 @@ from bitsandbytes.functional import get_ptr -from ..util import CODE +from ..utils import CODE from ..._ops import register_kernel from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib From a84addfe5d6410f5752faba4f12f56c5b0b1e2ee Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:25:48 +0000 Subject: [PATCH 22/47] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index ecc744e2d..d6398b06b 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -189,7 +189,7 @@ def _( out_2 = dequantize_nf4_test(A, absmax, blocksize, quant_type, shape, dtype) out = out.reshape(shape) out_2 = out_2.reshape(shape) - if torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): + if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): import pdb; pdb.set_trace() elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( From c4bb6607767668be20eb00d43aa91297055a4ca9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:33:13 +0000 Subject: [PATCH 23/47] fix Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 46e238386..03c0af795 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -32,10 +32,10 @@ void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, for (long long i = 0; i < valid_items; i += 2) { long long byte_index = (block_idx + i) >> 1; unsigned char byte = A[byte_index]; - float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) - : dDequantizeNF4(byte & 0x0F)) * scale; - float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) + float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) : dDequantizeNF4(byte >> 4)) * scale; + float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) + : dDequantizeNF4(byte & 0x0F)) * scale; if constexpr (std::is_same::value) { out[block_idx + i] = float_to_bf16(v0); } else { From 4ba13fd37f4d741648712abaab35edeab7039dd8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 30 Oct 2025 20:40:55 +0000 Subject: [PATCH 24/47] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index d6398b06b..99bc21ca0 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -137,6 +137,10 @@ def _( if A.dtype != torch.uint8: A = A.view(torch.uint8) + # TODO: support half precision absmax + if absmax.dtype != torch.float32: + absmax = absmax.float() + A = A.reshape(-1) out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) if quant_type == "fp4": From c0d05ec1e03c24c8717de979fe49ca76d0d18733 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:22:02 +0000 Subject: [PATCH 25/47] change input param Signed-off-by: jiqing-feng --- bitsandbytes/autograd/_functions.py | 2 - bitsandbytes/backends/cpu/ops.py | 39 +++++--- csrc/cpu_ops.cpp | 145 +++++++++++++++++++--------- csrc/cpu_ops.h | 2 +- csrc/pythonInterface.cpp | 18 ++-- 5 files changed, 137 insertions(+), 69 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 158088c97..061b4d1b8 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -435,8 +435,6 @@ def matmul_4bit( # Change dtype to bfloat16 on CPU if A.device.type == "cpu": quant_state.dtype = A.dtype - if hasattr(quant_state, "state2"): - quant_state.state2.dtype = A.dtype if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 99bc21ca0..c38c22583 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -86,7 +86,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_bf16( @@ -95,7 +96,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp16( @@ -104,7 +106,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) else: out = code[A.reshape(-1).int()] @@ -141,7 +144,7 @@ def _( if absmax.dtype != torch.float32: absmax = absmax.float() - A = A.reshape(-1) + A = A.reshape(shape[0], shape[1] // 2) out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) if quant_type == "fp4": if dtype == torch.float32: @@ -151,7 +154,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_fp4_bf16( @@ -160,7 +164,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp4_fp16( @@ -169,7 +174,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif quant_type == "nf4": if dtype == torch.float32: @@ -179,7 +185,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_nf4_bf16( @@ -188,7 +195,8 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) out_2 = dequantize_nf4_test(A, absmax, blocksize, quant_type, shape, dtype) out = out.reshape(shape) @@ -202,10 +210,12 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(out.numel()), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), ) else: # Map nf4 to [-1, 1] + A = A.reshape(-1) out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) n = out_dq.numel() out_dq[1::2] = A & 0xF @@ -229,7 +239,7 @@ def _( else: out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) - out = out.reshape(-1, *shape[1:]).to(dtype) + out = out.reshape(-1, *shape[1:]).to(dtype) return out @@ -266,3 +276,10 @@ def dequantize_nf4_test( out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) return out + + +def _reverse_4bit_compress_format(weight: torch.Tensor): + out_1 = (weight & 0xF0) >> 4 + out_2 = (weight & 0xF) << 4 + out = out_1 | out_2 + return out diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 03c0af795..ef5a27729 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -6,53 +6,97 @@ using namespace BinSearch; +// 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. +// DATA_TYPE: 1 = FP4, 2 = NF4 template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, - long long blocksize, long long n) { - if (DATA_TYPE == 0) { - #pragma omp parallel for - for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { - long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); - long long block_end = block_idx + valid_items; - float scale = absmax[block_idx / blocksize]; - for (long long i = block_idx; i < block_end; ++i) { - float v = code[A[i]] * scale; +inline void dequantizeBlockwise4bitCpu(float* code, + unsigned char* A, + const float* absmax, + T* out, + long long blocksize, + long long m, + long long n) { + static_assert(DATA_TYPE == 1 || DATA_TYPE == 2, + "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); + if (blocksize <= 0 || n <= 0) return; + +#if defined(__AVX512F__) && defined(__AVX512BW__) && defined(TEST_BUG) + // AVX512 optimized branch (placeholder) + // DATA_TYPE: 1 = FP4, 2 = NF4 + if (1 == 0) {return;} +#else + // Scalar fallback branch + long long total = m * n; + #pragma omp parallel for + for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { + long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); + float scale = absmax[block_idx / blocksize]; + for (long long i = 0; i < valid_items; i += 2) { + long long byte_index = (block_idx + i) >> 1; + unsigned char byte = A[byte_index]; + + // High nibble first (matches previous code logic) + float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) + : dDequantizeNF4(byte >> 4)) * scale; + // Low nibble second + float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) + : dDequantizeNF4(byte & 0x0F)) * scale; + + if constexpr (std::is_same::value) { + out[block_idx + i] = float_to_bf16(v0); + } else { + out[block_idx + i] = static_cast(v0); + } + + if (i + 1 < valid_items) { if constexpr (std::is_same::value) { - out[i] = float_to_bf16(v); + out[block_idx + i + 1] = float_to_bf16(v1); } else { - out[i] = static_cast(v); + out[block_idx + i + 1] = static_cast(v1); } } } - } else { + } +#endif +} + + +template +void dequantizeBlockwiseCpu(float* code, + unsigned char* A, + const float* absmax, + T* out, + long long blocksize, + long long m, + long long n) { + static_assert(DATA_TYPE == 0 || DATA_TYPE == 1 || DATA_TYPE == 2, + "dequantizeBlockwiseCpu: invalid DATA_TYPE"); + if (blocksize <= 0 || m <= 0 || n <= 0) return; + + if constexpr (DATA_TYPE == 0) { + // 8-bit path + long long total = (m * n) >> 1; #pragma omp parallel for - for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { - long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); + for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { + long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); + long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; - for (long long i = 0; i < valid_items; i += 2) { - long long byte_index = (block_idx + i) >> 1; - unsigned char byte = A[byte_index]; - float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) - : dDequantizeNF4(byte >> 4)) * scale; - float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) - : dDequantizeNF4(byte & 0x0F)) * scale; + for (long long i = block_idx; i < block_end; ++i) { + float v = code[A[i]] * scale; if constexpr (std::is_same::value) { - out[block_idx + i] = float_to_bf16(v0); + out[i] = float_to_bf16(v); } else { - out[block_idx + i] = static_cast(v0); - } - if (i + 1 < valid_items) { - if constexpr (std::is_same::value) { - out[block_idx + i + 1] = float_to_bf16(v1); - } else { - out[block_idx + i + 1] = static_cast(v1); - } + out[i] = static_cast(v); } } } + } else { + // 4-bit helper (FP4 / NF4) + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } } + void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below @@ -105,31 +149,40 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long //============================================================== template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); - + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); +template void dequantizeBlockwise4bitCpu( + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); // template void gemv_4bit_inference( // int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 3ad7b3ac2..0ea071e2d 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -117,6 +117,6 @@ inline float dDequantizeNF4(unsigned char val) { } template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n); +void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 33fcb6041..2ab6920da 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -843,53 +843,53 @@ void cquantize_blockwise_cpu_fp32( } void cdequantize_blockwise_cpu_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_nf4_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); } From 62a16a6e8fb4611508d94c11fb4429a4322e8ea0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:27:19 +0000 Subject: [PATCH 26/47] fix typo Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 5 ++--- csrc/pythonInterface.cpp | 26 +++++++++++--------------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index ef5a27729..8ad2626d2 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -9,8 +9,7 @@ using namespace BinSearch; // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. // DATA_TYPE: 1 = FP4, 2 = NF4 template -inline void dequantizeBlockwise4bitCpu(float* code, - unsigned char* A, +inline void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, @@ -92,7 +91,7 @@ void dequantizeBlockwiseCpu(float* code, } } else { // 4-bit helper (FP4 / NF4) - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } } diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 2ab6920da..7cd74b844 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -845,52 +845,48 @@ void cquantize_blockwise_cpu_fp32( void cdequantize_blockwise_cpu_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_fp4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_fp4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_fp4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_nf4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_nf4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } - void cdequantize_blockwise_cpu_nf4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, n); + dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); } } From d9ad828244b0a30b5640d42cb4a5e9076e17d8b6 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:30:41 +0000 Subject: [PATCH 27/47] fix input param Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 8ad2626d2..e7d80677f 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -169,19 +169,19 @@ template void dequantizeBlockwiseCpu( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); // template void gemv_4bit_inference( // int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, From 09ed6cbf455de0445cbf59bbb77100c7405c6d60 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:38:30 +0000 Subject: [PATCH 28/47] spliut 8bit and 4bit Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 9 ++--- csrc/cpu_ops.cpp | 65 +++++++++++--------------------- csrc/cpu_ops.h | 8 ++-- csrc/pythonInterface.cpp | 18 ++++----- 4 files changed, 38 insertions(+), 62 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index c38c22583..33060718f 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -86,8 +86,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), + ct.c_longlong(A.numel()), ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_bf16( @@ -96,8 +95,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), + ct.c_longlong(A.numel()), ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp16( @@ -106,8 +104,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), + ct.c_longlong(A.numel()), ) else: out = code[A.reshape(-1).int()] diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e7d80677f..091925fca 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -9,13 +9,13 @@ using namespace BinSearch; // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. // DATA_TYPE: 1 = FP4, 2 = NF4 template -inline void dequantizeBlockwise4bitCpu(unsigned char* A, +void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n) { - static_assert(DATA_TYPE == 1 || DATA_TYPE == 2, + static_assert(DATA_TYPE == 0 || DATA_TYPE == 1, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); if (blocksize <= 0 || n <= 0) return; @@ -60,38 +60,29 @@ inline void dequantizeBlockwise4bitCpu(unsigned char* A, } -template -void dequantizeBlockwiseCpu(float* code, +template +void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, - long long m, long long n) { - static_assert(DATA_TYPE == 0 || DATA_TYPE == 1 || DATA_TYPE == 2, - "dequantizeBlockwiseCpu: invalid DATA_TYPE"); - if (blocksize <= 0 || m <= 0 || n <= 0) return; - - if constexpr (DATA_TYPE == 0) { - // 8-bit path - long long total = (m * n) >> 1; - #pragma omp parallel for - for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { - long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); - long long block_end = block_idx + valid_items; - float scale = absmax[block_idx / blocksize]; - for (long long i = block_idx; i < block_end; ++i) { - float v = code[A[i]] * scale; - if constexpr (std::is_same::value) { - out[i] = float_to_bf16(v); - } else { - out[i] = static_cast(v); - } + if (blocksize <= 0 || n <= 0) return; + // 8-bit path + long long total = (m * n) >> 1; + #pragma omp parallel for + for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { + long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); + long long block_end = block_idx + valid_items; + float scale = absmax[block_idx / blocksize]; + for (long long i = block_idx; i < block_end; ++i) { + float v = code[A[i]] * scale; + if constexpr (std::is_same::value) { + out[i] = float_to_bf16(v); + } else { + out[i] = static_cast(v); } } - } else { - // 4-bit helper (FP4 / NF4) - dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } } @@ -147,25 +138,11 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long // TEMPLATE DEFINITIONS //============================================================== -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( +template void dequantizeBlockwise8bitCpu( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); - -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( +template void dequantizeBlockwise8bitCpu( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); - -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); -template void dequantizeBlockwiseCpu( +template void dequantizeBlockwise8bitCpu( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); template void dequantizeBlockwise4bitCpu( diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 0ea071e2d..092261c4f 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -10,9 +10,8 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); typedef enum DataType_t { - General8bit = 0, + NF4 = 0, FP4 = 1, - NF4 = 2, } DataType_t; using fp16_t = _Float16; @@ -116,7 +115,10 @@ inline float dDequantizeNF4(unsigned char val) { return -1.0f; //*0000 } +template +void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); + template -void dequantizeBlockwiseCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); +void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 7cd74b844..127d147a0 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -845,48 +845,48 @@ void cquantize_blockwise_cpu_fp32( void cdequantize_blockwise_cpu_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwiseCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); } } From a3f7b61128bf051e824b9979dbb3f166dfdced56 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:39:36 +0000 Subject: [PATCH 29/47] fix typo Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 6 +++--- csrc/cpu_ops.h | 2 +- csrc/pythonInterface.cpp | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 091925fca..2bc380a91 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -139,11 +139,11 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long //============================================================== template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); template void dequantizeBlockwise4bitCpu( unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 092261c4f..047466f7a 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -116,7 +116,7 @@ inline float dDequantizeNF4(unsigned char val) { } template -void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); +void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n); template void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 127d147a0..894432ede 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -845,17 +845,17 @@ void cquantize_blockwise_cpu_fp32( void cdequantize_blockwise_cpu_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp4_fp32( From 47084701d5a736435e9ceaf2f74033b5229cdf2c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:40:36 +0000 Subject: [PATCH 30/47] fix typo Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 2bc380a91..55d129ceb 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -69,10 +69,9 @@ void dequantizeBlockwise8bitCpu(float* code, long long n) { if (blocksize <= 0 || n <= 0) return; // 8-bit path - long long total = (m * n) >> 1; #pragma omp parallel for - for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { - long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; for (long long i = block_idx; i < block_end; ++i) { From 1dfe9f71648079fa33532c0c72e8eb4766cf896c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:42:12 +0000 Subject: [PATCH 31/47] fix input params Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 33060718f..c4475eef1 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -146,7 +146,6 @@ def _( if quant_type == "fp4": if dtype == torch.float32: lib.cdequantize_blockwise_cpu_fp4_fp32( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -156,7 +155,6 @@ def _( ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_fp4_bf16( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -166,7 +164,6 @@ def _( ) elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_fp4_fp16( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -177,7 +174,6 @@ def _( elif quant_type == "nf4": if dtype == torch.float32: lib.cdequantize_blockwise_cpu_nf4_fp32( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -187,7 +183,6 @@ def _( ) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_cpu_nf4_bf16( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -202,7 +197,6 @@ def _( import pdb; pdb.set_trace() elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( - None, get_ptr(A), get_ptr(absmax), get_ptr(out), From 00289c429dc28a552894bf086e0b1349d07af74f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 12:43:43 +0000 Subject: [PATCH 32/47] fix input params Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 10 +++++----- csrc/pythonInterface.cpp | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 55d129ceb..e9c477893 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -10,11 +10,11 @@ using namespace BinSearch; // DATA_TYPE: 1 = FP4, 2 = NF4 template void dequantizeBlockwise4bitCpu(unsigned char* A, - const float* absmax, - T* out, - long long blocksize, - long long m, - long long n) { + const float* absmax, + T* out, + long long blocksize, + long long m, + long long n) { static_assert(DATA_TYPE == 0 || DATA_TYPE == 1, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); if (blocksize <= 0 || n <= 0) return; diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 894432ede..62d3bf826 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -861,32 +861,32 @@ void cdequantize_blockwise_cpu_fp16( void cdequantize_blockwise_cpu_fp4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp32( float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { - dequantizeBlockwise4bitCpu(code, A, absmax, out, blocksize, m, n); + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } } From a2578baabaaaf178f331e5357fd2f3ac3ed6654b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 13:04:05 +0000 Subject: [PATCH 33/47] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 4 +++- csrc/cpu_ops.cpp | 4 ++-- csrc/pythonInterface.cpp | 12 ++++++------ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index c4475eef1..acf4caa34 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -190,7 +190,7 @@ def _( ct.c_longlong(shape[0]), ct.c_longlong(shape[1]), ) - out_2 = dequantize_nf4_test(A, absmax, blocksize, quant_type, shape, dtype) + out_2 = dequantize_nf4_test(A.reshape(-1), absmax, blocksize, quant_type, shape, dtype) out = out.reshape(shape) out_2 = out_2.reshape(shape) if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): @@ -266,6 +266,8 @@ def dequantize_nf4_test( else: out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) + out = out.reshape(-1, *shape[1:]).to(dtype) + return out diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e9c477893..a0c9bd50c 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -7,7 +7,7 @@ using namespace BinSearch; // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. -// DATA_TYPE: 1 = FP4, 2 = NF4 +// DATA_TYPE: 1 = FP4, 0 = NF4 template void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, @@ -17,7 +17,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, long long n) { static_assert(DATA_TYPE == 0 || DATA_TYPE == 1, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); - if (blocksize <= 0 || n <= 0) return; + if (blocksize <= 0 || m < 0 || n <= 0) return; #if defined(__AVX512F__) && defined(__AVX512BW__) && defined(TEST_BUG) // AVX512 optimized branch (placeholder) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 62d3bf826..fd89a626e 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -859,33 +859,33 @@ void cdequantize_blockwise_cpu_fp16( } void cdequantize_blockwise_cpu_fp4_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_fp4_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } void cdequantize_blockwise_cpu_nf4_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } From 72033dc1a39f803c40a42b0f31f0ae6bddc7b2af Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 13:19:58 +0000 Subject: [PATCH 34/47] fix typo Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 2 +- csrc/pythonInterface.cpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index acf4caa34..2a6014940 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -142,7 +142,7 @@ def _( absmax = absmax.float() A = A.reshape(shape[0], shape[1] // 2) - out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) + out = torch.empty(shape, dtype=dtype, device=A.device) if quant_type == "fp4": if dtype == torch.float32: lib.cdequantize_blockwise_cpu_fp4_fp32( diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index fd89a626e..d9914951f 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -843,17 +843,17 @@ void cquantize_blockwise_cpu_fp32( } void cdequantize_blockwise_cpu_fp32( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n ) { dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_bf16( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n ) { dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp16( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n ) { dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } From 1c20ae831e4371f62e668ff9df0216e492edb48b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 15:34:41 +0000 Subject: [PATCH 35/47] enable dequant4bit Signed-off-by: jiqing-feng --- CMakeLists.txt | 12 ++ bitsandbytes/backends/cpu/ops.py | 212 ++++++++++++++----------------- csrc/cpu_ops.cpp | 127 +++++++++++++++++- csrc/cpu_ops.h | 1 + 4 files changed, 229 insertions(+), 123 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c5abfca78..8d4a492c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -271,6 +271,18 @@ target_include_directories(bitsandbytes PUBLIC csrc include) if (BUILD_CPU) target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX) + include(CheckCXXCompilerFlag) + + check_cxx_compiler_flag(-mavx512f HAS_AVX512F) + check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16) + + if(HAS_AVX512F) + target_compile_options(bitsandbytes PRIVATE -mavx512f) + endif() + + if(HAS_AVX512BF16) + target_compile_options(bitsandbytes PRIVATE -mavx512bf16) + endif() endif() if(BUILD_CUDA) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 2a6014940..57cd830c2 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -27,6 +27,12 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0]) +def _reverse_4bit_compress_format(weight: torch.Tensor): + out_1 = (weight & 0xF0) >> 4 + out_2 = (weight & 0xF) << 4 + out = out_1 | out_2 + return out + if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): @register_kernel("bitsandbytes::quantize_blockwise", "cpu") @@ -118,121 +124,95 @@ def _( return out -@register_kernel("bitsandbytes::dequantize_4bit", "cpu") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - torch._check_is_size(blocksize) - torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - # Enable non uint8 dtype - if A.dtype != torch.uint8: - A = A.view(torch.uint8) - - # TODO: support half precision absmax - if absmax.dtype != torch.float32: - absmax = absmax.float() - - A = A.reshape(shape[0], shape[1] // 2) - out = torch.empty(shape, dtype=dtype, device=A.device) - if quant_type == "fp4": - if dtype == torch.float32: - lib.cdequantize_blockwise_cpu_fp4_fp32( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_cpu_fp4_bf16( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - elif dtype == torch.float16: - lib.cdequantize_blockwise_cpu_fp4_fp16( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - elif quant_type == "nf4": - if dtype == torch.float32: - lib.cdequantize_blockwise_cpu_nf4_fp32( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_cpu_nf4_bf16( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - out_2 = dequantize_nf4_test(A.reshape(-1), absmax, blocksize, quant_type, shape, dtype) - out = out.reshape(shape) - out_2 = out_2.reshape(shape) - if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): - import pdb; pdb.set_trace() - elif dtype == torch.float16: - lib.cdequantize_blockwise_cpu_nf4_fp16( - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(shape[0]), - ct.c_longlong(shape[1]), - ) - else: - # Map nf4 to [-1, 1] - A = A.reshape(-1) - out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) - n = out_dq.numel() - out_dq[1::2] = A & 0xF - out_dq[::2] = A >> 4 - # code is fp32, cast to dtype to avoid the mismatch issue - code = CODE[quant_type].to(dtype).to(A.device) - out_dq = code[out_dq] - - # Apply scales - if out_dq.numel() != n: - assert out_dq.numel() == n + 1 - out_dq = torch.narrow(out_dq, 0, 0, n) - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - rem = n % blocksize - has_rem = rem > 0 - - if has_rem: - out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) - out[n - rem :] = out_dq[n - rem :] * absmax[-1] + @register_kernel("bitsandbytes::dequantize_4bit", "cpu") + def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + ) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + # Enable non uint8 dtype + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + + # TODO: support half precision absmax + if absmax.dtype != torch.float32: + absmax = absmax.float() + + A = _reverse_4bit_compress_format(A) + A = A.reshape(shape[0], shape[1] // 2) + out = torch.empty(shape, dtype=dtype, device=A.device) + if quant_type == "fp4": + if dtype == torch.float32: + lib.cdequantize_blockwise_cpu_fp4_fp32( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_cpu_fp4_bf16( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + elif dtype == torch.float16: + lib.cdequantize_blockwise_cpu_fp4_fp16( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + elif quant_type == "nf4": + if dtype == torch.float32: + lib.cdequantize_blockwise_cpu_nf4_fp32( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_cpu_nf4_bf16( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) + out_2 = dequantize_nf4_test(_reverse_4bit_compress_format(A.reshape(-1)), absmax, blocksize, quant_type, shape, dtype) + if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): + import pdb; pdb.set_trace() + elif dtype == torch.float16: + lib.cdequantize_blockwise_cpu_nf4_fp16( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(shape[0]), + ct.c_longlong(shape[1]), + ) else: - out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) - - out = out.reshape(-1, *shape[1:]).to(dtype) + raise ValueError - return out + return out def dequantize_nf4_test( A: torch.Tensor, @@ -270,9 +250,3 @@ def dequantize_nf4_test( return out - -def _reverse_4bit_compress_format(weight: torch.Tensor): - out_1 = (weight & 0xF0) >> 4 - out_2 = (weight & 0xF) << 4 - out = out_1 | out_2 - return out diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index a0c9bd50c..5beeaf3d4 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,6 +5,76 @@ using namespace BinSearch; +// #if defined(__AVX512F__) +#if 1 +#include + +inline __m256i cvt_fp32_to_fp16(const __m512 src) { + return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + +inline __m256i cvt_fp32_to_bf16(const __m512 src) { + #if defined(__AVX512BF16__) + return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src)); + #else + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); + #endif +} + +static inline __m512 set_nf4_lut() { + return _mm512_set_ps( + 1.0f, + 0.7229568362236023, + 0.5626170039176941, + 0.44070982933044434, + 0.33791524171829224, + 0.24611230194568634, + 0.16093020141124725, + 0.07958029955625534, + 0.0f, + -0.09105003625154495, + -0.18477343022823334, + -0.28444138169288635, + -0.39491748809814453, + -0.5250730514526367, + -0.6961928009986877, + -1.0f); +} +static inline __m512 set_fp4_lut() { + return _mm512_set_ps( + 0.0000f, + 5.208333333e-03f, + 0.66666667f, + 1.0000f, + 0.33333333f, + 0.5000f, + 0.16666667f, + 0.2500f, + 0.0000f, + -5.208333333e-03f, + -0.66666667f, + -1.0000f, + -0.33333333f, + -0.5000f, + -0.16666667f, + -0.2500f); +} +#endif // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. // DATA_TYPE: 1 = FP4, 0 = NF4 @@ -19,10 +89,59 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); if (blocksize <= 0 || m < 0 || n <= 0) return; -#if defined(__AVX512F__) && defined(__AVX512BW__) && defined(TEST_BUG) - // AVX512 optimized branch (placeholder) - // DATA_TYPE: 1 = FP4, 2 = NF4 - if (1 == 0) {return;} +// #if defined(__AVX512F__) && defined(TEST_BUG) +# if 1 + auto dim_0 = m; + auto dim_1 = n; + auto input_dim_1 = dim_1 >> 1; + using Tcomp = float; + constexpr auto VEC_LEN = sizeof(__m512i) / sizeof(Tcomp); // 16 + if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) { + __m512 lut = DATA_TYPE == 1 ? set_fp4_lut() : set_nf4_lut(); + constexpr auto k_step = VEC_LEN / 2; // 8 + // auto dequant_loop = ThreadedLoop<2>({{dim_0}}, /* loop_scheme */ "A"); + // dequant_loop( + // [&](int* idx) { + // int block_idx = idx[0]; + #pragma omp parallel for + for (int block_idx = 0; block_idx < dim_0; ++block_idx) { + for (int k = 0; k < input_dim_1; k += k_step) { + // Load 64 bits of nf4 data and a single scale data + // auto p = A[block_idx * input_dim_1 + k]; + uint8_t* p = &A[block_idx * input_dim_1 + k]; + uint64_t packed; + std::memcpy(&packed, p, sizeof(uint64_t)); + auto scale_idx = k * 2 / blocksize; + auto vscales = _mm512_set1_ps((float)absmax[block_idx * blocksize + scale_idx]); + // uint64_t packed = reinterpret_cast(p)[0]; + // unpack nf4 data to 32-bit integers + uint64_t high = 0; + uint64_t low = 0; + for (int i = 0; i < 8; ++i) { + low |= ((packed >> (i * 4)) & 0xf) << (i * 8); + high |= ((packed >> (i * 4 + 32)) & 0xf) << (i * 8); + } + __m128i packed_128 = _mm_set_epi64x(high, low); + __m512i vint32 = _mm512_cvtepu8_epi32(packed_128); + // Table look-up + __m512 vout = _mm512_permutexvar_ps(vint32, lut); + // Apply scale + vout = _mm512_mul_ps(vout, vscales); + // Store results + // auto pout = out[block_idx * dim_1 + k * 2]; + T* pout = &out[block_idx * dim_1 + k * 2]; // out[block_idx][k/k_step] + if constexpr (std::is_same()) { + _mm512_storeu_ps(pout, vout); + } else if constexpr (std::is_same()) { + _mm256_storeu_si256( + (__m256i*)pout, cvt_fp32_to_bf16(vout)); + } else if constexpr (std::is_same()) { + _mm256_storeu_si256( + (__m256i*)pout, cvt_fp32_to_fp16(vout)); + } + } + } + } #else // Scalar fallback branch long long total = m * n; diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 047466f7a..6be5a864c 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -115,6 +115,7 @@ inline float dDequantizeNF4(unsigned char val) { return -1.0f; //*0000 } + template void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n); From 7552fe22e94ef1cdf472df84587c048615fad3d0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 16:05:42 +0000 Subject: [PATCH 36/47] fix Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 6 +++--- csrc/cpu_ops.cpp | 24 ++++++++---------------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 57cd830c2..c5a45c914 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -197,9 +197,9 @@ def _( ct.c_longlong(shape[0]), ct.c_longlong(shape[1]), ) - out_2 = dequantize_nf4_test(_reverse_4bit_compress_format(A.reshape(-1)), absmax, blocksize, quant_type, shape, dtype) - if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): - import pdb; pdb.set_trace() + # out_2 = dequantize_nf4_test(_reverse_4bit_compress_format(A.reshape(-1)), absmax, blocksize, quant_type, shape, dtype) + # if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): + # import pdb; pdb.set_trace() elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( get_ptr(A), diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 5beeaf3d4..f9c7a5364 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,8 +5,7 @@ using namespace BinSearch; -// #if defined(__AVX512F__) -#if 1 +#if defined(__AVX512F__) #include inline __m256i cvt_fp32_to_fp16(const __m512 src) { @@ -89,31 +88,25 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); if (blocksize <= 0 || m < 0 || n <= 0) return; -// #if defined(__AVX512F__) && defined(TEST_BUG) -# if 1 - auto dim_0 = m; - auto dim_1 = n; - auto input_dim_1 = dim_1 >> 1; +#if defined(__AVX512F__) && defined(TEST_BUG) + long long dim_0 = m; + long long dim_1 = n; + long long input_dim_1 = dim_1 >> 1; + long long absmax_dim_1 = dim_1 / blocksize using Tcomp = float; constexpr auto VEC_LEN = sizeof(__m512i) / sizeof(Tcomp); // 16 if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) { __m512 lut = DATA_TYPE == 1 ? set_fp4_lut() : set_nf4_lut(); constexpr auto k_step = VEC_LEN / 2; // 8 - // auto dequant_loop = ThreadedLoop<2>({{dim_0}}, /* loop_scheme */ "A"); - // dequant_loop( - // [&](int* idx) { - // int block_idx = idx[0]; #pragma omp parallel for for (int block_idx = 0; block_idx < dim_0; ++block_idx) { for (int k = 0; k < input_dim_1; k += k_step) { // Load 64 bits of nf4 data and a single scale data - // auto p = A[block_idx * input_dim_1 + k]; uint8_t* p = &A[block_idx * input_dim_1 + k]; uint64_t packed; std::memcpy(&packed, p, sizeof(uint64_t)); auto scale_idx = k * 2 / blocksize; - auto vscales = _mm512_set1_ps((float)absmax[block_idx * blocksize + scale_idx]); - // uint64_t packed = reinterpret_cast(p)[0]; + auto vscales = _mm512_set1_ps((float)absmax[block_idx * absmax_dim_1 + scale_idx]); // unpack nf4 data to 32-bit integers uint64_t high = 0; uint64_t low = 0; @@ -128,8 +121,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, // Apply scale vout = _mm512_mul_ps(vout, vscales); // Store results - // auto pout = out[block_idx * dim_1 + k * 2]; - T* pout = &out[block_idx * dim_1 + k * 2]; // out[block_idx][k/k_step] + T* pout = &out[block_idx * dim_1 + k * 2]; if constexpr (std::is_same()) { _mm512_storeu_ps(pout, vout); } else if constexpr (std::is_same()) { From 8b32a39c34464c44120128d7d0982371087f9d70 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 16:09:33 +0000 Subject: [PATCH 37/47] fix Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index f9c7a5364..ddfd64aa7 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -88,11 +88,11 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); if (blocksize <= 0 || m < 0 || n <= 0) return; -#if defined(__AVX512F__) && defined(TEST_BUG) +#if defined(__AVX512F__) long long dim_0 = m; long long dim_1 = n; long long input_dim_1 = dim_1 >> 1; - long long absmax_dim_1 = dim_1 / blocksize + long long absmax_dim_1 = dim_1 / blocksize; using Tcomp = float; constexpr auto VEC_LEN = sizeof(__m512i) / sizeof(Tcomp); // 16 if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) { From 8f1cc3699be96062564b8296aa2382c3356e71d3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 31 Oct 2025 17:47:23 +0000 Subject: [PATCH 38/47] fix reverse Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 7 ------- csrc/cpu_ops.cpp | 10 +++++++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index c5a45c914..a716c7580 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -27,12 +27,6 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0]) -def _reverse_4bit_compress_format(weight: torch.Tensor): - out_1 = (weight & 0xF0) >> 4 - out_2 = (weight & 0xF) << 4 - out = out_1 | out_2 - return out - if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): @register_kernel("bitsandbytes::quantize_blockwise", "cpu") @@ -147,7 +141,6 @@ def _( if absmax.dtype != torch.float32: absmax = absmax.float() - A = _reverse_4bit_compress_format(A) A = A.reshape(shape[0], shape[1] // 2) out = torch.empty(shape, dtype=dtype, device=A.device) if quant_type == "fp4": diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index ddfd64aa7..a46799c58 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,6 +5,8 @@ using namespace BinSearch; +#define __AVX512F__ + #if defined(__AVX512F__) #include @@ -110,9 +112,11 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, // unpack nf4 data to 32-bit integers uint64_t high = 0; uint64_t low = 0; - for (int i = 0; i < 8; ++i) { - low |= ((packed >> (i * 4)) & 0xf) << (i * 8); - high |= ((packed >> (i * 4 + 32)) & 0xf) << (i * 8); + for (int i = 0; i < 4; ++i) { + low |= ((packed >> (2*i * 4)) & 0xf) << ((2*i+1) * 8); + low |= ((packed >> ((2*i+1) * 4)) & 0xf) << (2*i * 8); + high |= ((packed >> (2*i * 4 + 32)) & 0xf) << ((2*i+1) * 8); + high |= ((packed >> ((2*i+1) * 4 + 32)) & 0xf) << (2*i * 8); } __m128i packed_128 = _mm_set_epi64x(high, low); __m512i vint32 = _mm512_cvtepu8_epi32(packed_128); From 49d242a82751c45bb3ad04aae6eb740d62eecc40 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 3 Nov 2025 09:15:24 +0000 Subject: [PATCH 39/47] fix dequant 4bit fallback path Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index a46799c58..a797b0ab2 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -5,7 +5,6 @@ using namespace BinSearch; -#define __AVX512F__ #if defined(__AVX512F__) #include @@ -137,8 +136,9 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, } } } + return; } -#else +#endif // Scalar fallback branch long long total = m * n; #pragma omp parallel for @@ -171,7 +171,6 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, } } } -#endif } From 4a9a6dc1817bc38110ac94656bd606024a4b953b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 3 Nov 2025 09:42:55 +0000 Subject: [PATCH 40/47] fix fp4 dequant Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index a797b0ab2..f8082fb7a 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -57,22 +57,22 @@ static inline __m512 set_nf4_lut() { } static inline __m512 set_fp4_lut() { return _mm512_set_ps( + -0.2500f, + -0.16666667f, + -0.5000f, + -0.33333333f, + -1.0000f, + -0.66666667f, + -5.208333333e-03f, 0.0000f, - 5.208333333e-03f, - 0.66666667f, - 1.0000f, - 0.33333333f, - 0.5000f, - 0.16666667f, 0.2500f, - 0.0000f, - -5.208333333e-03f, - -0.66666667f, - -1.0000f, - -0.33333333f, - -0.5000f, - -0.16666667f, - -0.2500f); + 0.16666667f, + 0.5000f, + 0.33333333f, + 1.0000f, + 0.66666667f, + 5.208333333e-03f, + 0.0000f); } #endif From d7e981d920c8ac79c208a3fb56f8493203cdd386 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 5 Nov 2025 12:52:55 +0000 Subject: [PATCH 41/47] rm _Float16 Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 6 ++++++ csrc/cpu_ops.h | 49 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index f8082fb7a..f590bc6ab 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -158,6 +158,8 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, if constexpr (std::is_same::value) { out[block_idx + i] = float_to_bf16(v0); + } else if constexpr (std::is_same::value) { + out[block_idx + i] = float_to_fp16(v0); } else { out[block_idx + i] = static_cast(v0); } @@ -165,6 +167,8 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, if (i + 1 < valid_items) { if constexpr (std::is_same::value) { out[block_idx + i + 1] = float_to_bf16(v1); + } else if constexpr (std::is_same::value) { + out[block_idx + i + 1] = float_to_fp16(v1); } else { out[block_idx + i + 1] = static_cast(v1); } @@ -192,6 +196,8 @@ void dequantizeBlockwise8bitCpu(float* code, float v = code[A[i]] * scale; if constexpr (std::is_same::value) { out[i] = float_to_bf16(v); + } else if constexpr (std::is_same::value) { + out[i] = float_to_fp16(v); } else { out[i] = static_cast(v); } diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 6be5a864c..fea894d79 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -1,11 +1,8 @@ #ifndef BITSANDBYTES_CPU_OPS_H #define BITSANDBYTES_CPU_OPS_H -#include -#include #include #include -#include void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); @@ -14,7 +11,9 @@ typedef enum DataType_t { FP4 = 1, } DataType_t; -using fp16_t = _Float16; +struct fp16_t { + uint16_t v; +}; struct bf16_t { uint16_t v; @@ -27,6 +26,48 @@ static inline bf16_t float_to_bf16(float x) { return bf16_t{static_cast(r >> 16)}; } +static inline fp16_t float_to_fp16(float x) { + uint32_t bits; + std::memcpy(&bits, &x, 4); + uint32_t sign = (bits >> 31) & 0x1; + uint32_t exp = (bits >> 23) & 0xFF; + uint32_t mant = bits & 0x7FFFFF; + + uint16_t h; + if (exp == 0xFF) { // Inf / NaN + uint16_t mant16 = mant ? 0x200 : 0; // quiet NaN: set MSB of mantissa + h = (sign << 15) | (0x1F << 10) | mant16; + } else if (exp > 0x70 + 0x1E) { // overflow: exp_f -127 +15 > 30 (exp_f > 142) + h = (sign << 15) | (0x1F << 10); // Inf + } else if (exp < 0x71) { // subnormal or zero (exp_f < 113) + if (exp < 0x67) { // too small -> zero (exp_f < 103) + h = (sign << 15); + } else { + // subnormal: implicit leading 1 + uint32_t shift = 0x71 - exp; + uint32_t mant_with_hidden = mant | 0x800000; + // add rounding bias before shifting (23-10 =13 bits to drop + shift) + uint32_t rounded = (mant_with_hidden + (1u << (shift + 12))) >> (shift + 13); + h = (sign << 15) | (uint16_t)rounded; + } + } else { + // normalized + uint32_t exp_h = exp - 127 + 15; + // round mantissa: add 2^(23-10-1) = 0x1000 + uint32_t mant_rounded = mant + 0x00001000; + if (mant_rounded & 0x00800000) { // mantissa overflow after rounding + mant_rounded = 0; + ++exp_h; + if (exp_h >= 0x1F) { // overflow to Inf + h = (sign << 15) | (0x1F << 10); + return fp16_t{h}; + } + } + h = (sign << 15) | ((uint16_t)exp_h << 10) | ((uint16_t)(mant_rounded >> 13)); + } + return fp16_t{h}; +} + inline float dDequantizeFP4(unsigned char val) { if ((val & 0b1000) == 8) if ((val & 0b0100) == 4) From d8cbc681e3f4512fef3f2c30c761e93dfc4a9dac Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 6 Nov 2025 09:39:06 +0000 Subject: [PATCH 42/47] fix cmake check Signed-off-by: jiqing-feng --- CMakeLists.txt | 49 +++++++++++++++++++++++++++++++++++++++--------- csrc/cpu_ops.cpp | 13 ++++++++++--- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d4a492c8..f1c1efffa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,6 +85,7 @@ endif() if (BUILD_CPU) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) + string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" HOST_ARCH) find_package(OpenMP) endif() @@ -270,18 +271,48 @@ target_compile_features(bitsandbytes PUBLIC cxx_std_17) target_include_directories(bitsandbytes PUBLIC csrc include) if (BUILD_CPU) - target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX) - include(CheckCXXCompilerFlag) - - check_cxx_compiler_flag(-mavx512f HAS_AVX512F) - check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16) + if (OpenMP_CXX_FOUND) + target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX) + add_definitions(-DHAS_OPENMP) + else() + add_definitions(-DNO_OPENMP) + endif() - if(HAS_AVX512F) - target_compile_options(bitsandbytes PRIVATE -mavx512f) + if (HOST_ARCH MATCHES "x86_64|amd64") + include(CheckCXXCompilerFlag) + check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG) + check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG) + if (HAS_AVX512F_FLAG) + target_compile_options(bitsandbytes PRIVATE -mavx512f) + add_definitions(-DHAS_AVX512F) + endif() + if (HAS_AVX512BF16_FLAG) + target_compile_options(bitsandbytes PRIVATE -mavx512bf16) + add_definitions(-DHAS_AVX512BF16) + else() + add_definitions(-DNO_AVX512BF16) + endif() endif() +endif() - if(HAS_AVX512BF16) - target_compile_options(bitsandbytes PRIVATE -mavx512bf16) +# --- Windows MSVC specific AVX512BF16 probe (after add_library) --- +if (MSVC AND BUILD_CPU) + include(CheckCXXSourceCompiles) + set(_AVX512BF16_TEST " + #include + int main(){ + __m512bh a{}, b{}; + auto c = _mm512_dpbf16_ps(_mm512_setzero_ps(), a, b); + (void)c; + return 0; + }") + check_cxx_source_compiles("${_AVX512BF16_TEST}" MSVC_HAS_AVX512BF16) + if (MSVC_HAS_AVX512BF16) + # /arch:AVX512; + target_compile_options(bitsandbytes PRIVATE /arch:AVX512) + add_definitions(-DHAS_AVX512BF16) + else() + add_definitions(-DNO_AVX512BF16) endif() endif() diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index f590bc6ab..18aa9e596 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -3,6 +3,13 @@ #include #include +#ifdef HAS_OPENMP +#include +#define BNB_OMP_PARALLEL_FOR _Pragma("omp parallel for") +#else +#define BNB_OMP_PARALLEL_FOR +#endif + using namespace BinSearch; @@ -99,7 +106,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) { __m512 lut = DATA_TYPE == 1 ? set_fp4_lut() : set_nf4_lut(); constexpr auto k_step = VEC_LEN / 2; // 8 - #pragma omp parallel for + BNB_OMP_PARALLEL_FOR for (int block_idx = 0; block_idx < dim_0; ++block_idx) { for (int k = 0; k < input_dim_1; k += k_step) { // Load 64 bits of nf4 data and a single scale data @@ -141,7 +148,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, #endif // Scalar fallback branch long long total = m * n; - #pragma omp parallel for + BNB_OMP_PARALLEL_FOR for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx); float scale = absmax[block_idx / blocksize]; @@ -187,7 +194,7 @@ void dequantizeBlockwise8bitCpu(float* code, long long n) { if (blocksize <= 0 || n <= 0) return; // 8-bit path - #pragma omp parallel for + BNB_OMP_PARALLEL_FOR for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); long long block_end = block_idx + valid_items; From a0389c81d6f5a97b15b742d94471d24c4f3dfe1c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 09:02:41 +0000 Subject: [PATCH 43/47] fix lint Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 41 -------- csrc/cpu_ops.cpp | 165 +++++++++++++------------------ csrc/cpu_ops.h | 22 +++-- csrc/pythonInterface.cpp | 6 ++ 4 files changed, 88 insertions(+), 146 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index a716c7580..fe25a7f70 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -6,7 +6,6 @@ from bitsandbytes.functional import get_ptr -from ..utils import CODE from ..._ops import register_kernel from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib @@ -190,9 +189,6 @@ def _( ct.c_longlong(shape[0]), ct.c_longlong(shape[1]), ) - # out_2 = dequantize_nf4_test(_reverse_4bit_compress_format(A.reshape(-1)), absmax, blocksize, quant_type, shape, dtype) - # if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2): - # import pdb; pdb.set_trace() elif dtype == torch.float16: lib.cdequantize_blockwise_cpu_nf4_fp16( get_ptr(A), @@ -206,40 +202,3 @@ def _( raise ValueError return out - -def dequantize_nf4_test( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -): - # Map nf4 to [-1, 1] - out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) - n = out_dq.numel() - out_dq[1::2] = A & 0xF - out_dq[::2] = A >> 4 - # code is fp32, cast to dtype to avoid the mismatch issue - code = CODE[quant_type].to(dtype).to(A.device) - out_dq = code[out_dq] - - # Apply scales - if out_dq.numel() != n: - assert out_dq.numel() == n + 1 - out_dq = torch.narrow(out_dq, 0, 0, n) - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - rem = n % blocksize - has_rem = rem > 0 - - if has_rem: - out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) - out[n - rem :] = out_dq[n - rem :] * absmax[-1] - else: - out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) - - out = out.reshape(-1, *shape[1:]).to(dtype) - - return out - diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 18aa9e596..6ec19db88 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -12,89 +12,61 @@ using namespace BinSearch; - #if defined(__AVX512F__) #include inline __m256i cvt_fp32_to_fp16(const __m512 src) { return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - } +} inline __m256i cvt_fp32_to_bf16(const __m512 src) { - #if defined(__AVX512BF16__) - return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src)); - #else - __m512i value = _mm512_castps_si512(src); - __m512i nan = _mm512_set1_epi32(0xffff); - auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); - __m512i ones = _mm512_set1_epi32(0x1); - __m512i vec_bias = _mm512_set1_epi32(0x7fff); - // uint32_t lsb = (input >> 16) & 1; - auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); - // uint32_t rounding_bias = 0x7fff + lsb; - t_value = _mm512_add_epi32(t_value, vec_bias); - // input += rounding_bias; - t_value = _mm512_add_epi32(t_value, value); - // input = input >> 16; - t_value = _mm512_srli_epi32(t_value, 16); - // Check NaN before converting back to bf16 - t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); - return _mm512_cvtusepi32_epi16(t_value); - #endif +#if defined(__AVX512BF16__) + return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src)); +#else + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); +#endif } static inline __m512 set_nf4_lut() { return _mm512_set_ps( - 1.0f, - 0.7229568362236023, - 0.5626170039176941, - 0.44070982933044434, - 0.33791524171829224, - 0.24611230194568634, - 0.16093020141124725, - 0.07958029955625534, - 0.0f, - -0.09105003625154495, - -0.18477343022823334, - -0.28444138169288635, - -0.39491748809814453, - -0.5250730514526367, - -0.6961928009986877, - -1.0f); + 1.0f, 0.7229568362236023, 0.5626170039176941, 0.44070982933044434, 0.33791524171829224, 0.24611230194568634, + 0.16093020141124725, 0.07958029955625534, 0.0f, -0.09105003625154495, -0.18477343022823334, + -0.28444138169288635, -0.39491748809814453, -0.5250730514526367, -0.6961928009986877, -1.0f + ); } + static inline __m512 set_fp4_lut() { return _mm512_set_ps( - -0.2500f, - -0.16666667f, - -0.5000f, - -0.33333333f, - -1.0000f, - -0.66666667f, - -5.208333333e-03f, - 0.0000f, - 0.2500f, - 0.16666667f, - 0.5000f, - 0.33333333f, - 1.0000f, - 0.66666667f, - 5.208333333e-03f, - 0.0000f); + -0.2500f, -0.16666667f, -0.5000f, -0.33333333f, -1.0000f, -0.66666667f, -5.208333333e-03f, 0.0000f, 0.2500f, + 0.16666667f, 0.5000f, 0.33333333f, 1.0000f, 0.66666667f, 5.208333333e-03f, 0.0000f + ); } #endif // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. -// DATA_TYPE: 1 = FP4, 0 = NF4 +// DATA_TYPE: 1 = FP4, 2 = NF4 template -void dequantizeBlockwise4bitCpu(unsigned char* A, - const float* absmax, - T* out, - long long blocksize, - long long m, - long long n) { - static_assert(DATA_TYPE == 0 || DATA_TYPE == 1, - "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); - if (blocksize <= 0 || m < 0 || n <= 0) return; +void dequantizeBlockwise4bitCpu( + unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n +) { + static_assert(DATA_TYPE == 1 || DATA_TYPE == 2, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE"); + if (blocksize <= 0 || m < 0 || n <= 0) + return; #if defined(__AVX512F__) long long dim_0 = m; @@ -119,10 +91,10 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, uint64_t high = 0; uint64_t low = 0; for (int i = 0; i < 4; ++i) { - low |= ((packed >> (2*i * 4)) & 0xf) << ((2*i+1) * 8); - low |= ((packed >> ((2*i+1) * 4)) & 0xf) << (2*i * 8); - high |= ((packed >> (2*i * 4 + 32)) & 0xf) << ((2*i+1) * 8); - high |= ((packed >> ((2*i+1) * 4 + 32)) & 0xf) << (2*i * 8); + low |= ((packed >> (2 * i * 4)) & 0xf) << ((2 * i + 1) * 8); + low |= ((packed >> ((2 * i + 1) * 4)) & 0xf) << (2 * i * 8); + high |= ((packed >> (2 * i * 4 + 32)) & 0xf) << ((2 * i + 1) * 8); + high |= ((packed >> ((2 * i + 1) * 4 + 32)) & 0xf) << (2 * i * 8); } __m128i packed_128 = _mm_set_epi64x(high, low); __m512i vint32 = _mm512_cvtepu8_epi32(packed_128); @@ -133,13 +105,11 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, // Store results T* pout = &out[block_idx * dim_1 + k * 2]; if constexpr (std::is_same()) { - _mm512_storeu_ps(pout, vout); + _mm512_storeu_ps(pout, vout); } else if constexpr (std::is_same()) { - _mm256_storeu_si256( - (__m256i*)pout, cvt_fp32_to_bf16(vout)); + _mm256_storeu_si256((__m256i*)pout, cvt_fp32_to_bf16(vout)); } else if constexpr (std::is_same()) { - _mm256_storeu_si256( - (__m256i*)pout, cvt_fp32_to_fp16(vout)); + _mm256_storeu_si256((__m256i*)pout, cvt_fp32_to_fp16(vout)); } } } @@ -157,11 +127,9 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, unsigned char byte = A[byte_index]; // High nibble first (matches previous code logic) - float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) - : dDequantizeNF4(byte >> 4)) * scale; + float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) : dDequantizeNF4(byte >> 4)) * scale; // Low nibble second - float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) - : dDequantizeNF4(byte & 0x0F)) * scale; + float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) : dDequantizeNF4(byte & 0x0F)) * scale; if constexpr (std::is_same::value) { out[block_idx + i] = float_to_bf16(v0); @@ -184,20 +152,17 @@ void dequantizeBlockwise4bitCpu(unsigned char* A, } } - template -void dequantizeBlockwise8bitCpu(float* code, - unsigned char* A, - const float* absmax, - T* out, - long long blocksize, - long long n) { - if (blocksize <= 0 || n <= 0) return; +void dequantizeBlockwise8bitCpu( + float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n +) { + if (blocksize <= 0 || n <= 0) + return; // 8-bit path BNB_OMP_PARALLEL_FOR for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); - long long block_end = block_idx + valid_items; + long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; for (long long i = block_idx; i < block_end; ++i) { float v = code[A[i]] * scale; @@ -212,7 +177,6 @@ void dequantizeBlockwise8bitCpu(float* code, } } - void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below @@ -265,26 +229,35 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long //============================================================== template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n +); template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n +); template void dequantizeBlockwise8bitCpu( - float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n); + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n +); template void dequantizeBlockwise4bitCpu( - unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n); + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n +); // template void gemv_4bit_inference( // int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index fea894d79..84971f177 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -7,8 +7,9 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); typedef enum DataType_t { - NF4 = 0, + General8bit = 0, FP4 = 1, + NF4 = 2, } DataType_t; struct fp16_t { @@ -30,17 +31,17 @@ static inline fp16_t float_to_fp16(float x) { uint32_t bits; std::memcpy(&bits, &x, 4); uint32_t sign = (bits >> 31) & 0x1; - uint32_t exp = (bits >> 23) & 0xFF; + uint32_t exp = (bits >> 23) & 0xFF; uint32_t mant = bits & 0x7FFFFF; uint16_t h; - if (exp == 0xFF) { // Inf / NaN + if (exp == 0xFF) { // Inf / NaN uint16_t mant16 = mant ? 0x200 : 0; // quiet NaN: set MSB of mantissa h = (sign << 15) | (0x1F << 10) | mant16; - } else if (exp > 0x70 + 0x1E) { // overflow: exp_f -127 +15 > 30 (exp_f > 142) + } else if (exp > 0x70 + 0x1E) { // overflow: exp_f -127 +15 > 30 (exp_f > 142) h = (sign << 15) | (0x1F << 10); // Inf - } else if (exp < 0x71) { // subnormal or zero (exp_f < 113) - if (exp < 0x67) { // too small -> zero (exp_f < 103) + } else if (exp < 0x71) { // subnormal or zero (exp_f < 113) + if (exp < 0x67) { // too small -> zero (exp_f < 103) h = (sign << 15); } else { // subnormal: implicit leading 1 @@ -156,11 +157,14 @@ inline float dDequantizeNF4(unsigned char val) { return -1.0f; //*0000 } - template -void dequantizeBlockwise8bitCpu(float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n); +void dequantizeBlockwise8bitCpu( + float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n +); template -void dequantizeBlockwise4bitCpu(unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n); +void dequantizeBlockwise4bitCpu( + unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n +); #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index d9914951f..f1d15cf51 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -847,11 +847,13 @@ void cdequantize_blockwise_cpu_fp32( ) { dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_cpu_bf16( float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n ) { dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_cpu_fp16( float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n ) { @@ -863,11 +865,13 @@ void cdequantize_blockwise_cpu_fp4_fp32( ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_fp4_bf16( unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_fp4_fp16( unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { @@ -879,11 +883,13 @@ void cdequantize_blockwise_cpu_nf4_fp32( ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_nf4_bf16( unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } + void cdequantize_blockwise_cpu_nf4_fp16( unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n ) { From 0d760b97598cbc3e162e738673d2fd9dfde6bf55 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 09:10:29 +0000 Subject: [PATCH 44/47] fix datatypr Signed-off-by: jiqing-feng --- csrc/common.h | 6 ++++++ csrc/cpu_ops.h | 6 ------ csrc/ops.cuh | 6 ------ csrc/ops_hip.cuh | 6 ------ csrc/xpu_ops.h | 6 ------ 5 files changed, 6 insertions(+), 24 deletions(-) diff --git a/csrc/common.h b/csrc/common.h index c0c9a43be..76b5d6aee 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -5,6 +5,12 @@ using namespace BinSearch; +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + struct quantize_block_args { BinAlgo* bin_searcher; float* code; diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 84971f177..8070560d9 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -6,12 +6,6 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); -typedef enum DataType_t { - General8bit = 0, - FP4 = 1, - NF4 = 2, -} DataType_t; - struct fp16_t { uint16_t v; }; diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 01e11ff31..ba07ddd00 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -77,12 +77,6 @@ typedef enum Transform_t { COL_AMPERE = 4, } Transform_t; -typedef enum DataType_t { - General8bit = 0, - FP4 = 1, - NF4 = 2, -} DataType_t; - typedef enum Funcs_t { FILL = 0, ARANGE = 1, diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index 0f8db2ee4..78efd4425 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -79,12 +79,6 @@ typedef enum Transform_t { COL_AMPERE = 4, } Transform_t; -typedef enum DataType_t { - General8bit = 0, - FP4 = 1, - NF4 = 2, -} DataType_t; - typedef enum Funcs_t { FILL = 0, ARANGE = 1, diff --git a/csrc/xpu_ops.h b/csrc/xpu_ops.h index 142d6c161..b5b79b5f0 100644 --- a/csrc/xpu_ops.h +++ b/csrc/xpu_ops.h @@ -27,12 +27,6 @@ static inline void sycl_comp_kernel_submit(sycl::nd_range range, sycl::queu q.submit(cgf); } -typedef enum DataType_t { - General8bit = 0, - FP4 = 1, - NF4 = 2, -} DataType_t; - template void dequantizeBlockwise( float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int n, sycl::queue* stream From 1e3bde6e4eb6742120aa9d8125cf9031892e571c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 09:11:54 +0000 Subject: [PATCH 45/47] fix include Signed-off-by: jiqing-feng --- csrc/pythonInterface.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index f1d15cf51..9f3c860f3 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -17,6 +17,7 @@ #include #endif #include +#include // Compatibility between HIP/CUDA APIs #if BUILD_HIP From d531f5f33805a5ee3f24c4ef22f0ce24d0eb3a35 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 09:12:55 +0000 Subject: [PATCH 46/47] fix typo Signed-off-by: jiqing-feng --- csrc/pythonInterface.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 9f3c860f3..f88bd4ac6 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -17,7 +17,7 @@ #include #endif #include -#include +#include // Compatibility between HIP/CUDA APIs #if BUILD_HIP From af54c9d7fe5963da275cda371f2a4409ebc6fa5e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 7 Nov 2025 09:24:42 +0000 Subject: [PATCH 47/47] fix include Signed-off-by: jiqing-feng --- csrc/cpu_ops.cpp | 1 - csrc/cpu_ops.h | 1 + csrc/ops.cu | 1 - csrc/ops.cuh | 9 +-------- csrc/ops_hip.cuh | 9 +-------- csrc/pythonInterface.cpp | 1 - csrc/xpu_ops.cpp | 1 - csrc/xpu_ops.h | 1 + 8 files changed, 4 insertions(+), 20 deletions(-) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 6ec19db88..08ac59ab6 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -1,5 +1,4 @@ #include -#include #include #include diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 8070560d9..7040833a0 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -1,6 +1,7 @@ #ifndef BITSANDBYTES_CPU_OPS_H #define BITSANDBYTES_CPU_OPS_H +#include #include #include diff --git a/csrc/ops.cu b/csrc/ops.cu index 37a3191bc..6b9fa87bf 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -5,7 +5,6 @@ #include #include -#include #include #include #include diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 98243a6e6..a9c9bbb12 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -69,14 +70,6 @@ typedef enum Optimizer_t { ADEMAMIX = 6 } Optimizer_t; -typedef enum Transform_t { - ROW = 0, - COL = 1, - COL32 = 2, - COL_TURING = 3, - COL_AMPERE = 4, -} Transform_t; - typedef enum Funcs_t { FILL = 0, ARANGE = 1, diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index e2862e59a..72cdf4e01 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -71,14 +72,6 @@ typedef enum Optimizer_t { ADEMAMIX = 6, } Optimizer_t; -typedef enum Transform_t { - ROW = 0, - COL = 1, - COL32 = 2, - COL_TURING = 3, - COL_AMPERE = 4, -} Transform_t; - typedef enum Funcs_t { FILL = 0, ARANGE = 1, diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 2885cfe57..d61e486b9 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -17,7 +17,6 @@ #include #endif #include -#include // Compatibility between HIP/CUDA APIs #if BUILD_HIP diff --git a/csrc/xpu_ops.cpp b/csrc/xpu_ops.cpp index aa6ac808f..48c986fc4 100644 --- a/csrc/xpu_ops.cpp +++ b/csrc/xpu_ops.cpp @@ -1,4 +1,3 @@ -#include #include #include diff --git a/csrc/xpu_ops.h b/csrc/xpu_ops.h index b5b79b5f0..a5ea80f97 100644 --- a/csrc/xpu_ops.h +++ b/csrc/xpu_ops.h @@ -2,6 +2,7 @@ #define xpu_ops_H #include +#include #include #include #include