Skip to content

Fixes #37219 : RecurrentGemma crashes for inputs longer than sliding window length #37613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 22, 2025

Conversation

manueldeprada
Copy link
Contributor

@manueldeprada manueldeprada commented Apr 18, 2025

Fix: Long-context bug in RecurrentGemma generation (#37219)

This PR resolves a shape mismatch error during generation with long prompts in RecurrentGemma, caused by the attention_mask not being cropped to the model’s sliding attention window.

The model expects a fixed-size attention window (attention_window_size) during decoding, but the generic prepare_inputs_for_generation does not enforce this. This leads to a mismatch in _update_causal_mask() when the prompt length exceeds the window.

Fix:
Crop attention_mask to the last attention_window_size tokens inside _update_causal_mask().

This makes RecurrentGemma compatible with the general prepare_inputs_for_generation, and restores long-prompt generation without regressions.

Tests

There was a previous long_context test, but it was 313 tokens, less than the 2048 tokens context windows. Added a new sequence to the test that covers this case.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@manueldeprada manueldeprada requested a review from gante April 21, 2025 11:15
@manueldeprada manueldeprada marked this pull request as ready for review April 21, 2025 11:15
@manueldeprada
Copy link
Contributor Author

The sample and greedy tests have been failing since at least version 4.45.2, so I updated the expected outputs to match the current behavior.

Are we expected to investigate why these tests have been broken for so long, or is it acceptable to move forward with the updated outputs? @gante

@gante
Copy link
Member

gante commented Apr 21, 2025

@manueldeprada let's find the first commit that caused the divergence, then we can better see if we can safely update the tests or not 🤗 (git bisect is the best tool here, lmk on slack if you'd like pointers)

At first glance, the change looks good to me. For reference, there was a similar fix in other models with sliding window attention, which was more complex due to fixed-sized attention mask with torch.compile (see this PR and then this one)

@manueldeprada
Copy link
Contributor Author

manueldeprada commented Apr 21, 2025

@manueldeprada let's find the first commit that caused the divergence, then we can better see if we can safely update the tests or not 🤗 (git bisect is the best tool here, lmk on slack if you'd like pointers)

At first glance, the change looks good to me. For reference, there was a similar fix in other models with sliding window attention, which was more complex due to fixed-sized attention mask with torch.compile (see this PR and then this one)

The commit that breaks the greedy test is c215523. It is a fix to unexpected keyword argument 'position_ids' that appeared due to #31549. That strict signature is much later reverted in 3d6e55c. I don't understand very well what the correct handling of position_ids would be. c215523 is obscure to me.

EDIT: when I say that c215523 breaks the greedy decoding test, I mean that it breaks it given that you cherry pick the permissive forward() signature. I guess that ideally we would like an strict signature fix that doesn't break the test.

@gante
Copy link
Member

gante commented Apr 21, 2025

c215523 is definitely a fix, so the expected output after that commit is the correct one!

TL;DR: the original model commit didn't pass position_ids from RecurrentGemmaForCausalLM to RecurrentGemmaModel. position_ids == cache_positions without padding, and we have

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

in RecurrentGemmaModel. So the test cases that have padding started being incorrect since that commit, because the correct position_ids are being piped correctly :)

👉 your test updates in this PR make sense 👍 (if we look at your diff, we see that non-sample tests are only changed on the second input, the one that has padding)

@manueldeprada
Copy link
Contributor Author

manueldeprada commented Apr 21, 2025

c215523 is definitely a fix, so the expected output after that commit is the correct one!

TL;DR: the original model commit didn't pass position_ids from RecurrentGemmaForCausalLM to RecurrentGemmaModel. position_ids == cache_positions without padding, and we have

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

in RecurrentGemmaModel. So the test cases that have padding started being incorrect since that commit, because the correct position_ids are being piped correctly :)

👉 your test updates in this PR make sense 👍 (if we look at your diff, we see that non-sample tests are only changed on the second input, the one that has padding)

Aha, now I get it. Thanks a lot! I fixed as well the quantized test, and separated in a new test the longer_than_window case, as you suggested :)

I checked and all the new tests pass on c215523

@manueldeprada manueldeprada requested a review from gante April 22, 2025 09:39
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing this one 🤗

@manueldeprada manueldeprada changed the title fix: RecurrentGemma crashes for inputs longer than sliding window length Fixes #37219: RecurrentGemma crashes for inputs longer than sliding window length Apr 22, 2025
@manueldeprada manueldeprada changed the title Fixes #37219: RecurrentGemma crashes for inputs longer than sliding window length Fixes #37219 : RecurrentGemma crashes for inputs longer than sliding window length Apr 22, 2025
@manueldeprada manueldeprada merged commit 413f9bb into huggingface:main Apr 22, 2025
14 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
…an sliding window length (huggingface#37613)

* fix: RecurrentGemma crashes during inference for inputs longer than sliding window width

* fix recurrentgemma tests; add long test bigger than context window
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RecurrentGemma crashes during inference for inputs longer than sliding window width
3 participants