-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Make hybrid cache exportable #37623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Make hybrid cache exportable #37623
Changes from all commits
1d03c6b
3ee5884
238a95c
0730752
9eb49aa
8e9507a
158a218
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
import torch | ||
from packaging import version | ||
|
||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6 | ||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6, is_torch_greater_or_equal_than_2_7 | ||
|
||
from .configuration_utils import PretrainedConfig | ||
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging | ||
|
@@ -1672,6 +1672,9 @@ def __init__( | |
"sliding window attention, please check if there is a `sliding_window` field in the model " | ||
"config and it's not set to None." | ||
) | ||
self.config = config | ||
self.device = device | ||
self.layer_device_map = layer_device_map | ||
self.max_cache_len = max_cache_len | ||
self.max_batch_size = max_batch_size | ||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads | ||
|
@@ -1685,9 +1688,7 @@ def __init__( | |
) | ||
|
||
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC | ||
self.is_sliding = torch.tensor( | ||
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool | ||
) | ||
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)] | ||
self.key_cache: List[torch.Tensor] = [] | ||
self.value_cache: List[torch.Tensor] = [] | ||
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) | ||
|
@@ -1800,7 +1801,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0): | |
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " | ||
"Using the `layer_idx` argument is not supported." | ||
) | ||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() | ||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item() | ||
|
||
def reset(self): | ||
"""Resets the cache values while preserving the objects""" | ||
|
@@ -1810,6 +1811,73 @@ def reset(self): | |
self.value_cache[layer_idx].zero_() | ||
|
||
|
||
def _get_flat_dict_for_hybrid_cache(hybrid_cache: HybridCache): | ||
return { | ||
"config": getattr(hybrid_cache, "config"), | ||
"device": str(getattr(hybrid_cache, "device")) if getattr(hybrid_cache, "device", None) is not None else None, | ||
"layer_device_map": getattr(hybrid_cache, "layer_device_map"), | ||
"key_cache": getattr(hybrid_cache, "key_cache"), | ||
"value_cache": getattr(hybrid_cache, "value_cache"), | ||
"max_batch_size": getattr(hybrid_cache, "max_batch_size"), | ||
"max_cache_len": getattr(hybrid_cache, "max_cache_len"), | ||
"_dtype": str(getattr(hybrid_cache, "_dtype")) if getattr(hybrid_cache, "_dtype", None) is not None else None, | ||
} | ||
|
||
|
||
def _flatten_hybrid_cache( | ||
hybrid_cache: HybridCache, | ||
): | ||
"""Flattens HybridCache into flat list of tensors for `torch.export.export` to consume""" | ||
if not isinstance(hybrid_cache, HybridCache): | ||
raise RuntimeError("This pytree flattening function should only be applied to HybridCache") | ||
|
||
if not is_torch_greater_or_equal_than_2_7: | ||
logger.warning_once( | ||
"HybridCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions." | ||
) | ||
|
||
return torch.utils._pytree._dict_flatten(_get_flat_dict_for_hybrid_cache(hybrid_cache)) | ||
|
||
|
||
def _flatten_with_keys_hybrid_cache(hybrid_cache: HybridCache): | ||
return torch.utils._pytree._dict_flatten_with_keys(_get_flat_dict_for_hybrid_cache(hybrid_cache)) | ||
|
||
|
||
def _unflatten_hybrid_cache( | ||
values, | ||
context: torch.utils._pytree.Context, | ||
): | ||
dictionary = torch.utils._pytree._dict_unflatten(values, context) | ||
hybrid_cache = HybridCache( | ||
dictionary["config"], | ||
dictionary["max_batch_size"], | ||
dictionary["max_cache_len"], | ||
torch.device(dictionary["device"]) if dictionary["device"] is not None else None, | ||
getattr(torch, dictionary["_dtype"][len("torch.") :]) if dictionary["_dtype"] is not None else None, | ||
dictionary["layer_device_map"], | ||
) | ||
|
||
hybrid_cache.key_cache = dictionary["key_cache"] | ||
hybrid_cache.value_cache = dictionary["value_cache"] | ||
return hybrid_cache | ||
|
||
|
||
def _flatten_hybrid_cache_for_fx(hybrid_cache, spec): | ||
return torch.utils._pytree.tree_flatten(_get_flat_dict_for_hybrid_cache(hybrid_cache))[0] | ||
|
||
|
||
if is_torch_greater_or_equal("2.3"): | ||
torch.utils._pytree.register_pytree_node( | ||
HybridCache, | ||
_flatten_hybrid_cache, | ||
_unflatten_hybrid_cache, | ||
serialized_type_name=f"{HybridCache.__module__}.{HybridCache.__name__}", | ||
flatten_with_keys_fn=_flatten_with_keys_hybrid_cache, | ||
) | ||
# TODO (tmanlaibaatar) This won't be needed in torch 2.7. | ||
torch.fx._pytree.register_pytree_flatten_spec(HybridCache, _flatten_hybrid_cache_for_fx) | ||
Comment on lines
+1814
to
+1878
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah lot of this has nothing to do in this file and should rather go into |
||
|
||
|
||
class HybridChunkedCache(Cache): | ||
""" | ||
Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window | ||
|
@@ -1998,7 +2066,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0): | |
) | ||
if len(self.key_cache) == 0: | ||
return 0 | ||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() | ||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item() | ||
|
||
def reset(self): | ||
"""Resets the cache values while preserving the objects""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -317,7 +317,7 @@ def forward( | |
# In case we are beyond the sliding window, we need to correctly offset the mask slicing | ||
offset = cache_position[-1] - effective_seq_len + 1 | ||
# Should only be used when beyond the sliding window (i.e. offset > 0) | ||
offset = max(0, offset) | ||
offset = torch.clamp(offset, min=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is PR to automatically do it: pytorch/pytorch#151348 |
||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, | ||
# but without data-dependent slicing (i.e. torch.compile friendly) | ||
mask_indexes = torch.arange( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,8 @@ | |
import torch.nn as nn | ||
import torch.utils.checkpoint | ||
|
||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_7 | ||
|
||
from ...activations import ACT2FN | ||
from ...cache_utils import Cache, HybridCache, StaticCache | ||
from ...configuration_utils import PretrainedConfig | ||
|
@@ -201,6 +203,13 @@ def __init__( | |
self.attn_logit_softcapping = attn_logit_softcapping | ||
self.cache_implementation = cache_implementation | ||
|
||
def __hash__(self): | ||
return hash(tuple(sorted(self.__dict__))) | ||
|
||
|
||
if is_torch_greater_or_equal_than_2_7: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this change needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep! Because HybridCache is in the output, export should understand this type thoroughly. Since HybridCache depends on model config, we should tell export that it is a constant type (doesn't have any inner tensors). This API is introduced in 2.7 hence there is a guard There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uhmm... my question here is not so much regarding the usefulness of these lines, but rather about the file they are in. Perhaps it should be moved to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh i was getting error that there is difference between modular_gemma2 and generated modeling_gemma2. So i thought modular_gemma2 was the source of truth. |
||
torch.utils._pytree.register_constant(Gemma2Config) | ||
|
||
|
||
class Gemma2RMSNorm(GemmaRMSNorm): | ||
pass | ||
|
@@ -364,7 +373,7 @@ def forward( | |
# In case we are beyond the sliding window, we need to correctly offset the mask slicing | ||
offset = cache_position[-1] - effective_seq_len + 1 | ||
# Should only be used when beyond the sliding window (i.e. offset > 0) | ||
offset = max(0, offset) | ||
offset = torch.clamp(offset, min=0) | ||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, | ||
# but without data-dependent slicing (i.e. torch.compile friendly) | ||
mask_indexes = torch.arange( | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -28,6 +28,7 @@ | |||
|
||||
logger = logging.get_logger(__name__) | ||||
|
||||
is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) | ||||
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True) | ||||
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True) | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use
is_torch_greater_or_equal
instead. We're shifting towards this one across the library.(equivalent usage:
is_torch_greater_or_equal("2.7", accept_dev=True)
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the implementation in the codebase... Should I create this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_torch_greater_or_equal
is already imported in that file :) (see the imports from.utils
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lol i see sorry