diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index ca65c02a40c3b..2743c1e522c9f 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -471,6 +471,12 @@ file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py" ) +file(GLOB onnxruntime_python_transformers_models_torch_export_patches_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/torch_export_patches/*.py" +) +file(GLOB onnxruntime_python_transformers_models_torch_export_patches_patches_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/torch_export_patches/patches/*.py" +) file(GLOB onnxruntime_python_transformers_models_bart_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bart/*.py" ) @@ -566,6 +572,8 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/sam2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/torch_export_patches + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/torch_export_patches/patches COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/whisper COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/operators @@ -682,6 +690,12 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_t5_src} $/onnxruntime/transformers/models/t5/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_torch_export_patches_src} + $/onnxruntime/transformers/models/torch_export_patches/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_torch_export_patches_patches_src} + $/onnxruntime/transformers/models/torch_export_patches/patches/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_whisper_src} $/onnxruntime/transformers/models/whisper/ diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 89fd613ecbbc2..5c778424709c8 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -22,6 +22,10 @@ from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check from llama_torch import setup_torch_model + +# to patch transformers before exporting for transformers >= 4.45 +from models.torch_export_patches import bypass_export_some_errors +from models.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes, replace_dynamic_shapes from onnx_model import OnnxModel from optimizer import optimize_model from packaging import version @@ -122,7 +126,7 @@ def run_dynamo_export( config.capture_scalar_outputs = True # Dummy values for export - batch_size, sequence_length, past_sequence_length = 2, 8, 0 + batch_size, sequence_length, past_sequence_length = 2, 8, 3 device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") temp_name = args.model_name.lower().replace("-", "").replace("_", "") @@ -141,9 +145,76 @@ def run_dynamo_export( ) temp_dir = tempfile.TemporaryDirectory() temp_path = os.path.join(temp_dir.name, "temp.onnx") - torch.onnx.dynamo_export( - llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True) - ).save(temp_path) + + input_names = ["input_ids", "attention_mask", "position_ids"] + output_names = [ + "logits", + *list( + chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers)) + ), + ] + dynamic_axes = get_model_dynamic_axes(input_names, output_names) + + model_args = (input_ids, attn_mask, pos_ids, past_kv) + model_args, model_kwargs, dynamic_shapes = convert_dynamic_axes_into_dynamic_shapes( + llama, args=model_args, dynamic_axes=dynamic_axes, prefix_mapping={"present": "past_key_values"} + ) + + if version.Version(torch.__version__) < version.Version("2.7"): + # This section is only needed for torch==2.6. The workaround implemented here + # to fix bugs is not necessary with torch>=2.7. + # - strings are not allowed with torch 2.6, so we replace them by DYNAMIC + # - TypePromotion was fixed in torch==2.7 + from onnxscript import opset18 as op + + dynamic_shapes = replace_dynamic_shapes( + dynamic_shapes, + dict(batch_size=torch.export.Dim("batch_size")), + default_value=torch.export.Dim.DYNAMIC, + ) + + # TypePromotion cannot fix a type issue after the conversion. + # We insert an additional CastLike when the exporter + def custom_aten_ge(self, other): + if isinstance(other, (int, float)): + return op.GreaterOrEqual(self, op.CastLike(other, self)) + return op.GreaterOrEqual(self, other) + + with bypass_export_some_errors(patch_transformers=True): + # ONNX pass TypePromotion crashes for torch 2.6. + # It can be bypassed by exporting first into an exported program. + # We then need to apply run_decompositions() before onnx conversion starts. + ep = torch.export.export( + llama, + (), + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + ep = ep.run_decompositions() + torch.onnx.export( + ep, + (), + temp_path, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + dynamo=True, + verbose=args.verbose, + optimize=True, + custom_translation_table={torch.ops.aten.ge.Scalar: custom_aten_ge}, + ) + else: + with bypass_export_some_errors(patch_transformers=True): + torch.onnx.export( + llama, + (), + temp_path, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + dynamo=True, + verbose=args.verbose, + optimize=True, + ) # Check decoder_with_past_model.onnx and save all external data to one file onnx.checker.check_model(temp_path) @@ -330,6 +401,7 @@ def run_torchscript_merged_export( temp_dir = f"./temp_{rank}" _prepare_dir(temp_dir) temp_path = os.path.join(temp_dir, "temp.onnx") + torch.onnx.export( llama, args=decoder_merged_inputs, @@ -341,6 +413,7 @@ def run_torchscript_merged_export( opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, + dynamo=False, ) # Check decoder_merged_model.onnx and save all external data to one file @@ -862,9 +935,6 @@ def main(): decoder_merged_model_fp32_opt_path, ] - if args.use_dynamo_export: - continue - # Run the optimizer script. logger.info("Optimizing models...") for orig_path, opt_path in zip(old_paths, new_paths, strict=False): @@ -970,9 +1040,6 @@ def main(): remove_existing_model(fp_path) barrier() - if args.use_dynamo_export: - return - logger.info("Verifying parity on all ONNX models created") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 025d57f0b2d5d..5c9ccb118bc61 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -7,6 +7,7 @@ import numpy as np import torch +import transformers from transformers import AutoConfig, AutoTokenizer from onnxruntime import InferenceSession, OrtValue @@ -240,8 +241,12 @@ def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, u def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]): past_kv = {} for i, (past_k, past_v) in enumerate(past_key_values): - past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() - past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() + if isinstance(past_key_values, transformers.cache_utils.DynamicCache): + past_kv[f"past_key_values_key_cache_{i}"] = past_k.detach().cpu().numpy() + past_kv[f"past_key_values_value_cache_{i}"] = past_v.detach().cpu().numpy() + else: + past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() + past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() return past_kv diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index eab55154b50b1..27e9db54bd0a4 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -11,7 +11,9 @@ import time import numpy as np +import packaging.version as pv import torch +import transformers from benchmark_helper import setup_logger from dist_settings import get_rank, get_size from llama_inputs import ( @@ -23,6 +25,7 @@ verify_ort_inputs, ) from llama_torch import setup_torch_model +from models.torch_export_patches.cache_helper import make_dynamic_cache from transformers import AutoConfig import onnxruntime as ort @@ -71,6 +74,28 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig): return inputs +def torch_deepcopy(value): + if isinstance(value, (int, float, str)): + return value + if isinstance(value, tuple): + return tuple(torch_deepcopy(v) for v in value) + if isinstance(value, list): + return [torch_deepcopy(v) for v in value] + if isinstance(value, set): + return {torch_deepcopy(v) for v in value} + if isinstance(value, dict): + return {k: torch_deepcopy(v) for k, v in value.items()} + if isinstance(value, np.ndarray): + return value.copy() + if hasattr(value, "clone"): + return value.clone() + if isinstance(value, transformers.cache_utils.DynamicCache): + return make_dynamic_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache, strict=False)))) + # We should have a code using serialization, deserialization assuming a model + # cannot be exported without them. + raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}") + + def verify_parity( args: argparse.Namespace, location: str, @@ -92,11 +117,18 @@ def verify_parity( inputs = get_inputs(args, config) + if "past_key_values" in inputs and pv.Version(transformers.__version__) >= pv.Version("4.45"): + # Using DynamicCache + inputs["past_key_values"] = make_dynamic_cache(inputs["past_key_values"]) + # Run inference with PyTorch if args.execution_provider != "cpu": torch.cuda.synchronize() start_time = time.time() - pt_outputs = py_model(**inputs).logits.detach().cpu().numpy() + # If there is a cache in the inputs, we need to make a copy as the model modify them inplace. + # DynamicCache inherits from torch.nn.Module in some version of transformers. + # We need to make the copy manually. + pt_outputs = py_model(**torch_deepcopy(inputs)).logits.detach().cpu().numpy() if args.execution_provider != "cpu": torch.cuda.synchronize() end_time = time.time() diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index c965cc5dab58a..40f3ab1c92f16 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,5 +1,7 @@ +onnxscript>=0.2.3 +optree optimum>=1.14.1 -transformers>=4.33.2,<= 4.38.0 +transformers==4.48.0 torch>=2.2.0 onnx==1.17.0 datasets>=2.8.0 diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py new file mode 100644 index 0000000000000..caa6638e7f749 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py @@ -0,0 +1,112 @@ +from typing import Any + +import packaging.version as pv +import torch +import transformers + +from .onnx_export_errors import ( + bypass_export_some_errors, + register_additional_serialization_functions, +) + + +def is_torchdynamo_exporting() -> bool: + "Tells if torch is exporting a model." + import torch + + if not hasattr(torch.compiler, "is_exporting"): + # torch.compiler.is_exporting requires torch>=2.7 + return False + + try: + return torch.compiler.is_exporting() + except Exception: + try: + import torch._dynamo as dynamo + + return dynamo.is_exporting() # type: ignore + except Exception: + return False + + +def string_type(anything, **args): + # too long + return str(anything) + + +if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): + + def make_dynamic_cache( + key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], + ) -> transformers.cache_utils.DynamicCache: + """ + Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + This version is valid for ``transformers >= 4.50``. + + :param key_value_pairs: list of pairs of (key, values) + :return: :class:`transformers.cache_utils.DynamicCache` + + Example: + + :: + + n_layers = 2 + bsize, nheads, slen, dim = 2, 4, 3, 7 + + past_key_values = make_dynamic_cache( + [ + ( + torch.randn(bsize, nheads, slen, dim), + torch.randn(bsize, nheads, slen, dim), + ) + for i in range(n_layers) + ] + ) + print(string_type(past_key_values, with_shape=True)) + """ + return transformers.cache_utils.DynamicCache(key_value_pairs) + +else: + + def make_dynamic_cache( + key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], + ) -> transformers.cache_utils.DynamicCache: + """ + Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + This version is valid for ``transformers < 4.50``. + + :param key_value_pairs: list of pairs of (key, values) + :return: :class:`transformers.cache_utils.DynamicCache` + + Example: + + :: + + n_layers = 2 + bsize, nheads, slen, dim = 2, 4, 3, 7 + + past_key_values = make_dynamic_cache( + [ + ( + torch.randn(bsize, nheads, slen, dim), + torch.randn(bsize, nheads, slen, dim), + ) + for i in range(n_layers) + ] + ) + print(string_type(past_key_values, with_shape=True)) + """ + cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) + for i, (key, value) in enumerate(key_value_pairs): + cache.update(key, value, i) + return cache + + +def make_encoder_decoder_cache( + self_attention_cache: transformers.cache_utils.DynamicCache, + cross_attention_cache: transformers.cache_utils.DynamicCache, +) -> transformers.cache_utils.EncoderDecoderCache: + "Creates an EncoderDecoderCache." + return transformers.cache_utils.EncoderDecoderCache( + self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache + ) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py new file mode 100644 index 0000000000000..0cbe1e58a9e02 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py @@ -0,0 +1,74 @@ +import packaging.version as pv +import torch +import transformers +import transformers.cache_utils + + +def is_cache_dynamic_registered(fast: bool = False) -> bool: + """ + Tells class :class:`transformers.cache_utils.DynamicCache` can be + serialized and deserialized. Only then, :func:`torch.export.export` + can export a model. + + :param fast: if True, do not check the serialization is ok as well + :return: result + """ + if fast: + return transformers.cache_utils.DynamicCache in torch.utils._pytree.SUPPORTED_NODES + bsize, nheads, slen, dim = 2, 4, 3, 7 + cache = make_dynamic_cache( + [ + ( + torch.randn(bsize, nheads, slen, dim), + torch.randn(bsize, nheads, slen, dim), + ) + for i in range(2) + ] + ) + values, spec = torch.utils._pytree.tree_flatten(cache) + cache2 = torch.utils._pytree.tree_unflatten(values, spec) + return len(cache2.key_cache) == len(cache.value_cache) + + +if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): + + def make_dynamic_cache( + key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], + ) -> transformers.cache_utils.DynamicCache: + """ + Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + This version is valid for ``transformers >= 4.50``. + + :param key_value_pairs: list of pairs of (key, values) + :return: :class:`transformers.cache_utils.DynamicCache` + """ + return transformers.cache_utils.DynamicCache(key_value_pairs) + +else: + + def make_dynamic_cache( + key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], + ) -> transformers.cache_utils.DynamicCache: + """ + Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + This version is valid for ``transformers < 4.50``. + + :param key_value_pairs: list of pairs of (key, values) + :return: :class:`transformers.cache_utils.DynamicCache` + """ + cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) + for i, (key, value) in enumerate(key_value_pairs): + cache.update(key, value, i) + return cache + + +def make_encoder_decoder_cache( + self_attention_cache: transformers.cache_utils.DynamicCache, + cross_attention_cache: transformers.cache_utils.DynamicCache, +) -> transformers.cache_utils.EncoderDecoderCache: + """ + Creates an EncoderDecoderCache. + """ + return transformers.cache_utils.EncoderDecoderCache( + self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache + ) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py new file mode 100644 index 0000000000000..5dd3b38a8232a --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -0,0 +1,472 @@ +import contextlib +import pprint +from collections.abc import Callable +from typing import Any + +from .onnx_export_serialization import ( + flatten_dynamic_cache, + flatten_mamba_cache, + flatten_with_keys_dynamic_cache, + flatten_with_keys_mamba_cache, + unflatten_dynamic_cache, + unflatten_mamba_cache, +) +from .patches import patch_transformers as patch_transformers_list + + +def patch_module(mod, verbose: int = 0) -> dict[type, dict[type, Callable]]: + """ + Applies all patches defined in classes prefixed by ``patched_`` + ``cls._PATCHED_CLASS_`` defines the class to patch, + ``cls._PATCHES_`` defines the method to patch. + The returns information needs to be sent to :func:`unpatch_module` + to revert the changes. + """ + to_patch = [] + for k in dir(mod): + if k.startswith("patched_"): + v = getattr(mod, k) + if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): + to_patch.append(v) + + res = {} + for cls in to_patch: + original = cls._PATCHED_CLASS_ + methods = cls._PATCHES_ + if verbose: + print(f"[patch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}") + + keep = {n: getattr(original, n, None) for n in methods} + for n in methods: + setattr(original, n, getattr(cls, n)) + res[cls] = keep + + return res + + +def unpatch_module(mod, info: dict[type, dict[type, Callable]], verbose: int = 0): + """Reverts modification made by :func:`patch_module`.""" + to_patch = [] + for k in dir(mod): + if k.startswith("patched_"): + v = getattr(mod, k) + if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): + to_patch.append(v) + set_patch = set(to_patch) + + for cls, methods in info.items(): + assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})" + if verbose: + print(f"[unpatch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}") + original = cls._PATCHED_CLASS_ + for n, v in methods.items(): + if v is None: + # The method did not exist. We remove it. + delattr(original, n) + else: + setattr(original, n, v) + + +def _register_cache_serialization(verbose: int = 0) -> dict[str, bool]: + # Cache serialization: to be moved into appropriate packages + import packaging.version as pv + import torch + import transformers + + try: + from transformers.cache_utils import DynamicCache + except ImportError: + DynamicCache = None + + try: + from transformers.cache_utils import MambaCache + except ImportError: + MambaCache = None + + # MambaCache + unregistered_mamba_cache = True + if MambaCache is not None and MambaCache in torch.utils._pytree.SUPPORTED_NODES: + if verbose > 1: + print(f"[_register_cache_serialization] {MambaCache} already registered") + # It is already registered because bypass_export_some_errors was called + # within a section already calling bypass_export_some_errors or transformers + # has updated its code to do it. + # No need to register and unregister then. + unregistered_mamba_cache = False + else: + if verbose: + print("[_register_cache_serialization] register MambaCache") + torch.utils._pytree.register_pytree_node( + MambaCache, + flatten_mamba_cache, + unflatten_mamba_cache, + serialized_type_name=f"{MambaCache.__module__}.{MambaCache.__name__}", + flatten_with_keys_fn=flatten_with_keys_mamba_cache, + ) + + # DynamicCache serialization is different in transformers and does not + # play way with torch.export.export. + # This is caused by this line: + # torch.fx._pytree.register_pytree_flatten_spec( + # DynamicCache, _flatten_dynamic_cache_for_fx) + # so we remove it anyway + if DynamicCache in torch.fx._pytree.SUPPORTED_NODES and pv.Version(transformers.__version__) >= pv.Version("2.7"): + if verbose: + print("[_register_cache_serialization] DynamicCache is unregistered first.") + _unregister(DynamicCache) + + unregistered_dynamic_cache = True + if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES: + if verbose > 1: + print(f"[_register_cache_serialization] {DynamicCache} already registered") + unregistered_dynamic_cache = False + else: + if verbose: + print("[_register_cache_serialization] register DynamicCache") + torch.utils._pytree.register_pytree_node( + DynamicCache, + flatten_dynamic_cache, + unflatten_dynamic_cache, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=flatten_with_keys_dynamic_cache, + ) + torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, lambda x, _: [x.key_cache, x.value_cache]) + + # check + from .cache_helper import make_dynamic_cache + + cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) + values, spec = torch.utils._pytree.tree_flatten(cache) + cache2 = torch.utils._pytree.tree_unflatten(values, spec) + # torch.fx._pytree.tree_flatten(cache) + assert len(cache2.key_cache) == 1 + + return dict(DynamicCache=unregistered_dynamic_cache, MambaCache=unregistered_mamba_cache) + + +def _unregister(cls: type, verbose: int = 0): + import optree + import torch + + # torch.fx._pytree._deregister_pytree_flatten_spec(cls) + if cls in torch.fx._pytree.SUPPORTED_NODES: + del torch.fx._pytree.SUPPORTED_NODES[cls] + if cls in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH: + del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[cls] + if hasattr(torch.utils._pytree, "_deregister_pytree_node"): + # torch >= 2.7 + torch.utils._pytree._deregister_pytree_node(cls) + optree.unregister_pytree_node(cls, namespace="torch") + if cls in torch.utils._pytree.SUPPORTED_NODES: + import packaging.version as pv + + if pv.Version(torch.__version__) < pv.Version("2.7.0"): + del torch.utils._pytree.SUPPORTED_NODES[cls] + assert cls not in torch.utils._pytree.SUPPORTED_NODES, ( + f"{cls} was not successful unregistered " + f"from torch.utils._pytree.SUPPORTED_NODES=" + f"{pprint.pformat(list(torch.utils._pytree.SUPPORTED_NODES))}" + ) + if verbose: + print(f"[_unregister_cache_serialization] unregistered {cls.__name__}") + + +def _unregister_cache_serialization(undo: dict[str, bool], verbose: int = 0): + if undo.get("MambaCache", False): + from transformers.cache_utils import MambaCache + + _unregister(MambaCache, verbose) + elif verbose > 1: + print("[_unregister_cache_serialization] skip unregister MambaCache") + + if undo.get("DynamicCache", False): + from transformers.cache_utils import DynamicCache + + _unregister(DynamicCache, verbose) + elif verbose > 1: + print("[_unregister_cache_serialization] skip unregister DynamicCache") + + +@contextlib.contextmanager +def register_additional_serialization_functions(patch_transformers: bool = False, verbose: int = 0) -> Callable: + """The necessary modifications to run the fx Graph.""" + fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x) + done = _register_cache_serialization(verbose=verbose) + try: + yield fct_callable + finally: + _unregister_cache_serialization(done, verbose=verbose) + + +@contextlib.contextmanager +def bypass_export_some_errors( + patch_sympy: bool = True, + patch_torch: bool = True, + patch_transformers: bool = False, + catch_constraints: bool = True, + stop_if_static: bool = False, + verbose: int = 0, + patch: bool = True, +) -> Callable: + """ + Tries to bypass some situations :func:`torch.export.export` does not support. + + :param patch_sympy: fix missing method ``name`` for IntegerConstant + :param patch_torch: patches :epkg:`torch` with supported implementation + :param patch_transformers: patches :epkg:`transformers` with supported implementation + :param catch_constraints: catch constraints related to dynamic shapes, + as a result, some dynamic dimension may turn into static ones, + the environment variable ``SKIP_SOLVE_CONSTRAINTS=0`` + can be put to stop at that stage. + :param stop_if_static: see example :ref:`l-plot-export-locale-issue`, + to stop the export as soon as an issue is detected with dynamic shapes + and show a stack trace indicating the exact location of the issue + :param patch: if False, disable all patches except the registration of + serialization function + :param verbose: to show which patches is applied + + The list of available patches. + + * ``torch.jit.isinstance`` + * ``torch._dynamo.mark_static_address`` + * ``torch._subclasses.fake_impls.infer_size`` + * fix missing method ``name`` for ``sympy.S.IntegerConstant`` + * ``AttentionMaskConverter._make_causal_mask`` + * Serialization of ``MambaCache`` (in :epkg:`transformers`) + * Serialization of ``DynamicCache`` (in :epkg:`transformers`) + * reduce errors due to shape inference + * fixes some transformers classes + + Serialization issues happen when a module takes one input or output + has a type :func:`torch.export.export` cannot serialize. + + Examples: + + :: + + with bypass_export_some_errors(patch_transformers=True) as modificator: + inputs = modificator(inputs) + onx = to_onnx(..., inputs, ...) + + :: + + with bypass_export_some_errors(patch_transformers=True) as modificator: + inputs = modificator(inputs) + onx = torch.onnx.export(..., inputs, ...) + + It can be used as well to fix the torch export: + + :: + + with bypass_export_some_errors(patch_transformers=True) as modificator: + inputs = modificator(inputs) + ep = torch.export.export(..., inputs, ...) + + When running the model through the exported program, only the + serialization functions need to be restored: + + :: + + with register_additional_serialization_functions() as modificator: + inputs = modificator(inputs) + ep = torch.export.export(..., inputs, ...) + + When exporting a model with a cache, the following error message + may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``. + It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`. + """ + if not patch: + fct_callable = lambda x: x # noqa: E731 + done = _register_cache_serialization(verbose=verbose) + try: + yield fct_callable + finally: + _unregister_cache_serialization(done, verbose=verbose) + else: + import torch + import torch._export.non_strict_utils # produce_guards_and_solve_constraints + import torch.jit + + if verbose: + print("[bypass_export_some_errors] replace torch.jit.isinstance, torch._dynamo.mark_static_address") + + ######## + # caches + ######## + + cache_done = _register_cache_serialization(verbose=verbose) + + ############# + # patch sympy + ############# + + if patch_sympy: + import sympy + + f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None) + + if verbose: + print("[bypass_export_some_errors] patch sympy") + + sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{self!s}" + + ############### + # patch pytorch + ############### + # the linter gets confused if not initialized + f_jit_isinstance = f_mark_static_address = f_infer_size = ShapeEnv = None + f__broadcast_shapes = f_shape_env__set_replacement = revert_patches_info = None + + if patch_torch: + from .patches.patch_torch import ( + _catch_produce_guards_and_solve_constraints, + patch__check_input_constraints_for_graph, + patched__broadcast_shapes, + patched_infer_size, + ) + + if verbose: + print("[bypass_export_some_errors] patch pytorch") + + # torch.jit.isinstance + f_jit_isinstance = torch.jit.isinstance + torch.jit.isinstance = isinstance + + # torch._dynamo.mark_static_address + f_mark_static_address = torch._dynamo.mark_static_address + torch._dynamo.mark_static_address = lambda *_, **y_: None + + # torch._subclasses.fake_impls.infer_size + f_infer_size = torch._subclasses.fake_impls.infer_size + torch._subclasses.fake_impls.infer_size = patched_infer_size + + # torch._refs._broadcast_shapes + f__broadcast_shapes = torch._refs._broadcast_shapes + torch._refs._broadcast_shapes = patched__broadcast_shapes + torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes + + # torch._export.non_strict_utils.produce_guards_and_solve_constraints + if catch_constraints: + if verbose: + print("[bypass_export_some_errors] modifies shape constraints") + f_produce_guards_and_solve_constraints = torch._export.non_strict_utils.produce_guards_and_solve_constraints + f__check_input_constraints_for_graph = torch._export.utils._check_input_constraints_for_graph + torch._export.non_strict_utils.produce_guards_and_solve_constraints = ( + lambda *args, **kwargs: _catch_produce_guards_and_solve_constraints( + f_produce_guards_and_solve_constraints, *args, verbose=verbose, **kwargs + ) + ) + torch._export.utils._check_input_constraints_for_graph = ( + lambda *args, **kwargs: patch__check_input_constraints_for_graph( + f__check_input_constraints_for_graph, *args, verbose=verbose, **kwargs + ) + ) + + if stop_if_static: + if verbose: + print("[bypass_export_some_errors] assert when a dynamic dimension turns static") + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + from .patches.patch_torch import patched_ShapeEnv + + f_shape_env__set_replacement = ShapeEnv._set_replacement + ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement + + #################### + # patch transformers + #################### + + if patch_transformers: + revert_patches_info = patch_module(patch_transformers_list, verbose=verbose) + + ######## + # export + ######## + + fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x) + + if verbose: + print("[bypass_export_some_errors] done patching") + + try: + yield fct_callable + finally: + ####### + # sympy + ####### + + if verbose: + print("[bypass_export_some_errors] remove patches") + + if patch_sympy: + # tracked by https://github.com/pytorch/pytorch/issues/143494 + if f_sympy_name: + sympy.core.numbers.IntegerConstant.name = f_sympy_name + else: + delattr(sympy.core.numbers.IntegerConstant, "name") + + if verbose: + print("[bypass_export_some_errors] restored sympy functions") + + ####### + # torch + ####### + + if patch_torch: + # this should disappear when torch.jit is removed + torch.jit.isinstance = f_jit_isinstance + torch._dynamo.mark_static_address = f_mark_static_address + # tracked by https://github.com/pytorch/pytorch/issues/143495 + torch._subclasses.fake_impls.infer_size = f_infer_size + torch._refs._broadcast_shapes = f__broadcast_shapes + torch._meta_registrations._broadcast_shapes = f__broadcast_shapes + + if verbose: + print("[bypass_export_some_errors] restored pytorch functions") + + if stop_if_static: + if verbose: + print("[bypass_export_some_errors] restored ShapeEnv._set_replacement") + + ShapeEnv._set_replacement = f_shape_env__set_replacement + + if catch_constraints: + # to catch or skip dynamic_shapes issues + torch._export.non_strict_utils.produce_guards_and_solve_constraints = ( + f_produce_guards_and_solve_constraints + ) + torch._export.utils._check_input_constraints_for_graph = f__check_input_constraints_for_graph + if verbose: + print("[bypass_export_some_errors] restored shape constraints") + + ############## + # transformers + ############## + + if patch_transformers: + unpatch_module(patch_transformers_list, revert_patches_info, verbose=verbose) + + ######## + # caches + ######## + + _unregister_cache_serialization(cache_done, verbose=verbose) + + +def replacement_before_exporting(args: Any) -> Any: + """ + Does replacements on the given inputs if needed. + """ + if args is None: + return None + if isinstance(args, (int, float)): + return args + if isinstance(args, dict): + return {k: replacement_before_exporting(v) for k, v in args.items()} + if isinstance(args, tuple): + return tuple(replacement_before_exporting(v) for v in args) + if isinstance(args, list): + return [replacement_before_exporting(v) for v in args] + + return args diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py new file mode 100644 index 0000000000000..d109dd3059480 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py @@ -0,0 +1,132 @@ +from typing import Any + +import torch +import transformers + +############ +# MambaCache +############ + + +# self.conv_states: torch.Tensor = torch.zeros( +# config.num_hidden_layers, +# self.max_batch_size, +# self.intermediate_size, +# self.conv_kernel_size, +# device=device, +# dtype=dtype, +# ) +# self.ssm_states: torch.Tensor = torch.zeros( +# config.num_hidden_layers, +# self.max_batch_size, +# self.intermediate_size, +# self.ssm_state_size, +# device=device, +# dtype=dtype, +# ) +def flatten_mamba_cache( + mamba_cache: transformers.cache_utils.MambaCache, +) -> tuple[list[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + flat = [ + (k, getattr(mamba_cache, k)) + for k in [ + # "max_batch_size", # new in transformers==4.47 + # "intermediate_size", + # "ssm_state_size", + # "conv_kernel_size", + "conv_states", + "ssm_states", + ] + if hasattr(mamba_cache, k) + ] + return [f[1] for f in flat], [f[0] for f in flat] + + +def unflatten_mamba_cache( + values: list[Any], + context: torch.utils._pytree.Context, + output_type=None, +) -> transformers.cache_utils.MambaCache: + """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" + conv_states, ssm_states = values + + class _config: + def __init__(self): + if isinstance(conv_states, list): + self.intermediate_size = conv_states[0].shape[1] + self.state_size = ssm_states[0].shape[2] + self.conv_kernel = conv_states[0].shape[2] + self.num_hidden_layers = len(conv_states) + else: + self.intermediate_size = conv_states.shape[2] + self.state_size = ssm_states.shape[3] + self.conv_kernel = conv_states.shape[3] + self.num_hidden_layers = conv_states.shape[0] + + from transformers.cache_utils import MambaCache + + cache = MambaCache( + _config(), + max_batch_size=1, + dtype=values[-1][0].dtype, + device="cpu" if values[-1][0].get_device() < 0 else "cuda", + ) + values = dict(zip(context, values, strict=False)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +def flatten_with_keys_mamba_cache( + d: dict[Any, Any], +) -> tuple[ + list[tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + import torch + + values, context = flatten_mamba_cache(d) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values, strict=False)], context + + +############## +# DynamicCache +############## + + +def flatten_dynamic_cache( + dynamic_cache: transformers.cache_utils.DynamicCache, +) -> tuple[list[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + flat = [(k, getattr(dynamic_cache, k)) for k in ["key_cache", "value_cache"] if hasattr(dynamic_cache, k)] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_dynamic_cache( + d: dict[Any, Any], +) -> tuple[ + list[tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + import torch + + values, context = flatten_dynamic_cache(d) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values, strict=False)], context + + +def unflatten_dynamic_cache( + values: list[Any], + context: torch.utils._pytree.Context, + output_type=None, +) -> transformers.cache_utils.DynamicCache: + """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" + from transformers.cache_utils import DynamicCache + + cache = DynamicCache() + values = dict(zip(context, values, strict=False)) + for k, v in values.items(): + setattr(cache, k, v) + return cache diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py new file mode 100644 index 0000000000000..ded05b8c37be5 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py @@ -0,0 +1,166 @@ +import inspect +from typing import Any + +import torch +import transformers + +from . import string_type +from .cache_helper import make_dynamic_cache + + +def _process_cache(k: str, v): + assert k != "position_ids" or isinstance(k, torch.Tensor), ( + f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}" + ) + if isinstance(v, list) and all(isinstance(i, tuple) for i in v) and {len(t) for t in v} == {2}: + # A dynamicCache + cache = make_dynamic_cache(v) + return cache + if isinstance(v, torch.Tensor): + return v + raise NotImplementedError(f"Unable to process parameter {k!r} with v={string_type(v, with_shape=True)}") + + +def _make_shape(subset: dict, cls: type, value: Any) -> Any: + if cls is transformers.cache_utils.DynamicCache: + assert subset, "DynamicCache cannot be empty" + values = set(map(str, subset.values())) + assert len(values) == 1, ( + f"Inconsistencies in subset={subset}, found={values}, it cannot be a {cls}, value={string_type(value)}" + ) + cache_length = len(value.key_cache) + for v in subset.values(): + axes = v + break + new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]] + return new_shape + raise NotImplementedError(f"_make_shape not implemented for cls={cls}, subset={subset}, value={string_type(value)}") + + +def convert_dynamic_axes_into_dynamic_shapes( + model: torch.nn.Module, + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + dynamic_axes: dict[str, dict[int, str]] | None = None, + prefix_mapping: dict[str, str] | None = None, + verbose: int = 0, +) -> tuple[tuple[Any, ...], dict[str, Any], dict[str, Any]]: + """ + Converts the input from an export to something :func:`torch.export.export` can handle. + + :param model: model to convert (used to extract the signature) + :param args: positional arguments + :param kwargs: named arguments + :param dynamic_axes: dynamic axes + :param prefix_mapping: prefix mapping + :param verbose: verbosity + :return: (args, kwargs, dynamic shapes) + """ + new_kwargs = {} + if args: + assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}" + plus = 0 if isinstance(model, torch.nn.Module) else 1 + print( + f"[convert_dynamic_axes_into_dynamic_shapes] " + f"mapping args to kwargs for model=" + f"{model if plus else model.__class__.__name__}" + ) + pars = inspect.signature(model.forward).parameters + assert len(pars) >= len(args), f"Length mismatch, len(args)={len(args)}, pars={list(pars)}" + + for i, p in enumerate(pars): + if i < plus: + continue + if i - plus >= len(args): + break + if verbose: + print( + f"[convert_dynamic_axes_into_dynamic_shapes] mapping args[{i - plus}] " + f"to {p!r} ({string_type(args[i - plus])})" + ) + new_kwargs[p] = args[i - plus] + + if kwargs: + for k, v in kwargs.items(): + assert k not in new_kwargs, f"Argument {k!r} from kwargs already present in args." + new_kwargs[k] = v + + # process + updated_kwargs = {} + changes = {} + for k, v in new_kwargs.items(): + if isinstance(v, torch.Tensor): + updated_kwargs[k] = v + continue + if isinstance(v, list): + # cache? + updated_kwargs[k] = _process_cache(k, v) + if type(updated_kwargs[k]) is not type(v): + # A cache was introduced. + if verbose: + print( + f"[convert_dynamic_axes_into_dynamic_shapes] parameter " + f"{k!r} was changed into {type(updated_kwargs[k])}" + ) + changes[k] = type(updated_kwargs[k]) + continue + raise NotImplementedError(f"Unexpected type {type(v)} for parameter {k!r} ({string_type(v, with_shape=True)})") + + # process dynamic axes + if changes: + dynamic_shapes = {} + done = set() + for k, v in dynamic_axes.items(): + if k not in changes and k in updated_kwargs and isinstance(v, dict): + dynamic_shapes[k] = v + continue + if "." in k: + # something like present.0.key + prefix = k.split(".")[0] + if prefix in done: + continue + args_prefix = prefix_mapping[prefix] if prefix_mapping and prefix in prefix_mapping else prefix + if args_prefix in updated_kwargs and args_prefix in changes: + # A cache. + cls = changes[args_prefix] + dynamic_shapes[args_prefix] = _make_shape( + {_: __ for _, __ in dynamic_axes.items() if _.startswith(f"{prefix}.")}, + cls, + updated_kwargs[args_prefix], + ) + done.add(prefix) + continue + if k not in updated_kwargs: + # dynamic axes not in the given inputs, should be raise an exception? + if verbose: + print( + f"[convert_dynamic_axes_into_dynamic_shapes] dropping axes " + f"{k!r}-{v!r}, not found in {set(updated_kwargs)}" + ) + continue + raise NotImplementedError( + f"Unable to process dynamic axes {k!r}, axes={v}, " + f"value={string_type(updated_kwargs[k], with_shape=True)}, " + f"dynamic axes={dynamic_axes}, " + f"updated_kwargs={string_type(updated_kwargs, with_shape=True)}" + ) + + return (), updated_kwargs, dynamic_shapes + + +def replace_dynamic_shapes(ds, mapping, default_value): + if isinstance(ds, dict) and all(isinstance(k, int) for k in ds): + new_ds = {} + for k, v in ds.items(): + if isinstance(v, str): + new_ds[k] = mapping.get(v, default_value) + else: + new_ds[k] = v + return new_ds + if isinstance(ds, tuple): + return tuple(replace_dynamic_shapes(d, mapping, default_value) for d in ds) + if isinstance(ds, list): + return [replace_dynamic_shapes(d, mapping, default_value) for d in ds] + if isinstance(ds, dict): + return {k: replace_dynamic_shapes(v, mapping, default_value) for k, v in ds.items()} + return ds diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/__init__.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py new file mode 100644 index 0000000000000..c30a58f4290f1 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py @@ -0,0 +1,323 @@ +import inspect +import os +from collections.abc import Callable, Sequence +from typing import Any + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode + + +def _catch_produce_guards_and_solve_constraints( + previous_function: Callable, + fake_mode: "FakeTensorMode", + gm: "torch.fx.GraphModule", + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None, + equalities_inputs: "EqualityConstraint", # noqa: F821 + original_signature: inspect.Signature, + _is_torch_jit_trace: bool = False, + verbose: int = 0, +): + try: + return previous_function( + fake_mode=fake_mode, + gm=gm, + dynamic_shapes=dynamic_shapes, + equalities_inputs=equalities_inputs, + original_signature=original_signature, + _is_torch_jit_trace=_is_torch_jit_trace, + ) + except Exception as e: + if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")): + raise + if verbose: + print( + f"[_catch_produce_guards_and_solve_constraints] ERROR" + f"produce_guards_and_solve_constraints failed, " + f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n" + f"fake_mode={fake_mode}\n" + f"dynamic_shapes={dynamic_shapes}\n" + f"equalities_inputs={equalities_inputs}\n" + f"original_signature={original_signature}\n" + f"_is_torch_jit_trace={_is_torch_jit_trace}\n" + f"exc={e}\ngm={gm}" + ) + + +def patch__check_input_constraints_for_graph( + previous_function: Callable, + input_placeholders: list[torch.fx.Node], + flat_args_with_path, + range_constraints, + verbose: int = 0, +) -> None: + try: + return previous_function(input_placeholders, flat_args_with_path, range_constraints) + except Exception as e: + if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")): + raise + if verbose: + print( + f"[_check_input_constraints_for_graph] ERROR" + f"_check_input_constraints_for_graph failed, " + f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n" + f"input_placeholders={input_placeholders}\n" + f"range_constraints={range_constraints}\n" + f"exc={e}" + ) + + +def patched_infer_size(a, b): + """Patches ``torch._subclasses.fake_impls.infer_size``.""" + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + dimsA = len(a) + dimsB = len(b) + ndim = max(dimsA, dimsB) + expandedSizes = [0] * ndim + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + sizeA = a[dimA] if dimA >= 0 else 1 + sizeB = b[dimB] if dimB >= 0 else 1 + + # NB: It is very important to test for broadcasting, before testing + # sizeA == sizeB. This is because the broadcasting tests are likely + # to be statically known (in particular, if sizeA/sizeB is unbacked + # but size-like, we will unsoundly assume they never equal 1), but + # the sizeA == sizeB test may not be statically known. However, once + # we have established that no broadcasting is happening, the + # sizeA == sizeB is now expect_true and we can defer it as a runtime + # assert (this works because Python will return the terminal + # expression of an or statement as-is, without bool()'ing it; if this + # were not the case, we'd need to write this using torch.sym_or() or + # something like that). + try: + b1 = guard_size_oblivious(sizeA == 1) + except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: + b1 = False + try: + b2 = guard_size_oblivious(sizeB == 1) + except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: + b2 = False + try: + b3 = guard_size_oblivious(sizeA == sizeB) + except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: + b3 = False + if b1 or b2 or b3: + expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA + else: + # In this case, the current implementation of torch fails (17/12/2024). + # Try model SmolLM. + expandedSizes[i] = torch.sym_max(sizeA, sizeB) + return tuple(expandedSizes) + + +def patched__broadcast_shapes(*_shapes): + """Patches ``torch._refs._broadcast_shapes``.""" + from functools import reduce + + from torch._prims_common import IntLike + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + shapes = tuple((x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)) + + # Short-circuits on no input + if len(shapes) == 0: + return None + + # Type checking + # TODO: make common validations available as utils + for shape in shapes: + assert isinstance(shape, Sequence) + + # Computes common shape + common_shape = [ # list[Union[int, torch.SymInt]] + 1, + ] * reduce(max, (len(shape) for shape in shapes)) + for _arg_idx, shape in enumerate(shapes): + for idx in range(-1, -1 - len(shape), -1): + if guard_size_oblivious(common_shape[idx] == 1): + if shape[idx] < 0: + raise ValueError("Attempting to broadcast a dimension with negative length!") + common_shape[idx] = shape[idx] + elif guard_size_oblivious(shape[idx] != 1): + common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx]) + + return common_shape + + +class patched_ShapeEnv: + def _set_replacement( + self, + a: "sympy.Symbol", # noqa: F821 + tgt: "sympy.Expr", # noqa: F821 + msg: str, + ) -> None: + """ + Adds or updates a replacement for a symbol. + Use this instead of `self.replacements[a] = tgt`. + """ + if tgt == self.replacements.get(a, None): + return + + if a in tgt.free_symbols: + return + + import sympy + from torch._guards import TracingContext + from torch._logging import structured, trace_structured + from torch.fx.experimental.symbolic_shapes import ( + ValueRanges, + _is_supported_equivalence, + ) + from torch.utils._sympy.functions import CeilToInt, FloorToInt + from torch.utils._sympy.solve import try_solve + from torch.utils._traceback import CapturedTraceback + + # Precondition: a == tgt + assert isinstance(a, sympy.Symbol) + + if self.allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt): + # continuing leads to placeholder shapes + # having complex expressions that we can't resolve + return + + # Handles nested tensor symbolic variables which don't have + # var_to_range bounds + tgt_bound = None + if a in self.var_to_range: + src_bound = self.var_to_range[a] + + # First, refine the value range of a based on the computed value range + # of tgt. This is always OK to do, even if we decide not to do the + # substitution in the end. This might be a no-op, if a already has + # a tighter bound + tgt_bound = self.bound_sympy(tgt) + self._update_var_to_range(a, tgt_bound) + + # Next, check if we can update the range of free symbols in tgt + # based on the range in a. But only do it if: + # - the source bound non-trivially improves over what we get out of + # the existing bounds. + # - the replacement is univariate and we can invert the tgt expression + if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1: + b = next(iter(tgt.free_symbols)) + # Try to invert the equality + r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) + if r is not None: + self.log.debug( + "set_replacement: solve for %s in %s == %s gives %s", + b, + a, + tgt, + r, + ) + # The solution here can be non-integral, for example, if + # we have s0 = 2*s1, then s1 = s0/2. What we would like + # to do is calculated the bounds in arbitrary precision, + # and then requantize the bound to integers when we are + # done. + rat_b_bound = self.bound_sympy(r[1]) + b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)) + self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a]) + tgt_bound = self.bound_sympy(tgt) + assert tgt_bound.issubset(src_bound), f"{tgt_bound=} not a subset of {src_bound=}" + + # TODO: Should we propagate size-like-ness? + # + # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1 + # to become size-like. + # + # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T + # propagate in this case, because what if u0 == 0, then u1 is negative + # and clearly isn't a size. So, at minimum, any f(x) whose value + # range isn't [0, inf] given x in [0, inf] cannot propagate + # size-like-ness. But there are many situations where you could + # imagine u1 is going to be size-like and actually you just didn't + # have a refined enough value range on u0. Since even innocuous + # looking arithmetic operations can destroy size-like-ness, it's + # best to not propagate it at all and force the user to annotate it + # as necessary. + # + # Compromise: we preserve size-like-ness only for exact equality + # and nothing else. + if a in self.size_like and isinstance(tgt, sympy.Symbol): + self.size_like.add(tgt) + elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like: + self.size_like.add(a) + + # Now, decide if we will do the substitution. + # + # - If the source has a non-trivial range, only substitute if + # we preserve this range. Note that we may have propagated + # the src_range to free variables in tgt when tgt is univariate + # and we could find an inverse, which helps us achieve this. + # This ensures we never "forget" about user defined ranges, + # even if they end up being defined on composite formulas + # like s0 + s1. + # + # - If the variable is unbacked, only substitute if the substitution + # would preserve the bounds also under size-like-ness conditions. + + if not tgt_bound.issubset(src_bound): + self.log.debug( + "skipped set_replacement %s = %s (%s) [%s not subset of %s]", + a, + tgt, + msg, + tgt_bound, + src_bound, + ) + return + elif a in self.size_like: + tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) + src_bound_so = self.bound_sympy(a, size_oblivious=True) + if not tgt_bound_so.issubset(src_bound_so): + self.log.debug( + "skipped set_replacement %s = %s (%s) [%s not subset of %s (size-oblivious conditions)]", + a, + tgt, + msg, + tgt_bound_so, + src_bound_so, + ) + return + + if isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected (unless + # you specified dynamic=True) + + user_tb = TracingContext.extract_stack() + trace_structured( + "symbolic_shape_specialization", + metadata_fn=lambda: { + "symbol": repr(a), + "sources": [s.name() for s in self.var_to_sources.get(a, [])], + "value": repr(tgt), + "reason": msg, + "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), + "user_stack": (structured.from_traceback(user_tb) if user_tb else None), + }, + ) + + # if config.print_specializations: + # self.log.warning( + # "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt + # ) + # self.log.debug("SPECIALIZATION", stack_info=True) + assert msg != "range_refined_to_singleton", ( + f"A dynamic dimension becomes static! a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}" + ) + # log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) + self.replacements[a] = tgt + # NB: the replacement may get refined, but the user will find the + # FIRST one most useful (TODO: Maybe we could consider tracking all of + # them) + if a not in self.replacements_slocs: + self.replacements_slocs[a] = self._get_sloc() + self._update_version_counter() + + # When specializing 'a == tgt', the equality should be also conveyed to + # Z3, in case an expression uses 'a'. + self._add_target_expr(sympy.Eq(a, tgt, evaluate=False)) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py new file mode 100644 index 0000000000000..828a883b7ab12 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py @@ -0,0 +1,474 @@ +import inspect +import sys +from dataclasses import dataclass +from typing import Any + +import torch +import transformers +import transformers.modeling_attn_mask_utils +from transformers.cache_utils import Cache, DynamicCache, StaticCache + + +def _patch_make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: int | None = None, +): + """Patched method.""" + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), + mask, + ], + dim=-1, + ) + + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window - 1 + + context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) + # In this case, the current implementation of torch fails (17/12/2024). + # Try model Phi-3.5-Mini-Instruct. + mask = mask.masked_fill(context_mask, torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +if sys.version_info[:2] <= (3, 11): + + @dataclass + class patched_AttentionMaskConverter: + """ + Patches + ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. + """ + + _PATCHES_ = ["_make_causal_mask"] + _PATCHED_CLASS_ = transformers.modeling_attn_mask_utils.AttentionMaskConverter + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: int | None = None, + ): + """Patched method.""" + return _patch_make_causal_mask(input_ids_shape, dtype, device, past_key_values_length, sliding_window) + +else: + + @dataclass + class patched_AttentionMaskConverter: + """ + Patches + ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. + """ + + _PATCHES_ = ["_make_causal_mask"] + _PATCHED_CLASS_ = transformers.modeling_attn_mask_utils.AttentionMaskConverter + + @staticmethod + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: int | None = None, + ): + """Patched method.""" + return _patch_make_causal_mask(input_ids_shape, dtype, device, past_key_values_length, sliding_window) + + +class patched_DynamicCache: + """ + Applies modifications implemented in PR + `transformers/#36652 `_. + """ + + _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"] + _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. + A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or self.key_cache[layer_idx].numel() == 0 # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx].numel(): + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` + and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. + No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if key_states is not None: + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif not self.key_cache[layer_idx].numel(): # prefers not t.numel() to len(t) == 0 to export the model + # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def crop(self, max_length: int): + """Crop the past key values up to a new `max_length` + in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. + This is used in assisted decoding and contrastive search. + """ + # In case it is negative + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + if self.key_cache[idx].numel(): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + @classmethod + def from_batch_splits(cls, splits: list[DynamicCache]) -> DynamicCache: + """This is the opposite of the above `batch_split()` method. + This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls() + for idx in range(len(splits[0])): + key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx].numel()] + value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx].numel()] + if key_cache != []: + layer_keys = torch.cat(key_cache, dim=0) + layer_values = torch.cat(value_cache, dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + +class patched_GenerationMixin: + """ + Applies modifications implemented in PR + `transformers/#36652 `_. + """ + + _PATCHES_ = [ + "_cache_dependant_input_preparation", + "_cache_dependant_input_preparation_exporting", + "prepare_inputs_for_generation", + ] + _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin + + def _cache_dependant_input_preparation( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor | None, + cache_position: torch.LongTensor | None, + ) -> tuple[torch.FloatTensor, torch.LongTensor]: + """ + Generic cache-dependent input preparation + The code is put in a separate function to allow granular unit testing + as it needs a different implementation to be exportable. + + If we have cache: let's slice `input_ids` through `cache_position`, + to keep only the unprocessed tokens + - Exception 1: when passing input_embeds, + input_ids may be missing entries + - Exception 2: some generation methods do special slicing of input_ids, + so we don't need to do it here + - Exception 3: with synced GPUs cache_position may go out of bounds, + but we only want dummy token in that case. + - Exception 4: If input_embeds are passed then slice it through + `cache_position`, to keep only the unprocessed tokens and + generate the first token for each sequence. + Later use the generated Input ids for continuation. + + The current implementation does not rely on ``self`` and could be + a class method. It is left as a standard method to be easily rewritten. + Original code: + + .. code-block:: python + + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif inputs_embeds is not None or ( # Exception 1 + cache_position[-1] >= input_ids.shape[1] + ): # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + return inputs_embeds, input_ids + """ + return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position) + + def _cache_dependant_input_preparation_exporting( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor | None, + cache_position: torch.LongTensor | None, + ) -> tuple[torch.FloatTensor, torch.LongTensor]: + """ + This method implements method ``_cache_dependant_input_preparation`` + with :func:`torch.cond` to make it exportable with :func:`torch.export.export`. + The code is put in a separate function to allow granular unit testing. + """ + if inputs_embeds is None: + input_ids = input_ids[:, cache_position] + else: + # This is the code we need to implemented with torch.cond. + # if input_ids.shape[1] == 0: + # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + # else: + # if cache_position[-1] >= input_ids.shape[1]: + # input_ids = input_ids[:, -cache_position.shape[0] :] + # else: + # if input_ids.shape[1] != cache_position.shape[0]: + # input_ids = input_ids[:, cache_position] + def branch_1(inputs_embeds, cache_position): + return inputs_embeds[:, -cache_position.shape[0] :] + + def branch_2(input_ids, cache_position): + return input_ids[:, -cache_position.shape[0] :] + + def branch_3(input_ids, cache_position): + return input_ids[:, cache_position] + + inputs_embeds, input_ids = torch.cond( + input_ids.shape[1] == 0, + ( + lambda input_ids, inputs_embeds, cache_position: ( + branch_1(inputs_embeds, cache_position), + input_ids, + ) + ), + ( + lambda input_ids, inputs_embeds, cache_position: ( + inputs_embeds, + torch.cond( + cache_position[-1] >= input_ids.shape[1], + branch_2, + lambda input_ids, cache_position: ( + torch.cond( + input_ids.shape[1] != cache_position.shape[0], + branch_3, + (lambda input_ids, cache_position: input_ids), + [input_ids, cache_position], + ) + ), + [input_ids, cache_position], + ), + ) + ), + [input_ids, inputs_embeds, cache_position], + ) + return inputs_embeds, input_ids + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Cache | None = None, + attention_mask: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs, + ): + """ + Prepare the model inputs for generation. + In includes operations like computing the 4D attention mask or + slicing inputs given the existing cache. + + See the forward pass in the model documentation + for expected arguments (different models might have different + requirements for e.g. `past_key_values`). + This function should work as is for most LLMs. + """ + + # 1. Handle BC: + model_inputs = {} + # - some models don't have `Cache` support + # (which implies they don't expect `cache_position` in `forward`) + if self._supports_cache_class: + model_inputs["cache_position"] = cache_position + # - `cache_position` was not a mandatory input in + # `prepare_inputs_for_generation` for those models, and this + # function may be called outside of `generate`. + # Handle most use cases by creating `cache_position` on the fly + # (this alternative is not as robust as calling + # `generate` and letting it create `cache_position`) + elif cache_position is None: + past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + + # 2. Generic cache-dependent input preparation + if past_key_values is not None: + model_inputs["past_key_values"] = past_key_values + inputs_embeds, input_ids = self._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position) + + # 3. Prepare base model inputs + input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + # if `inputs_embeds` are passed, we only want + # to use them in the 1st generation step for every prompt. + if not self.config.is_encoder_decoder: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs[input_ids_key] = None + model_inputs["inputs_embeds"] = inputs_embeds + else: + # `clone` calls in this function ensure a consistent stride. See #32227 + model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) + model_inputs["inputs_embeds"] = None + else: + model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) + + # 4. Create missing `position_ids` on the fly + encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None + attention_mask = ( + kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask + ) + attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask" + position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids" + if ( + attention_mask is not None + and kwargs.get(position_ids_key) is None + and position_ids_key in set(inspect.signature(self.forward).parameters.keys()) + ): + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below) + + # 5. Slice model inputs if it's an input + # that should have the same length as `input_ids` + for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]: + model_input = kwargs.get(model_input_name) + if model_input is not None: + if past_key_values is not None: + current_input_length = ( + model_inputs["inputs_embeds"].shape[1] + if model_inputs.get("inputs_embeds") is not None + else model_inputs[input_ids_key].shape[1] + ) + model_input = model_input[:, -current_input_length:] + model_input = model_input.clone(memory_format=torch.contiguous_format) + model_inputs[model_input_name] = model_input + + # 6. Create 4D attention mask is we are using a + # `StaticCache` (important for performant compiled forward pass) + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs[input_ids_key].shape + device = model_inputs[input_ids_key].device + + # Create the causal mask with fixed shape in advance, + # to reduce recompilations. If the function to create + # the 4D causal mask exists, + # it should be present in the base model (XXXModel class). + base_model = getattr(self, self.base_model_prefix, None) + if base_model is None: + causal_mask_creation_function = getattr( + self, "_prepare_4d_causal_attention_mask_with_cache_position", None + ) + else: + causal_mask_creation_function = getattr( + base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None + ) + if causal_mask_creation_function is None: + pass + # logger.warning_once( + # f"{self.__class__.__name__} has no " + # "`_prepare_4d_causal_attention_mask_with_cache_position` method " + # "defined in its base modeling class. " + # "Compiled forward passes will be sub-optimal. If you're " + # "writing code, see Llama for an example implementation. " + # "If you're a user, please report this " + # "issue on GitHub." + # ) + else: + attention_mask = causal_mask_creation_function( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + if attention_mask is not None: + model_inputs[attention_mask_key] = attention_mask + + if encoder_attention_mask is not None: + model_inputs["attention_mask"] = encoder_attention_mask + + # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) + model_inputs.pop("labels", None) + return model_inputs diff --git a/pyproject.toml b/pyproject.toml index 09a203772aaf9..eed772f6e3cfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ ignore = [ "tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix "onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_triton.py" = ["N806"] # use of Q, K and V in triton script "onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_triton.py" = ["N806"] # use of Q, K and V in triton script +"onnxruntime/python/tools/transformers/models/torch_export_patches/*" = ["F401", "PLW0211", "N801", "N806", "RUF012"] # patches are based on pytorch code "onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix "onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix "orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo. diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index f4658f3a22c33..122c7651907b0 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -332,7 +332,7 @@ stages: python3 -m pip install -r requirements.txt ; \ popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ - python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --small_gp;\ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --small_gp --use_dynamo_export;\ ls -l llama2-7b-fp16; \ du -sh llama2-7b-fp16; \ popd ; \ @@ -353,7 +353,7 @@ stages: python3 -m pip install -r requirements.txt ; \ popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ - python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda;\ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda --use_dynamo_export;\ ls -l llama2-7b-fp32-gpu; \ du -sh llama2-7b-fp32-gpu; \ popd ; \ @@ -374,7 +374,7 @@ stages: python3 -m pip install -r requirements.txt ; \ popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ - python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-int4-gpu --precision int4 --execution_provider cuda --use_gqa;\ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-int4-gpu --precision int4 --execution_provider cuda --use_gqa --use_dynamo_export;\ ls -l llama2-7b-int4-gpu; \ du -sh llama2-7b-int4-gpu; \ popd ; \