-
-
Notifications
You must be signed in to change notification settings - Fork 796
Cpu fused kernel #1804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cpu fused kernel #1804
Conversation
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a comment !
|
Hi @matthewdouglas . The BNB will only load 1 lib (one from cpu/cuda/xpu). It means we can only build 1 .so file for the bnb. But we cannot build CPU and XPU together, because CPU relies on openMP(libiomp5.so) but XPU relies on GNU OpenMP (libgomp.so), build them together will raise error like: In the current stage, we can only consider to build one backend, so the format cpu will not be triggered in other backends. Even though, I added the reverse logic in case we want to support multi-backends in the future. cc @SunMarc |
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
|
Hi @jiqing-feng If you rebase it should trigger CI to run the tests on the PR now. In regard to building the optimized CPU code when you're using an accelerator, you're right, I suppose has_avx512() would return false regardless of your CPU since we're only building that in the CPU-only lib? In the future I might want to change things so we always use a separate CPU library build instead of including CPU code in the accelerator libs. For now, it's a good point, so we shouldn't be rearranging weights if you're using an accelerator at all. I'll look in more detail next week but another question I have is the naming of the gemv. It seems like this is actually a full GEMM implementation? I realize I never made a custom op for full GEMM so that's something maybe I can do in a followup PR. |
|
Yes, exactly! We should never rearrange weight if we are using an accelerator; it's for CPU-only. OK, would like to see your next round review. It would be better to give me an approximate time when we can merge this PR so I can plan for the next feature. Thanks! |
|
Hi @matthewdouglas . Please trigger the CI. Thanks! |
d2de0f5 to
0045c4b
Compare
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
|
Hi @matthewdouglas . I've fixed and verified the save and re load. You can check it by the following script. import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import shutil
import os
model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-BNB-NF4"
save_path = "./local_test_model"
input_text = "What is bitsandbytes?"
if os.path.exists(save_path):
shutil.rmtree(save_path)
def run_generation(model, tokenizer, prompt, tag):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
print(f"\n[{tag}] Start Generating...")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
output_str = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"[{tag}] Output: {output_str}")
print(">>> Loading original model on CPU...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cpu",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
)
run_generation(model, tokenizer, input_text, "Run-1")
print(f"\n>>> Saving model to {save_path}...")
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
del model
del tokenizer
import gc
gc.collect()
print(f"\n>>> Reloading model from {save_path}...")
loaded_model = AutoModelForCausalLM.from_pretrained(
save_path,
device_map="cpu",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
)
loaded_tokenizer = AutoTokenizer.from_pretrained(save_path)
run_generation(loaded_model, loaded_tokenizer, input_text, "Run-2")Output: |
Signed-off-by: jiqing-feng <[email protected]>
|
Looks good, thanks! |
6aa9619
into
bitsandbytes-foundation:main
The fused kernel optimized 4bit model inference about 4x speed-up on TPOT compared dequant+matmul. For next optimization of TTFT, we need to import libxsmm.