Skip to content

[Perf] [CPU] eliminate redundant memory access in group query attention #13319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

ZelinMa557
Copy link

@ZelinMa557 ZelinMa557 commented May 5, 2025

Modern LLMs (Llama3, qwen 2.5, etc) usually use group query attention, which significantly reduces memory usage caused by KV cache. Group query attention means that query rows of neighbor query heads share kv rows of the same kv head, so we can reorder the loop to:

// python style pseudo code
for group_id in (0,group_num):
    for seq_id in (0, seq_length):
        k = load_k(group_id, seq_id)
        v = load_v(group_id, seq_id)
        for head_id in (group_id * n_gqa, group_id * n_gqa +n_gqa):
              q = load_q(head_id, seq_id)
              compute(q, k, v)

to improve spatial locality of memory access. However the original implemention of cpu flash attention kernel didn't consider that, and this pr improves it.

This is my test command:

./build/bin/llama-cli -t 4 -fa --ctx-size 8192 -m models/Qwen2.5-Coder-7B-Instruct-Q2_K.gguf -f convert_lora_to_gguf.py

The mastrer branch result:

llama_perf_sampler_print:    sampling time =      45.59 ms /  4647 runs   (    0.01 ms per token, 101939.19 tokens per second)
llama_perf_context_print:        load time =     687.54 ms
llama_perf_context_print: prompt eval time =  588053.13 ms /  4412 tokens (  133.28 ms per token,     7.50 tokens per second)
llama_perf_context_print:        eval time =   71929.76 ms /   234 runs   (  307.39 ms per token,     3.25 tokens per second)
llama_perf_context_print:       total time =  660956.03 ms /  4646 tokens
Interrupted by user

With the optimization, the result is:

llama_perf_sampler_print:    sampling time =      56.22 ms /  4717 runs   (    0.01 ms per token, 83901.03 tokens per second)
llama_perf_context_print:        load time =     870.17 ms
llama_perf_context_print: prompt eval time =  574061.97 ms /  4415 tokens (  130.03 ms per token,     7.69 tokens per second)
llama_perf_context_print:        eval time =   71333.37 ms /   301 runs   (  236.99 ms per token,     4.22 tokens per second)
llama_perf_context_print:       total time =  646281.74 ms /  4716 tokens
Interrupted by user

We can see slight speed up in prefill, and 25% speed up in decode!

Further work:

  1. flash decoding: in this pr, when n_kv_head < thread num, and there is only one concurrent request, this cpu kernel cannot use all the threads. we can solve this by using flash decoding.
  2. load balancing between threads: in causual attention, the computation amount is different between rows, but the current implementation dosen't take that into consideration, which slows down the multi-threaded long-context prefill speed.

My test environment:

Architecture:            x86_64
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         39 bits physical, 48 bits virtual
  Byte Order:            Little Endian
CPU(s):                  8
  On-line CPU(s) list:   0-7
Vendor ID:               GenuineIntel
  Model name:            Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
    CPU family:          6
    Model:               142
    Thread(s) per core:  2
    Core(s) per socket:  4
    Socket(s):           1
    Stepping:            12
    BogoMIPS:            4607.99
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse
                          sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid pni pclm
                         ulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor la
                         hf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase bmi1 avx2
                          smep bmi2 erms invpcid rdseed adx smap clflushopt xsaveopt xsavec xgetbv1 xsaves md_clear f
                         lush_l1d arch_capabilities
Virtualization features: 
  Hypervisor vendor:     Microsoft
  Virtualization type:   full
Caches (sum of all):     
  L1d:                   128 KiB (4 instances)
  L1i:                   128 KiB (4 instances)
  L2:                    1 MiB (4 instances)
  L3:                    8 MiB (1 instance)
Vulnerabilities:         
  Itlb multihit:         KVM: Mitigation: VMX unsupported
  L1tf:                  Not affected
  Mds:                   Not affected
  Meltdown:              Not affected
  Spec store bypass:     Mitigation; Speculative Store Bypass disabled via prctl and seccomp
  Spectre v1:            Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:            Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
  Srbds:                 Mitigation; TSX disabled
  Tsx async abort:       Not affected

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label May 5, 2025
@MaggotHATE
Copy link
Contributor

Just tested this out of curiosity: Qwen 3 degrades in quality (ignores /no_think, for example), Mistral Small 2 outputs empty characters. Does it break compatibility with older models? Windows 10, i7-8700, CPU backend only.

@ZelinMa557
Copy link
Author

ZelinMa557 commented May 6, 2025

Just tested this out of curiosity: Qwen 3 degrades in quality (ignores /no_think, for example), Mistral Small 2 outputs empty characters. Does it break compatibility with older models? Windows 10, i7-8700, CPU backend only.

Hi, thanks for your reply! It do not break compatibility with older models in theory, but there might be small bugs in my implementation. In my test, it works with Qwen 2.5 7b. Can you tell me the Qwen3 model size you used to test? I will test both qwen3 and mistral to debug.

@MaggotHATE
Copy link
Contributor

Can you tell me the Qwen3 model size you used to test? I will test both qwen3 and mistral to debug.

I've tested both 8b and 4b models in Q6, both worked correctly without this PR. Mistral Small 2 is in Q5_K_L, works correctly on main too.

@ZelinMa557
Copy link
Author

I've tested both 8b and 4b models in Q6, both worked correctly without this PR. Mistral Small 2 is in Q5_K_L, works correctly on main too.

Thanks, I have reproduced the same problem. I will try to fix it.

Signed-off-by: ZelinMa557 <[email protected]>
@ZelinMa557
Copy link
Author

I have fixed the bug. Are there any scripts to format the code locally? This pr cannot pass the code lint now

@MaggotHATE
Copy link
Contributor

I have fixed the bug. Are there any scripts to format the code locally? This pr cannot pass the code lint now

Thank you! I've already deleted Qwen models, unfortunately, but Mistral Small 2 generates text correctly now. I'll test it a bit more with other models, but so far it seems to be fixed.

On i7 8700 with Mistral Small 3 (the 24b one, q4_k_m) I get 2.08t/s with this PR vs 1.97t/s on current main.

Signed-off-by: ZelinMa557 <[email protected]>
@ZelinMa557
Copy link
Author

image The CI says that there are trailing whitespaces at line 7045, but I cannot find trailing whitespaces at that line. That is quite strange.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants