Skip to content

Make hybrid cache exportable #37623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 74 additions & 6 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from packaging import version

from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6, is_torch_greater_or_equal_than_2_7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use is_torch_greater_or_equal instead. We're shifting towards this one across the library.

(equivalent usage: is_torch_greater_or_equal("2.7", accept_dev=True))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the implementation in the codebase... Should I create this function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_torch_greater_or_equal is already imported in that file :) (see the imports from .utils)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol i see sorry


from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
Expand Down Expand Up @@ -1672,6 +1672,9 @@ def __init__(
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.config = config
self.device = device
self.layer_device_map = layer_device_map
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
Expand All @@ -1685,9 +1688,7 @@ def __init__(
)

layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
Expand Down Expand Up @@ -1800,7 +1801,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0):
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported."
)
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item()

def reset(self):
"""Resets the cache values while preserving the objects"""
Expand All @@ -1810,6 +1811,73 @@ def reset(self):
self.value_cache[layer_idx].zero_()


def _get_flat_dict_for_hybrid_cache(hybrid_cache: HybridCache):
return {
"config": getattr(hybrid_cache, "config"),
"device": str(getattr(hybrid_cache, "device")) if getattr(hybrid_cache, "device", None) is not None else None,
"layer_device_map": getattr(hybrid_cache, "layer_device_map"),
"key_cache": getattr(hybrid_cache, "key_cache"),
"value_cache": getattr(hybrid_cache, "value_cache"),
"max_batch_size": getattr(hybrid_cache, "max_batch_size"),
"max_cache_len": getattr(hybrid_cache, "max_cache_len"),
"_dtype": str(getattr(hybrid_cache, "_dtype")) if getattr(hybrid_cache, "_dtype", None) is not None else None,
}


def _flatten_hybrid_cache(
hybrid_cache: HybridCache,
):
"""Flattens HybridCache into flat list of tensors for `torch.export.export` to consume"""
if not isinstance(hybrid_cache, HybridCache):
raise RuntimeError("This pytree flattening function should only be applied to HybridCache")

if not is_torch_greater_or_equal_than_2_7:
logger.warning_once(
"HybridCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions."
)

return torch.utils._pytree._dict_flatten(_get_flat_dict_for_hybrid_cache(hybrid_cache))


def _flatten_with_keys_hybrid_cache(hybrid_cache: HybridCache):
return torch.utils._pytree._dict_flatten_with_keys(_get_flat_dict_for_hybrid_cache(hybrid_cache))


def _unflatten_hybrid_cache(
values,
context: torch.utils._pytree.Context,
):
dictionary = torch.utils._pytree._dict_unflatten(values, context)
hybrid_cache = HybridCache(
dictionary["config"],
dictionary["max_batch_size"],
dictionary["max_cache_len"],
torch.device(dictionary["device"]) if dictionary["device"] is not None else None,
getattr(torch, dictionary["_dtype"][len("torch.") :]) if dictionary["_dtype"] is not None else None,
dictionary["layer_device_map"],
)

hybrid_cache.key_cache = dictionary["key_cache"]
hybrid_cache.value_cache = dictionary["value_cache"]
return hybrid_cache


def _flatten_hybrid_cache_for_fx(hybrid_cache, spec):
return torch.utils._pytree.tree_flatten(_get_flat_dict_for_hybrid_cache(hybrid_cache))[0]


if is_torch_greater_or_equal("2.3"):
torch.utils._pytree.register_pytree_node(
HybridCache,
_flatten_hybrid_cache,
_unflatten_hybrid_cache,
serialized_type_name=f"{HybridCache.__module__}.{HybridCache.__name__}",
flatten_with_keys_fn=_flatten_with_keys_hybrid_cache,
)
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
torch.fx._pytree.register_pytree_flatten_spec(HybridCache, _flatten_hybrid_cache_for_fx)
Comment on lines +1814 to +1878
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah lot of this has nothing to do in this file and should rather go into integrations/...
We need a bit of doc about why all of this is need to help people broader support for these!



class HybridChunkedCache(Cache):
"""
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window
Expand Down Expand Up @@ -1998,7 +2066,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0):
)
if len(self.key_cache) == 0:
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item()

def reset(self):
"""Resets the cache values while preserving the objects"""
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,8 @@ def __init__(
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation

def __hash__(self):
return hash(tuple(sorted(self.__dict__)))


__all__ = ["Gemma2Config"]
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def forward(
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is PR to automatically do it: pytorch/pytorch#151348

# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch.nn as nn
import torch.utils.checkpoint

from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_7

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -201,6 +203,13 @@ def __init__(
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation

def __hash__(self):
return hash(tuple(sorted(self.__dict__)))


if is_torch_greater_or_equal_than_2_7:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed? modular_gemma2.py shouldn't be imported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! Because HybridCache is in the output, export should understand this type thoroughly. Since HybridCache depends on model config, we should tell export that it is a constant type (doesn't have any inner tensors). This API is introduced in 2.7 hence there is a guard

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhmm... my question here is not so much regarding the usefulness of these lines, but rather about the file they are in. modular_xxx.py is only used for scaffolding, and never imported (or at least it shouldn't be!).

Perhaps it should be moved to configuration_gemma2.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh i was getting error that there is difference between modular_gemma2 and generated modeling_gemma2. So i thought modular_gemma2 was the source of truth.

torch.utils._pytree.register_constant(Gemma2Config)


class Gemma2RMSNorm(GemmaRMSNorm):
pass
Expand Down Expand Up @@ -364,7 +373,7 @@ def forward(
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/gemma3/configuration_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ def __init__(
self.rope_scaling = rope_scaling
rope_config_validation(self)

def __hash__(self):
return hash(tuple(sorted(self.__dict__)))


class Gemma3Config(PretrainedConfig):
r"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

logger = logging.get_logger(__name__)

is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)

is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
Expand Down
112 changes: 112 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""Testing suite for the PyTorch Gemma2 model."""

import unittest
from contextlib import contextmanager
from unittest.mock import patch

import pytest
from packaging import version
Expand Down Expand Up @@ -337,6 +339,116 @@ def test_export_static_cache(self):
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)

@slow
@require_read_token
def test_export_hybrid_cache(self):
if version.parse(torch.__version__) < version.parse("2.7.0"):
self.skipTest(reason="This test requires torch >= 2.7 to run.")

model_id = "google/gemma-2-2b"

model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
).to(torch_device)
self.assertEqual(model.config._attn_implementation, "sdpa")
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model(**inputs, max_new_tokens=20, do_sample=False)

@contextmanager
def _detect_attribute_assignment_hacked_up(mod: torch.nn.Module):
# Do not allow assignment of tensor attributes during export unless
# the attribute is registered as a buffer.

from torch.utils import _pytree as pytree

NN_MODULE_STD_ATTRS = [
"_backward_hooks",
"_backward_pre_hooks",
"_buffers",
"_forward_hooks",
"_forward_hooks_always_called",
"_forward_hooks_with_kwargs",
"_forward_pre_hooks",
"_forward_pre_hooks_with_kwargs",
"_is_full_backward_hook",
"_load_state_dict_post_hooks",
"_load_state_dict_pre_hooks",
"_modules",
"_non_persistent_buffers_set",
"_parameters",
"_state_dict_hooks",
"_state_dict_pre_hooks",
"training",
]
NN_MODULE_LAZY_STD_ATTRS = [
"_initialize_hook",
"_load_hook",
]
STD_ATTRS = {
*NN_MODULE_STD_ATTRS,
*NN_MODULE_LAZY_STD_ATTRS,
}

def _get_attributes(mod):
# return any attributes of a module that are not standard attributes
return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}

def is_leaf(x):
# Ideally is_leaf should not be needed when mapping, but it seems that
# subclasses of a standard container X may sometimes map to X, which
# destroys information and can cause future mapping to fail.
known_subclasses_that_lose_info = (
torch.Size,
# add more here if needed
)
return isinstance(x, known_subclasses_that_lose_info)

# save state of attributes before enter
snapshot = pytree.tree_map(lambda x: x, _get_attributes(mod), is_leaf=is_leaf)
try:
yield
finally:
# after exit, compare state of attributes with snapshot
# to detect which tensor attributes were assigned
assigned_tensor_attributes = []

def _collect_assigned_tensor_attributes(kp, v, _v):
if _v is not v:
attr, *rest = kp
if isinstance(v, torch.Tensor):
assigned_tensor_attributes.append(f"self.{attr.key}{pytree.keystr(rest)}")
# TODO(avik): Assigning all other types are allowed right now.
# Maybe in the future we want to limit this to primitive types?
return v

pytree.tree_map_with_path(_collect_assigned_tensor_attributes, snapshot, _get_attributes(mod))
# restore state of all attributes (including, e.g., of primitive types)
mod.__dict__.update(snapshot)

if assigned_tensor_attributes:
if len(assigned_tensor_attributes) > 1:
noun, verb = "attributes", "were"
else:
noun, verb = "attribute", "was"
raise ValueError(
f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. "
"Such attributes must be registered as buffers using the `register_buffer` API "
"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)

# FIXME this should be gone in torch 2.7/2.8
with patch(
"torch._functorch.aot_autograd._detect_attribute_assignment", _detect_attribute_assignment_hacked_up
):
from torch.export import export_for_training

with torch.no_grad():
ep = export_for_training(model, (), {**inputs, "max_new_tokens": 20, "do_sample": False}, strict=False)
ep_out = ep.module()(**inputs, max_new_tokens=20, do_sample=False)
self.assertTrue(torch.allclose(output.logits, ep_out.logits))

@require_read_token
@tooslow
def test_model_9b_bf16_flex_attention(self):
Expand Down