Skip to content

Fixes #37219 : RecurrentGemma crashes for inputs longer than sliding window length #37613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,8 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
# Crop the attention mask to the target length.
attention_mask = attention_mask[:, -target_length:]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
Expand Down
37 changes: 24 additions & 13 deletions tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,12 @@ def test_initialization(self):
@slow
class RecurrentGemmaIntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
input_long_text = ['<bos><s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col.'] # fmt: skip
model_id = "google/recurrentgemma-2b"

@require_read_token
def test_2b_generate(self):
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today is a very good day for you. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do.'] # fmt: skip
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today is a new app that allows you to make money by watching videos.\n\nThe app is very simple to use and you can earn money by watching videos.\n\nThe app is available for both Android and iOS devices and you can download it from the Google Play Store or the App Store.\n\nOnce you have downloaded the app'] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(self.model_id, low_cpu_mem_usage=True).to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(self.model_id)
Expand All @@ -300,7 +301,7 @@ def test_2b_generate(self):
self.assertEqual(output_text, EXPECTED_TEXTS)

tokenizer.padding_side = "left"
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today I am going to share with you the best <strong><em>free online video editing software</em></strong>.\n\n<h2><strong>Best Free Online Video Editing Software</strong></h2>\n\n<strong>1.</strong> <strong>Wondershare Filmora</strong>\n\nWondershare Filmora is a free online video editing software that is used to edit videos.'] # fmt: skip
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today I’m going to show you how to make a simple and easy to make a <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY'] # fmt: skip

inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
Expand All @@ -320,7 +321,7 @@ def test_2b_generate(self):
@require_read_token
def test_2b_sample(self):
set_seed(0)
EXPECTED_TEXT = ['Where is Paris ?\n\nAnswer this question "yes" or "no": Could a person pass out in subzero temperatures?\n\nFor the sentence below, underline the pronoun in parentheses that agrees with its antecedent.\n\nExample 1. Mary and Pam will have the opportunity to prove (herself, $\\underline{\\text{themselves}}$) at the concert.\n\nThe waiters and the manager at the restaurant will do <em>(his, their)</em> best to assist you.\n\nA vocabulary word appears in italics in the short passage below. Think about how the word is used. Then write a definition for the vocabulary word.\n\nAfter a one-hour $'] # fmt: skip
EXPECTED_TEXT = ['Where is Paris ?\n\nChoose the word or phrase that is closest in meaning to the word in capital letters.\n\nREDEEM\n(A) sort out\n(B) think over\n(C) turn in\n(D) take back\n\nWrite the correct word in the space next to each definition. Use each word only once.\n\nto badly damage\n\nOn the lines provided below, write <em>P</em> if the underlined word group is a phrase and <em>NP</em> if it is not a phrase. Example $\\underline{\\text{P}}$ 1. We have finally discovered the secret $\\underline{\\text{of delicious pizza. }}$'] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(self.model_id)
Expand All @@ -333,13 +334,13 @@ def test_2b_sample(self):
@require_bitsandbytes
@require_read_token
def test_model_2b_8bit(self):
EXPECTED_TEXTS = ['<bos>Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "<bos>Hi today<pad><pad> I'm going to show you how to make a simple and easy to use <strong><em><u>"] # fmt: skip
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a simple and easy"] # fmt: skip

model = AutoModelForCausalLM.from_pretrained(
"gg-hf/recurrent-gemma-2b-hf", device_map={"": torch_device}, load_in_8bit=True, torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(self.model_id)
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
Expand All @@ -349,18 +350,28 @@ def test_model_2b_8bit(self):

@require_read_token
def test_long_context(self):
input_text = [
'<bos><s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col.'
]
EXPECTED_GENERATION = [
' Jean-Paul Delannoy told CNN that the BEA is "not aware of any video footage that could have been taken on board the plane." "We are not aware of any video footage that could have been taken on board the plane," Delannoy said. "We are not aware of any video footage that could'
]
EXPECTED_GENERATION = [' Jean-Paul Delannoy told CNN that the BEA is "not aware of any video footage that could have been taken on board the plane." He added that the BEA is "not aware of any video footage that could have been taken on board the plane." The BEA is the French equivalent of the National Transportation Safety Board'] # fmt: skip

model = AutoModelForCausalLM.from_pretrained(
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
inputs = tokenizer(self.input_long_text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
output_text = tokenizer.batch_decode(output[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
print(output_text)
self.assertEqual(output_text, EXPECTED_GENERATION)

@require_read_token
def test_longer_than_window(self):
EXPECTED_GENERATION = [" Robin's comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the"] # fmt: skip

model = AutoModelForCausalLM.from_pretrained(
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16
).to(torch_device)
model.config.attention_window_size = 256 # Make the attention window size shorter than the current prompt
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
inputs = tokenizer(self.input_long_text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
output_text = tokenizer.batch_decode(output[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_GENERATION)