-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Make hybrid cache exportable #37623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Make hybrid cache exportable #37623
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
@@ -317,7 +317,7 @@ def forward( | |||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing | |||
offset = cache_position[-1] - effective_seq_len + 1 | |||
# Should only be used when beyond the sliding window (i.e. offset > 0) | |||
offset = max(0, offset) | |||
offset = torch.clamp(offset, min=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is PR to automatically do it: pytorch/pytorch#151348
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good to me, added a few nits :)
In parallel, I'm going to confirm that there is no speed degradation with torch.compile
with the Hybrid cache.
@@ -8,7 +8,7 @@ | |||
import torch | |||
from packaging import version | |||
|
|||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6 | |||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6, is_torch_greater_or_equal_than_2_7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use is_torch_greater_or_equal
instead. We're shifting towards this one across the library.
(equivalent usage: is_torch_greater_or_equal("2.7", accept_dev=True)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the implementation in the codebase... Should I create this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_torch_greater_or_equal
is already imported in that file :) (see the imports from .utils
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lol i see sorry
return hash(tuple(sorted(self.__dict__))) | ||
|
||
|
||
if is_torch_greater_or_equal_than_2_7: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change needed? modular_gemma2.py
shouldn't be imported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! Because HybridCache is in the output, export should understand this type thoroughly. Since HybridCache depends on model config, we should tell export that it is a constant type (doesn't have any inner tensors). This API is introduced in 2.7 hence there is a guard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uhmm... my question here is not so much regarding the usefulness of these lines, but rather about the file they are in. modular_xxx.py
is only used for scaffolding, and never imported (or at least it shouldn't be!).
Perhaps it should be moved to configuration_gemma2.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh i was getting error that there is difference between modular_gemma2 and generated modeling_gemma2. So i thought modular_gemma2 was the source of truth.
@@ -28,6 +28,7 @@ | |||
|
|||
logger = logging.get_logger(__name__) | |||
|
|||
is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True) |
There seems to be no throughput degradation vs Test scriptimport copy
import os
import torch
from torch.utils import benchmark
from transformers import AutoTokenizer, AutoModelForCausalLM
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Benchmarking settings
BSZ = [1, 4]
NEW_TOK = [16, 256]
N_ITER = 10
MODEL_ID = "google/gemma-2-2b-it"
CACHE_IMPLEMENTATION = "hybrid"
# Other constants
FRANCE_ARTICLE = (
"<s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight "
"9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille "
"prosecutor Brice Robin told CNN that \"so far no videos were used in the crash investigation.\" He added, \"A "
"person who has such a video needs to immediately give it to the investigators.\" Robin\'s comments follow claims "
"by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final "
"seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. "
"Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two "
"publications described the supposed video, but did not post it on their websites. The publications said that "
"they watched the video, which was found by a source close to the investigation. \"One can hear cries of 'My God' "
"in several languages,\" Paris Match reported. \"Metallic banging can also be heard more than three times, "
"perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy "
"shake, stronger than the others, the screaming intensifies. Then nothing.\" \"It is a very disturbing scene,\" "
"said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, "
"the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie "
"spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the "
"reports were \"completely wrong\" and \"unwarranted.\" Cell phones have been collected at the site, he said, "
"but that they \"hadn\'t been exploited yet.\" Menichini said he believed the cell phones would need to be sent "
"to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized "
"technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent "
"to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card "
"to the media, Menichini answered with a categorical \"no.\" Reichelt told \"Erin Burnett: Outfront\" that he "
"had watched the video and stood by the report, saying Bild and Paris Match are \"very confident\" that the clip "
"is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after "
"Bild and Paris Match published their reports. \"That is something we did not know before. ... Overall we can "
"say many things of the investigation weren\'t revealed by the investigation at the beginning,\" he said. What "
"was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas "
"Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s "
"accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training "
"school in 2009 that he had a \"previous episode of severe depression,\" the airline said Tuesday. Email "
"correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, "
"included medical documents he submitted in connection with resuming his flight training. The announcement "
"indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle with depression, allowed "
"him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously "
"said Lubitz was 100% fit to fly, described its statement Tuesday as a \"swift and seamless clarification\" and "
"said it was sharing the information and documents -- including training and medical records -- with public "
"prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past "
"week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center "
"set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving "
"families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human "
"remains were left at the site but recovery teams would keep searching. French President Francois Hollande, "
"speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the "
"end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the "
"victims' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be "
"more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our "
"correspondents . The details about Lubitz's correspondence with the flight school during his training were "
"among several developments as investigators continued to delve into what caused the crash and Lubitz\'s "
"possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid "
"medical certificate, had passed all his examinations and \"held all the licenses required.\" Earlier, a "
"spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz "
"suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before "
"he got his pilot's license. Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting "
"aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition "
"would cause him to lose his pilot's license, a European government official briefed on the investigation told "
"CNN on Tuesday. While flying was \"a big part of his life,\" the source said, it\'s only one theory being "
"considered. Another source, a law enforcement official briefed on the investigation, also told CNN that "
"authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not "
"be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye "
"doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had "
"psychological issues, the European government official said. But no matter what details emerge about his "
"previous mental health struggles, there's more to the story, said Brian Russell, a forensic psychologist. "
"\"Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they "
"weren't going to keep doing their job and they're upset about that and so they're suicidal,\" he said. \"But "
"there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it "
"outward on 149 other people who had nothing to do with the person's problems.\" Germanwings crash compensation: "
"What we know . Who was the captain of Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and "
"Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela "
"Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report."
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to("cuda")
prompt_length = tokenizer([FRANCE_ARTICLE], return_tensors="pt").input_ids.shape[1]
label_ms_per_token = f"Throughput (time/foward pass, prompt = {prompt_length} tokens)"
label_first_step = f"First call (time, prompt = {prompt_length} tokens)"
def print_results(all_results):
print("\n")
compare = benchmark.Compare(all_results)
compare.trim_significant_figures()
compare.colorize(rowwise = True)
compare.print()
def time_generate_call(model, task, ms_per_token, first_step, compile=False):
for bsz in BSZ:
for max_new_tokens in NEW_TOK:
input_ids = tokenizer([FRANCE_ARTICLE] * bsz, return_tensors="pt").to("cuda")
description = f"batch size, max_new_tokens: {bsz, max_new_tokens}"
task_spec_ms_per_token = benchmark.TaskSpec(
stmt="", setup="", description=task, label=label_ms_per_token, sub_label=description
)
task_spec_ms_first_step = benchmark.TaskSpec(
stmt="", setup="", description=task, label=label_first_step, sub_label=description
)
# generate EXACTLY `max_new_tokens` tokens (no early termination due to `eos_token_id`)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"min_new_tokens": max_new_tokens,
"eos_token_id": None,
"do_sample": False,
"cache_implementation": CACHE_IMPLEMENTATION if compile else None
}
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
torch.compiler.reset()
results = []
for _ in range(N_ITER):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
gen_out = model.generate(**input_ids, generation_config=generation_config)
end.record()
torch.cuda.synchronize()
total_time = start.elapsed_time(end) / 1000 # time in seconds
time_per_forward = total_time / max_new_tokens
assert gen_out.shape[1] == max_new_tokens + prompt_length
results.append(time_per_forward)
ms_per_token.append(benchmark.Measurement(1, results[3:], task_spec_ms_per_token, metadata=None))
first_step.append(benchmark.Measurement(
1, [results[0] * max_new_tokens], task_spec_ms_first_step, metadata=None)
)
print_results(ms_per_token)
print_results(first_step)
print("*" * 80)
ms_per_token = []
first_step = []
# eager
with torch.compiler.set_stance("force_eager"):
time_generate_call(model, "eager", ms_per_token, first_step)
# compiled
time_generate_call(model, "compiled", ms_per_token, first_step, compile=True) |
034e5f4
to
8e9507a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two more nits and should be ready to go 👍 (replied in the original threads above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Let's make sure export
specific logic is kept outside the more general cache
file!
def _get_flat_dict_for_hybrid_cache(hybrid_cache: HybridCache): | ||
return { | ||
"config": getattr(hybrid_cache, "config"), | ||
"device": str(getattr(hybrid_cache, "device")) if getattr(hybrid_cache, "device", None) is not None else None, | ||
"layer_device_map": getattr(hybrid_cache, "layer_device_map"), | ||
"key_cache": getattr(hybrid_cache, "key_cache"), | ||
"value_cache": getattr(hybrid_cache, "value_cache"), | ||
"max_batch_size": getattr(hybrid_cache, "max_batch_size"), | ||
"max_cache_len": getattr(hybrid_cache, "max_cache_len"), | ||
"_dtype": str(getattr(hybrid_cache, "_dtype")) if getattr(hybrid_cache, "_dtype", None) is not None else None, | ||
} | ||
|
||
|
||
def _flatten_hybrid_cache( | ||
hybrid_cache: HybridCache, | ||
): | ||
"""Flattens HybridCache into flat list of tensors for `torch.export.export` to consume""" | ||
if not isinstance(hybrid_cache, HybridCache): | ||
raise RuntimeError("This pytree flattening function should only be applied to HybridCache") | ||
|
||
if not is_torch_greater_or_equal_than_2_7: | ||
logger.warning_once( | ||
"HybridCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions." | ||
) | ||
|
||
return torch.utils._pytree._dict_flatten(_get_flat_dict_for_hybrid_cache(hybrid_cache)) | ||
|
||
|
||
def _flatten_with_keys_hybrid_cache(hybrid_cache: HybridCache): | ||
return torch.utils._pytree._dict_flatten_with_keys(_get_flat_dict_for_hybrid_cache(hybrid_cache)) | ||
|
||
|
||
def _unflatten_hybrid_cache( | ||
values, | ||
context: torch.utils._pytree.Context, | ||
): | ||
dictionary = torch.utils._pytree._dict_unflatten(values, context) | ||
hybrid_cache = HybridCache( | ||
dictionary["config"], | ||
dictionary["max_batch_size"], | ||
dictionary["max_cache_len"], | ||
torch.device(dictionary["device"]) if dictionary["device"] is not None else None, | ||
getattr(torch, dictionary["_dtype"][len("torch.") :]) if dictionary["_dtype"] is not None else None, | ||
dictionary["layer_device_map"], | ||
) | ||
|
||
hybrid_cache.key_cache = dictionary["key_cache"] | ||
hybrid_cache.value_cache = dictionary["value_cache"] | ||
return hybrid_cache | ||
|
||
|
||
def _flatten_hybrid_cache_for_fx(hybrid_cache, spec): | ||
return torch.utils._pytree.tree_flatten(_get_flat_dict_for_hybrid_cache(hybrid_cache))[0] | ||
|
||
|
||
if is_torch_greater_or_equal("2.3"): | ||
torch.utils._pytree.register_pytree_node( | ||
HybridCache, | ||
_flatten_hybrid_cache, | ||
_unflatten_hybrid_cache, | ||
serialized_type_name=f"{HybridCache.__module__}.{HybridCache.__name__}", | ||
flatten_with_keys_fn=_flatten_with_keys_hybrid_cache, | ||
) | ||
# TODO (tmanlaibaatar) This won't be needed in torch 2.7. | ||
torch.fx._pytree.register_pytree_flatten_spec(HybridCache, _flatten_hybrid_cache_for_fx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah lot of this has nothing to do in this file and should rather go into integrations/...
We need a bit of doc about why all of this is need to help people broader support for these!
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.