-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Support for KV caching and batched inference #1934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hey, great work @mseeger . Can we decouple things a lot, though? Some initial thoughts:
Again, super good stuff in the PR! I think there are a few things to split out and consider individually and then maybe we can have a video call about the core KVCache things, wdyt? Thanks for the initiative for better KVCacheing! |
Hello, sure we can have a call, I am in the central Europe (Germany) time zone. |
My impression was that batched generation is not really there. But if it is, I don't ask to change it. One thing is important through. KV caches really work by filling positions sequentially. So, you filled positions |
Also, the implementation right now allows you to send in KV cache objects from the start. If you do not do that, it will create them by default. This is done by Note that prefill here means that I can do a single pass, and the cache can take it all, without having to evict anything. It does not mean that this will encode even the shortest prompt in the batch. If prompts are longer than the max prefill length, you need to do it sequentially in chunks. Maybe there is an easier way, we can discuss. |
It is annoying I cannot show you the KV cache code I have. But in a talk, I could explain why a few things are the way they are. Of course, I am not on top of other constraints you guys have. |
You may ask why We can do things so the very first call to the model, with
you'd call
This I could do. That would indeed be a little simpler. |
@t-vi Let me know what the next steps here should be. If I understand correctly, I could:
|
Hi, so I think we should try to break things down. We could either start with the core caching itself and try to see how to integrate it with minimal changes or see what is the deal with batching and prefill first. |
Hello @t-vi , let me try to break things down. Changes are these:
|
If I understand you correctly, you complain about 2., especially the automatic creation of default cache when nothing is done, and the change of
Would that be what you prefer? |
As for 1. and 3., in the end, they go together, but I can try split it into two. I'd first do 1., keeping the generation code in place, which would however not work for batches and not support the sequential processing of prompts properly. First doing 3. is not really sensible, because it requires things from 1. What do you think? |
Note that with DeepSeek (I am involved trying to bring this to Hugging Face), there is a lot of movement now not to ignore KV caching in the future. They even released a paper now how they can train with large contexts. |
OK, I did 2) AFAI understand. I'd work on 1) once I find time. |
No idea why all these tests are failing. Tests work for me locally. |
@t-vi Maybe I can change your mind about first keeping the current generation code in place, and only contribute the KV cache support? This is quite a bit of extra work for me, and new code of mine has a number of improvements. in particular, the current code does not really do batch generation, it is marked with several TODO and is not used. If we could have a chat, I'd appreciate that. |
Your CI system seems to be broken still. |
Out of curiosity: Why do you object to batch prompts being a list of tensors? In the end, they can have wildly different lengths, and there is not much you can do against that (sure, if you get lots of requests, you can maybe cluster them, but doing this too much delays requests, so increases latency). Also, you really don't want to push PAD tokens into models just because a prompt in a batch happens to be shorter than others. The model, not being trained on this, would certainly get confused. And since you need to start token-by-token forward for generation, you really gain nothing by padding prompts. I always thought if this as some kind of TensorFlow artefact when all tensors had to be allocated up front, etc. But I thought we have overcome this with PyTorch. |
Hey, sorry, I am totally swamped, still want to have a video call to chat.
Because lists are a lot less nice to work with in various setups passing to kernels, cudagraphs etc. For somewhat homogeneous seq lengths, padding works fine. We are using it in production, so I'm doubting claims that it does not work. It does have limitations with the inhomogenous sequence lengths, which we want to support. But the proper way to support this is packed sequences, i.e. pass in flat This is hugely more flexible. It needs FlexAttention or somesuch https://pytorch.org/docs/stable/nn.attention.flex_attention.html to make it work efficiently in stock PyTorch. |
Let me know when is a good time. I am in Europe time zone |
After our call, I think I understand more what you mean. Something like an abstraction in multi-head attention, where the input are keys, values, query for the current input chunk, all the same size, but then this is bundled:
This makes a lot of sense, and is quite elegant. |
@mseeger shall be fixed now, thank you for your patience :) |
As discussed with @t-vi , I'll refactor this as stated in the comment above. Makes total sense |
OK, I've taken out the batched inference code. Still working on fixing the tests (and need to refactor speculative decoding), but this is essentially it. |
ce22f21
to
e1b834f
Compare
@t-vi , it would be great to get some feedback on this one, before I spent time on fixing tests for code which I need to change afterwards anyway. |
BTW: Even This could be a real differentiating feature of If you know another open source library that indexes on KV caching, and which you'd like to integrate with instead, please let me know. |
@t-vi , @Borda : Any change there will be some progress here? I recognize this is a big PR. On the other hand, a decent support for selective (sparse) KV caching could be a real differentiator for I made quite some progress, also on fine-tuning with long contexts. I am trying to get approval to open source this. In my team, we start to use |
One thing still missing is good support for batch inference without excessive padding. |
ed4de7c
to
a268cfa
Compare
986564a
to
8380f6a
Compare
@t-vi , @Borda: Any sign of life on this PR? A few things:
|
If things go well, I'd love to publish this to a conference with deadline in fall, and if everything was open sourced and in I don't know if your customers struggle with fine-tuning on long context widths. We certainly do. The library I am writing is a solution to that. |
@mseeger, so is this ready for review?
|
@Borda I can work on making the tests pass. Just after a call with @t-vi , he asked for a different abstraction for KV caches, and I changed it accordingly, so was hoping for more comments on whether what I am doing here is the right thing. The PR also grew quite large over time, since I added small changes I need to make gradient computations work. I'll go over it and check how it can be split into several ones. |
For me, it'd be most important if you could make a comment on whether the KV cache abstraction is good, including factoring out the multi-head attention code in |
for more information, see https://pre-commit.ci
OK, this PR contains the following parts:
In the library I am writing, there are a number of additional more powerful KV caches, such as H2O and quantization-aware H2O. I am also working on fine-tuning in the presence of KV caches. The abstraction I propose here, enables all of that. If these changes are not done, I'd have to copy and change quite a bit of your code. This would be hard to maintain, and would run the risk that KV caches are implemented differently at a later point, and then things really diverge. As I said in the comments above, I found KV caching to be super-important to make large context inference work on a moderate GPU budget, which should be of interest to your customers as well. |
I could now work to fix the broken tests. But I also would like to get some input, related to the comment "Can we decouple things a lot, though?" at the very top here. I could break this into two, first the "small things" and "refactoring MHA", then the "KV cache abstraction". This is more work for me, but I'd be OK doing it, if this means to get it merged. While "refactoring MHA" has its own benefits, it is of course motivated by the later goal. |
I am continuing this work in #2061 |
Adds abstraction for key-value caches, implements batched inference.
I am also adding two baseline KV caches, the default one from before (all KV are stored) and a last-recent one.
The abstraction contains methods not used by these baselines, but they are required to implement more advanced KV caches such as Heavy Hitter Oracle (H2O).
I have implemented some of these, but I may not be allowed to contribute them here (working for a company). I'll see what I can do.