Skip to content

Make hybrid cache exportable #37623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

tugsbayasgalan
Copy link
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions bot marked this pull request as draft April 18, 2025 21:52
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@@ -317,7 +317,7 @@ def forward(
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is PR to automatically do it: pytorch/pytorch#151348

@tugsbayasgalan tugsbayasgalan marked this pull request as ready for review April 21, 2025 21:18
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good to me, added a few nits :)

In parallel, I'm going to confirm that there is no speed degradation with torch.compile with the Hybrid cache.

@@ -8,7 +8,7 @@
import torch
from packaging import version

from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6, is_torch_greater_or_equal_than_2_7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use is_torch_greater_or_equal instead. We're shifting towards this one across the library.

(equivalent usage: is_torch_greater_or_equal("2.7", accept_dev=True))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the implementation in the codebase... Should I create this function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_torch_greater_or_equal is already imported in that file :) (see the imports from .utils)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol i see sorry

return hash(tuple(sorted(self.__dict__)))


if is_torch_greater_or_equal_than_2_7:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed? modular_gemma2.py shouldn't be imported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! Because HybridCache is in the output, export should understand this type thoroughly. Since HybridCache depends on model config, we should tell export that it is a constant type (doesn't have any inner tensors). This API is introduced in 2.7 hence there is a guard

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhmm... my question here is not so much regarding the usefulness of these lines, but rather about the file they are in. modular_xxx.py is only used for scaffolding, and never imported (or at least it shouldn't be!).

Perhaps it should be moved to configuration_gemma2.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh i was getting error that there is difference between modular_gemma2 and generated modeling_gemma2. So i thought modular_gemma2 was the source of truth.

@@ -28,6 +28,7 @@

logger = logging.get_logger(__name__)

is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)

@gante
Copy link
Member

gante commented Apr 22, 2025

There seems to be no throughput degradation vs main

main:
Screenshot 2025-04-22 at 10 28 32

this PR:
Screenshot 2025-04-22 at 10 23 26

Test script
import copy
import os
import torch
from torch.utils import benchmark

from transformers import AutoTokenizer, AutoModelForCausalLM

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Benchmarking settings
BSZ = [1, 4]
NEW_TOK = [16, 256]
N_ITER = 10
MODEL_ID = "google/gemma-2-2b-it"
CACHE_IMPLEMENTATION = "hybrid"

# Other constants
FRANCE_ARTICLE = (
  "<s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight "
  "9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille "
  "prosecutor Brice Robin told CNN that \"so far no videos were used in the crash investigation.\" He added, \"A "
  "person who has such a video needs to immediately give it to the investigators.\" Robin\'s comments follow claims "
  "by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final "
  "seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. "
  "Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two "
  "publications described the supposed video, but did not post it on their websites. The publications said that "
  "they watched the video, which was found by a source close to the investigation. \"One can hear cries of 'My God' "
  "in several languages,\" Paris Match reported. \"Metallic banging can also be heard more than three times, "
  "perhaps of the pilot trying to open the cockpit door with a heavy object.  Towards the end, after a heavy "
  "shake, stronger than the others, the screaming intensifies. Then nothing.\" \"It is a very disturbing scene,\" "
  "said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, "
  "the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie "
  "spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the "
  "reports were \"completely wrong\" and \"unwarranted.\" Cell phones have been collected at the site, he said, "
  "but that they \"hadn\'t been exploited yet.\" Menichini said he believed the cell phones would need to be sent "
  "to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized "
  "technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent "
  "to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card "
  "to the media, Menichini answered with a categorical \"no.\" Reichelt told \"Erin Burnett: Outfront\" that he "
  "had watched the video and stood by the report, saying Bild and Paris Match are \"very confident\" that the clip "
  "is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after "
  "Bild and Paris Match published their reports. \"That is something we did not know before. ... Overall we can "
  "say many things of the investigation weren\'t revealed by the investigation at the beginning,\" he said. What "
  "was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas "
  "Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s "
  "accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training "
  "school in 2009 that he had a \"previous episode of severe depression,\" the airline said Tuesday. Email "
  "correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, "
  "included medical documents he submitted in connection with resuming his flight training. The announcement "
  "indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle with depression, allowed "
  "him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously "
  "said Lubitz was 100% fit to fly, described its statement Tuesday as a \"swift and seamless clarification\" and "
  "said it was sharing the information and documents -- including training and medical records -- with public "
  "prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past "
  "week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center "
  "set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving "
  "families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human "
  "remains were left at the site but recovery teams would keep searching. French President Francois Hollande, "
  "speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the "
  "end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the "
  "victims' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be "
  "more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our "
  "correspondents . The details about Lubitz's correspondence with the flight school during his training were "
  "among several developments as investigators continued to delve into what caused the crash and Lubitz\'s "
  "possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid "
  "medical certificate, had passed all his examinations and \"held all the licenses required.\" Earlier, a "
  "spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz "
  "suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before "
  "he got his pilot's license. Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting "
  "aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition "
  "would cause him to lose his pilot's license, a European government official briefed on the investigation told "
  "CNN on Tuesday. While flying was \"a big part of his life,\" the source said, it\'s only one theory being "
  "considered. Another source, a law enforcement official briefed on the investigation, also told CNN that "
  "authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not "
  "be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye "
  "doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had "
  "psychological issues, the European government official said. But no matter what details emerge about his "
  "previous mental health struggles, there's more to the story, said Brian Russell, a forensic psychologist. "
  "\"Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they "
  "weren't going to keep doing their job and they're upset about that and so they're suicidal,\" he said. \"But "
  "there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it "
  "outward on 149 other people who had nothing to do with the person's problems.\" Germanwings crash compensation: "
  "What we know . Who was the captain of Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and "
  "Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela "
  "Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report."
)


tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to("cuda")
prompt_length = tokenizer([FRANCE_ARTICLE], return_tensors="pt").input_ids.shape[1]
label_ms_per_token = f"Throughput (time/foward pass, prompt = {prompt_length} tokens)"
label_first_step = f"First call (time, prompt = {prompt_length} tokens)"


def print_results(all_results):
  print("\n")
  compare = benchmark.Compare(all_results)
  compare.trim_significant_figures()
  compare.colorize(rowwise = True)
  compare.print()


def time_generate_call(model, task, ms_per_token, first_step, compile=False):
  for bsz in BSZ:
      for max_new_tokens in NEW_TOK:
          input_ids = tokenizer([FRANCE_ARTICLE] * bsz, return_tensors="pt").to("cuda")
          description = f"batch size, max_new_tokens: {bsz, max_new_tokens}"
          task_spec_ms_per_token = benchmark.TaskSpec(
              stmt="", setup="", description=task, label=label_ms_per_token, sub_label=description
          )
          task_spec_ms_first_step = benchmark.TaskSpec(
              stmt="", setup="", description=task, label=label_first_step, sub_label=description
          )

          # generate EXACTLY `max_new_tokens` tokens (no early termination due to `eos_token_id`)
          generation_kwargs = {
              "max_new_tokens": max_new_tokens,
              "min_new_tokens": max_new_tokens,
              "eos_token_id": None,
              "do_sample": False,
              "cache_implementation": CACHE_IMPLEMENTATION if compile else None
          }
          generation_config = copy.deepcopy(model.generation_config)
          generation_config.update(**generation_kwargs)

          torch.compiler.reset()
          results = []
          for _ in range(N_ITER):
              start = torch.cuda.Event(enable_timing=True)
              end = torch.cuda.Event(enable_timing=True)
              start.record()
              gen_out = model.generate(**input_ids, generation_config=generation_config)
              end.record()
              torch.cuda.synchronize()
              total_time = start.elapsed_time(end) / 1000  # time in seconds
              time_per_forward = total_time / max_new_tokens
              assert gen_out.shape[1] == max_new_tokens + prompt_length
              results.append(time_per_forward)

          ms_per_token.append(benchmark.Measurement(1, results[3:], task_spec_ms_per_token, metadata=None))
          first_step.append(benchmark.Measurement(
              1, [results[0] * max_new_tokens], task_spec_ms_first_step, metadata=None)
          )
          print_results(ms_per_token)
          print_results(first_step)
          print("*" * 80)


ms_per_token = []
first_step = []

# eager
with torch.compiler.set_stance("force_eager"):
  time_generate_call(model, "eager", ms_per_token, first_step)

# compiled
time_generate_call(model, "compiled", ms_per_token, first_step, compile=True)

@tugsbayasgalan tugsbayasgalan requested a review from gante April 23, 2025 14:24
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two more nits and should be ready to go 👍 (replied in the original threads above)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!
Let's make sure export specific logic is kept outside the more general cache file!

Comment on lines +1814 to +1878
def _get_flat_dict_for_hybrid_cache(hybrid_cache: HybridCache):
return {
"config": getattr(hybrid_cache, "config"),
"device": str(getattr(hybrid_cache, "device")) if getattr(hybrid_cache, "device", None) is not None else None,
"layer_device_map": getattr(hybrid_cache, "layer_device_map"),
"key_cache": getattr(hybrid_cache, "key_cache"),
"value_cache": getattr(hybrid_cache, "value_cache"),
"max_batch_size": getattr(hybrid_cache, "max_batch_size"),
"max_cache_len": getattr(hybrid_cache, "max_cache_len"),
"_dtype": str(getattr(hybrid_cache, "_dtype")) if getattr(hybrid_cache, "_dtype", None) is not None else None,
}


def _flatten_hybrid_cache(
hybrid_cache: HybridCache,
):
"""Flattens HybridCache into flat list of tensors for `torch.export.export` to consume"""
if not isinstance(hybrid_cache, HybridCache):
raise RuntimeError("This pytree flattening function should only be applied to HybridCache")

if not is_torch_greater_or_equal_than_2_7:
logger.warning_once(
"HybridCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions."
)

return torch.utils._pytree._dict_flatten(_get_flat_dict_for_hybrid_cache(hybrid_cache))


def _flatten_with_keys_hybrid_cache(hybrid_cache: HybridCache):
return torch.utils._pytree._dict_flatten_with_keys(_get_flat_dict_for_hybrid_cache(hybrid_cache))


def _unflatten_hybrid_cache(
values,
context: torch.utils._pytree.Context,
):
dictionary = torch.utils._pytree._dict_unflatten(values, context)
hybrid_cache = HybridCache(
dictionary["config"],
dictionary["max_batch_size"],
dictionary["max_cache_len"],
torch.device(dictionary["device"]) if dictionary["device"] is not None else None,
getattr(torch, dictionary["_dtype"][len("torch.") :]) if dictionary["_dtype"] is not None else None,
dictionary["layer_device_map"],
)

hybrid_cache.key_cache = dictionary["key_cache"]
hybrid_cache.value_cache = dictionary["value_cache"]
return hybrid_cache


def _flatten_hybrid_cache_for_fx(hybrid_cache, spec):
return torch.utils._pytree.tree_flatten(_get_flat_dict_for_hybrid_cache(hybrid_cache))[0]


if is_torch_greater_or_equal("2.3"):
torch.utils._pytree.register_pytree_node(
HybridCache,
_flatten_hybrid_cache,
_unflatten_hybrid_cache,
serialized_type_name=f"{HybridCache.__module__}.{HybridCache.__name__}",
flatten_with_keys_fn=_flatten_with_keys_hybrid_cache,
)
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
torch.fx._pytree.register_pytree_flatten_spec(HybridCache, _flatten_hybrid_cache_for_fx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah lot of this has nothing to do in this file and should rather go into integrations/...
We need a bit of doc about why all of this is need to help people broader support for these!

@tugsbayasgalan tugsbayasgalan mentioned this pull request Apr 30, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants