Skip to content

Conversation

@jiqing-feng
Copy link
Contributor

The fused kernel optimized 4bit model inference about 4x speed-up on TPOT compared dequant+matmul. For next optimization of TTFT, we need to import libxsmm.

Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Copy link
Contributor

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment !

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Nov 20, 2025

Hi @matthewdouglas . The BNB will only load 1 lib (one from cpu/cuda/xpu). It means we can only build 1 .so file for the bnb. But we cannot build CPU and XPU together, because CPU relies on openMP(libiomp5.so) but XPU relies on GNU OpenMP (libgomp.so), build them together will raise error like: libbitsandbytes_xpu.so: undefined symbol: __kmpc_for_static_init_8. I suppose it's same for cuda. But without OMP, the CPU kernel might be even worse than python op, and there might be other incompatible flags across different backends.

In the current stage, we can only consider to build one backend, so the format cpu will not be triggered in other backends. Even though, I added the reverse logic in case we want to support multi-backends in the future.

cc @SunMarc

Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
@matthewdouglas
Copy link
Member

Hi @jiqing-feng

If you rebase it should trigger CI to run the tests on the PR now.

In regard to building the optimized CPU code when you're using an accelerator, you're right, I suppose has_avx512() would return false regardless of your CPU since we're only building that in the CPU-only lib? In the future I might want to change things so we always use a separate CPU library build instead of including CPU code in the accelerator libs. For now, it's a good point, so we shouldn't be rearranging weights if you're using an accelerator at all.

I'll look in more detail next week but another question I have is the naming of the gemv. It seems like this is actually a full GEMM implementation? I realize I never made a custom op for full GEMM so that's something maybe I can do in a followup PR.

@jiqing-feng
Copy link
Contributor Author

Yes, exactly! We should never rearrange weight if we are using an accelerator; it's for CPU-only.
Yes, the cpu gemv kernel contains a full gemm implementation because it's a fused kernel. We do dequant and gemm in one kernel, that's where the speed-up is from.

OK, would like to see your next round review. It would be better to give me an approximate time when we can merge this PR so I can plan for the next feature. Thanks!

@jiqing-feng
Copy link
Contributor Author

Hi @matthewdouglas . Please trigger the CI. Thanks!

Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
@matthewdouglas matthewdouglas added this to the v0.49.0 milestone Nov 25, 2025
@jiqing-feng
Copy link
Contributor Author

Hi @matthewdouglas . I've fixed and verified the save and re load. You can check it by the following script.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import shutil
import os

model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-BNB-NF4"
save_path = "./local_test_model"
input_text = "What is bitsandbytes?"

if os.path.exists(save_path):
    shutil.rmtree(save_path)

def run_generation(model, tokenizer, prompt, tag):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    print(f"\n[{tag}] Start Generating...")
    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=10, 
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    
    output_str = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"[{tag}] Output: {output_str}")
print(">>> Loading original model on CPU...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cpu",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True
)
run_generation(model, tokenizer, input_text, "Run-1")
print(f"\n>>> Saving model to {save_path}...")
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
del model
del tokenizer
import gc
gc.collect()
print(f"\n>>> Reloading model from {save_path}...")
loaded_model = AutoModelForCausalLM.from_pretrained(
    save_path,
    device_map="cpu",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True
)
loaded_tokenizer = AutoTokenizer.from_pretrained(save_path)

run_generation(loaded_model, loaded_tokenizer, input_text, "Run-2")

Output:

>>> Loading original model on CPU...
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 35.76it/s]

[Run-1] Start Generating...
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
[Run-1] Output: What is bitsandbytes? Bitsandbytes is a free online tool that allows

>>> Saving model to ./local_test_model...

>>> Reloading model from ./local_test_model...
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 35.76it/s]

[Run-2] Start Generating...
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
[Run-2] Output: What is bitsandbytes? Bitsandbytes is a free online tool that allows

Signed-off-by: jiqing-feng <[email protected]>
@matthewdouglas
Copy link
Member

Looks good, thanks!

@matthewdouglas matthewdouglas merged commit 6aa9619 into bitsandbytes-foundation:main Nov 26, 2025
129 of 132 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants