Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ nppBackup

# Models
models/*
gpu/checkpoints/*

# Python

Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/llama.cpp
Submodule llama.cpp updated 1 files
+37 −3 src/llama.cpp
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ endif()
find_package(Threads REQUIRED)

add_subdirectory(src)
set(LLAMA_BUILD_SERVER ON CACHE BOOL "Build llama.cpp server" FORCE)
add_subdirectory(3rdparty/llama.cpp)

# install
Expand Down Expand Up @@ -74,4 +75,4 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama)

set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/llama.h)
install(TARGETS llama LIBRARY PUBLIC_HEADER)
install(TARGETS llama LIBRARY PUBLIC_HEADER)
46 changes: 39 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
![version](https://img.shields.io/badge/version-1.0-blue)

<img src="./assets/header_model_release.png" alt="BitNet Model on Hugging Face" width="800"/>
[<img src="./assets/header_model_release.png" alt="BitNet Model on Hugging Face" width="800"/>](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)

bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU (with NPU and GPU support coming next).
Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or build and run it on your own [CPU](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) or [GPU](https://github.com/microsoft/BitNet/blob/main/gpu/README.md).

bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU and GPU (NPU support will coming next).

The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details.

Expand All @@ -20,7 +22,8 @@ A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2:
https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1

## What's New:
- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) ![NEW](https://img.shields.io/badge/NEW-red)
- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md) ![NEW](https://img.shields.io/badge/NEW-red)
- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
- 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880)
- 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965)
- 10/21/2024 [1-bit AI Infra: Part 1.1, Fast and Lossless BitNet b1.58 Inference on CPUs](https://arxiv.org/abs/2410.16144)
Expand Down Expand Up @@ -158,7 +161,7 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
### Build from source

> [!IMPORTANT]
> If you are using Windows, please remember to always use a Developer Command Prompt / PowerShell for VS2022 for the following commands
> If you are using Windows, please remember to always use a Developer Command Prompt / PowerShell for VS2022 for the following commands. Please refer to the FAQs below if you see any issues.

1. Clone the repo
```bash
Expand All @@ -179,9 +182,6 @@ pip install -r requirements.txt
huggingface-cli download microsoft/BitNet-b1.58-2B-4T-gguf --local-dir models/BitNet-b1.58-2B-4T
python setup_env.py -md models/BitNet-b1.58-2B-4T -q i2_s

# Or you can download a model from Hugging Face, convert it to quantized gguf format, and build the project
python setup_env.py --hf-repo tiiuae/Falcon3-7B-Instruct-1.58bit -q i2_s

```
<pre>
usage: setup_env.py [-h] [--hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}] [--model-dir MODEL_DIR] [--log-dir LOG_DIR] [--quant-type {i2_s,tl1}] [--quant-embd]
Expand Down Expand Up @@ -278,4 +278,36 @@ python utils/generate-dummy-bitnet-model.py models/bitnet_b1_58-large --outfile
# Run benchmark with the generated model, use -m to specify the model path, -p to specify the prompt processed, -n to specify the number of token to generate
python utils/e2e_benchmark.py -m models/dummy-bitnet-125m.tl1.gguf -p 512 -n 128
```
### FAQ (Frequently Asked Questions)📌

#### Q1: The build dies with errors building llama.cpp due to issues with std::chrono in log.cpp?

**A:**
This is an issue introduced in recent version of llama.cpp. Please refer to this [commit](https://github.com/tinglou/llama.cpp/commit/4e3db1e3d78cc1bcd22bcb3af54bd2a4628dd323) in the [discussion](https://github.com/abetlen/llama-cpp-python/issues/1942) to fix this issue.

#### Q2: How to build with clang in conda environment on windows?

**A:**
Before building the project, verify your clang installation and access to Visual Studio tools by running:
```
clang -v
```

This command checks that you are using the correct version of clang and that the Visual Studio tools are available. If you see an error message such as:
```
'clang' is not recognized as an internal or external command, operable program or batch file.
```

It indicates that your command line window is not properly initialized for Visual Studio tools.

• If you are using Command Prompt, run:
```
"C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64
```

• If you are using Windows PowerShell, run the following commands:
```
Import-Module "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\Microsoft.VisualStudio.DevShell.dll" Enter-VsDevShell 3f0e31ad -SkipAutomaticLocation -DevCmdArguments "-arch=x64 -host_arch=x64"
```

These steps will initialize your environment and allow you to use the correct Visual Studio tools.
93 changes: 93 additions & 0 deletions gpu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# BitNet Inference Kernel

This repository provides a highly efficient GEMV kernel implementation for the BitNet model, optimized for W2A8 inference — 2-bit weights and 8-bit activations. It is tailored for use with the [BitNet-b1.58-2B-4T](https://arxiv.org/abs/2504.12285) model.

## Features

- Support for W2A8 (2-bit weight × 8-bit activation) GEMV computation
- Custom CUDA kernels with low-latency execution
- Optimizations for memory access, decoding, and compute throughput

## Usage

Installation and kernel performance tests:

```bash
# (Recommended) Create a new conda environment
conda create --name bitnet-gpu "python<3.13"
conda activate bitnet-gpu

# Install dependencies
pip install -r requirements.txt

# Build the kernel
cd bitnet_kernels
bash compile.sh
cd ..

# Run performance tests
python test.py
```

End-to-end inference:

```bash
# Download and convert the BitNet-b1.58-2B model
mkdir checkpoints
huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16
python ./convert_safetensors.py --safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors --output checkpoints/model_state.pt --model_name 2B
python ./convert_checkpoint.py --input ./checkpoints/model_state.pt
rm ./checkpoints/model_state.pt

# Inference
python3 ./generate.py ./checkpoints/ --interactive --chat_format
```

## Optimizations

### Weight Permutation

The weight matrix is divided into 16×32 blocks to optimize memory access patterns.

Within each block, values are stored contiguously in memory and permuted to facilitate efficient access and processing.

See `convert_checkpoint.py` for details.

### Fast Decoding

Every 16 two-bit values are packed into a single 32-bit integer using the following interleaving pattern:
```
[0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
```

This layout is designed to accelerate decoding by enabling efficient extraction of 4 values at a time into `int8`.

### `dp4a` Instruction

We use the `dp4a` instruction to accelerate low-precision dot product operations.

This instruction performs a dot product between two 4-element vectors (each stored in a 32-bit word as 8-bit integers) and accumulates the result into a 32-bit integer.

It significantly improves GEMV throughput when processing quantized weights and activations.


## Performance

Kernel performance (tested on NVIDIA A100 40GB GPU):

| Shape (N×K) | W2A8 Latency (us) | BF16 Latency (us) | Speedup Ratio |
|---------------------|-------------------|-------------------|----------------------|
| 2560 × 2560 | 13.32 | 18.32 | 1.38 |
| 3840 × 2560 | 14.90 | 18.87 | 1.27 |
| 13824 × 2560 | 18.75 | 59.51 | 3.17 |
| 2560 × 6912 | 14.49 | 37.78 | 2.61 |
| 3200 × 3200 | 14.61 | 19.08 | 1.31 |
| 4800 × 3200 | 13.09 | 21.84 | 1.67 |
| 3200 × 10240 | 19.64 | 60.79 | 3.10 |
| 20480 × 3200 | 30.99 | 112.39 | 3.63 |

Generation throughput:

| BF16 (tokens/s) | W2A8 (tokens/s) | Speedup Ratio |
|---|---|---|
| 10.9 | 213.3 | 19.6 |
37 changes: 37 additions & 0 deletions gpu/bitnet_kernels/bitnet_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "bitnet_kernels.h"

extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){
if (M == 1 && N == 3840 && K == 2560){
ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<<dim3(240, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if (M == 1 && N == 2560 && K == 2560){
ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if (M == 1 && N == 13824 && K == 2560){
ladder_int8xint2_kernel<1, 13824, 2560, 2, 8, 16><<<dim3(864, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if (M == 1 && N == 2560 && K == 6912){
ladder_int8xint2_kernel<1, 2560, 6912, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 4800 && K == 3200){
ladder_int8xint2_kernel<1, 4800, 3200, 6, 8, 16><<<dim3(300, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 3200 && K == 3200){
ladder_int8xint2_kernel<1, 3200, 3200, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 20480 && K == 3200){
ladder_int8xint2_kernel<1, 20480, 3200, 2, 8, 16><<<dim3(1280, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 3200 && K == 10240){
ladder_int8xint2_kernel<1, 3200, 10240, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 5120 && K == 27648){
ladder_int8xint2_kernel<1, 5120, 27648, 1, 8, 16><<<dim3(320, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 55296 && K == 5120){
ladder_int8xint2_kernel<1, 55296, 5120, 1, 8, 16><<<dim3(3456, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else{
std::cout << "required ladder gemm kernel: M " << M << ", N " << N << ", K " << K << std::endl;
}
}
83 changes: 83 additions & 0 deletions gpu/bitnet_kernels/bitnet_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include <cuda_runtime.h>
#include <math_constants.h>
#include <math.h>
#include <mma.h>
#include <iostream>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>


#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || (__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1
#else
#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0
#endif

template <typename T1, typename T2>
__device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16)
{
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast<uint *>(_i8s);

// i2s = {e0, e4, e8, e12, e1, e5, e9, e13, e2, e6, e10, e14, e3, e7, e11, e15}
uint const i2s = *_i2s;

static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000;

#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vsubss4(i8s[i], 0x02020202);
}
}

template <int M, int N, int K, int ws_num, int K_block_size, int N_block_size>
__global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restrict__ A, int8_t* __restrict__ B, __nv_bfloat16* __restrict__ dtype_transform, __nv_bfloat16* __restrict__ s, __nv_bfloat16* __restrict__ ws) {
constexpr int K_per_loop = 16;
constexpr int wmma_K = 32;
constexpr int wmma_N = 16;
int in_thread_C_local[1];
signed char A_local[K_per_loop];
int B_reshape_local[1];
signed char B_decode_local[K_per_loop];
int red_buf0[1];
in_thread_C_local[0] = 0;
#pragma unroll
for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) {
*(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop)));
B_reshape_local[0] = *(int*)(B +
(((int)blockIdx.x) * N_block_size * K / 4) +
(k_0 * K_block_size * K_per_loop * wmma_N / 4) +
((((int)threadIdx.x) >> 1) * wmma_K * wmma_N / 4) +
((((int)threadIdx.y) >> 3) * (wmma_K * wmma_N / 2) / 4) +
((((int)threadIdx.x) & 1) * (wmma_K * wmma_N / 4) / 4) +
((((int)threadIdx.y) & 7) * (wmma_K / 2) / 4)
);
decode_i2s_to_i8s(B_reshape_local, B_decode_local, 16);
#pragma unroll
for (int k_2_0 = 0; k_2_0 < 4; ++k_2_0) {
in_thread_C_local[0] = __dp4a(*(int *)&A_local[((k_2_0 * 4))],*(int *)&B_decode_local[((k_2_0 * 4))], in_thread_C_local[0]);
}
}
red_buf0[0] = in_thread_C_local[0];
#pragma unroll
for (int offset = K_block_size/2; offset > 0; offset /= 2) {
red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size);
}
int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y));
int ws_idx = out_idx / (N / ws_num);
if (threadIdx.x == 0)
dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]);
}
3 changes: 3 additions & 0 deletions gpu/bitnet_kernels/compile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitnet_kernels.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libbitnet.so


13 changes: 13 additions & 0 deletions gpu/bitnet_kernels/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
name='bitlinear_cpp',
ext_modules=[
CUDAExtension('bitlinear_cuda', [
'bitnet_kernels.cu',
])
],
cmdclass={
'build_ext': BuildExtension
})
Loading