-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Fixes #37219 : RecurrentGemma crashes for inputs longer than sliding window length #37613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…liding window width
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
The sample and greedy tests have been failing since at least version 4.45.2, so I updated the expected outputs to match the current behavior. Are we expected to investigate why these tests have been broken for so long, or is it acceptable to move forward with the updated outputs? @gante |
@manueldeprada let's find the first commit that caused the divergence, then we can better see if we can safely update the tests or not 🤗 ( At first glance, the change looks good to me. For reference, there was a similar fix in other models with sliding window attention, which was more complex due to fixed-sized attention mask with |
The commit that breaks the greedy test is c215523. It is a fix to EDIT: when I say that c215523 breaks the greedy decoding test, I mean that it breaks it given that you cherry pick the permissive |
c215523 is definitely a fix, so the expected output after that commit is the correct one! TL;DR: the original model commit didn't pass if position_ids is None:
position_ids = cache_position.unsqueeze(0) in 👉 your test updates in this PR make sense 👍 (if we look at your diff, we see that non-sample tests are only changed on the second input, the one that has padding) |
Aha, now I get it. Thanks a lot! I fixed as well the quantized test, and separated in a new test the I checked and all the new tests pass on c215523 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for fixing this one 🤗
…an sliding window length (huggingface#37613) * fix: RecurrentGemma crashes during inference for inputs longer than sliding window width * fix recurrentgemma tests; add long test bigger than context window
Fix: Long-context bug in RecurrentGemma generation (#37219)
This PR resolves a shape mismatch error during generation with long prompts in
RecurrentGemma
, caused by theattention_mask
not being cropped to the model’s sliding attention window.The model expects a fixed-size attention window (
attention_window_size
) during decoding, but the genericprepare_inputs_for_generation
does not enforce this. This leads to a mismatch in_update_causal_mask()
when the prompt length exceeds the window.Fix:
Crop
attention_mask
to the lastattention_window_size
tokens inside_update_causal_mask()
.This makes
RecurrentGemma
compatible with the generalprepare_inputs_for_generation
, and restores long-prompt generation without regressions.Tests
There was a previous
long_context
test, but it was 313 tokens, less than the 2048 tokens context windows. Added a new sequence to the test that covers this case.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante