Skip to content

[WIP] Support export of Llama with DynamicCache and transformers>=4.48 #24291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,12 @@ file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS
file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py"
)
file(GLOB onnxruntime_python_transformers_models_torch_export_patches_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/torch_export_patches/*.py"
)
file(GLOB onnxruntime_python_transformers_models_torch_export_patches_patches_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/torch_export_patches/patches/*.py"
)
file(GLOB onnxruntime_python_transformers_models_bart_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bart/*.py"
)
Expand Down Expand Up @@ -566,6 +572,8 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/sam2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/torch_export_patches
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/torch_export_patches/patches
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/whisper
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/operators
Expand Down Expand Up @@ -682,6 +690,12 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_t5_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_torch_export_patches_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/torch_export_patches/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_torch_export_patches_patches_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/torch_export_patches/patches/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_whisper_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/whisper/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs
from llama_parity import main as parity_check
from llama_torch import setup_torch_model

# to patch transformers before exporting for transformers >= 4.45
from models.torch_export_patches import bypass_export_some_errors
from models.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes, replace_dynamic_shapes
from onnx_model import OnnxModel
from optimizer import optimize_model
from packaging import version
Expand Down Expand Up @@ -122,7 +126,7 @@ def run_dynamo_export(
config.capture_scalar_outputs = True

# Dummy values for export
batch_size, sequence_length, past_sequence_length = 2, 8, 0
batch_size, sequence_length, past_sequence_length = 2, 8, 3
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")

temp_name = args.model_name.lower().replace("-", "").replace("_", "")
Expand All @@ -141,9 +145,76 @@ def run_dynamo_export(
)
temp_dir = tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.dynamo_export(
llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(temp_path)

input_names = ["input_ids", "attention_mask", "position_ids"]
output_names = [
"logits",
*list(
chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
),
]
dynamic_axes = get_model_dynamic_axes(input_names, output_names)

model_args = (input_ids, attn_mask, pos_ids, past_kv)
model_args, model_kwargs, dynamic_shapes = convert_dynamic_axes_into_dynamic_shapes(
llama, args=model_args, dynamic_axes=dynamic_axes, prefix_mapping={"present": "past_key_values"}
)

if version.Version(torch.__version__) < version.Version("2.7"):
# This section is only needed for torch==2.6. The workaround implemented here
# to fix bugs is not necessary with torch>=2.7.
# - strings are not allowed with torch 2.6, so we replace them by DYNAMIC
# - TypePromotion was fixed in torch==2.7
from onnxscript import opset18 as op

dynamic_shapes = replace_dynamic_shapes(
dynamic_shapes,
dict(batch_size=torch.export.Dim("batch_size")),
default_value=torch.export.Dim.DYNAMIC,
)

# TypePromotion cannot fix a type issue after the conversion.
# We insert an additional CastLike when the exporter
def custom_aten_ge(self, other):
if isinstance(other, (int, float)):
return op.GreaterOrEqual(self, op.CastLike(other, self))
return op.GreaterOrEqual(self, other)

with bypass_export_some_errors(patch_transformers=True):
# ONNX pass TypePromotion crashes for torch 2.6.
# It can be bypassed by exporting first into an exported program.
# We then need to apply run_decompositions() before onnx conversion starts.
ep = torch.export.export(
llama,
(),
kwargs=model_kwargs,
dynamic_shapes=dynamic_shapes,
strict=False,
)
ep = ep.run_decompositions()
torch.onnx.export(
ep,
(),
temp_path,
kwargs=model_kwargs,
dynamic_shapes=dynamic_shapes,
dynamo=True,
verbose=args.verbose,
optimize=True,
custom_translation_table={torch.ops.aten.ge.Scalar: custom_aten_ge},
)
else:
with bypass_export_some_errors(patch_transformers=True):
torch.onnx.export(
llama,
(),
temp_path,
kwargs=model_kwargs,
dynamic_shapes=dynamic_shapes,
dynamo=True,
verbose=args.verbose,
optimize=True,
)

# Check decoder_with_past_model.onnx and save all external data to one file
onnx.checker.check_model(temp_path)
Expand Down Expand Up @@ -330,6 +401,7 @@ def run_torchscript_merged_export(
temp_dir = f"./temp_{rank}"
_prepare_dir(temp_dir)
temp_path = os.path.join(temp_dir, "temp.onnx")

torch.onnx.export(
llama,
args=decoder_merged_inputs,
Expand All @@ -341,6 +413,7 @@ def run_torchscript_merged_export(
opset_version=torch_export_onnx_opset_version,
do_constant_folding=True,
verbose=args.verbose,
dynamo=False,
)

# Check decoder_merged_model.onnx and save all external data to one file
Expand Down Expand Up @@ -862,9 +935,6 @@ def main():
decoder_merged_model_fp32_opt_path,
]

if args.use_dynamo_export:
continue

# Run the optimizer script.
logger.info("Optimizing models...")
for orig_path, opt_path in zip(old_paths, new_paths, strict=False):
Expand Down Expand Up @@ -970,9 +1040,6 @@ def main():
remove_existing_model(fp_path)
barrier()

if args.use_dynamo_export:
return

logger.info("Verifying parity on all ONNX models created")

# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
import transformers

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'transformers' is imported with both 'import' and 'import from'.
Module 'onnxruntime.test.python.transformers' is imported with both 'import' and 'import from'.

Copilot Autofix

AI 13 days ago

To fix the problem, we should remove the from transformers import AutoConfig, AutoTokenizer statement and access these components directly from the transformers module. This will make the code more consistent and easier to understand.

  • Remove the from transformers import AutoConfig, AutoTokenizer statement.
  • Update the references to AutoConfig and AutoTokenizer to use transformers.AutoConfig and transformers.AutoTokenizer, respectively.
Suggested changeset 1
onnxruntime/python/tools/transformers/models/llama/llama_inputs.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
--- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
+++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
@@ -10,3 +10,2 @@
 import transformers
-from transformers import AutoConfig, AutoTokenizer
 
@@ -32,3 +31,3 @@
 def get_sample_inputs(
-    config: AutoConfig,
+    config: transformers.AutoConfig,
     device: torch.device,
@@ -67,3 +66,3 @@
 def get_sample_with_past_kv_inputs(
-    config: AutoConfig,
+    config: transformers.AutoConfig,
     device: torch.device,
EOF
@@ -10,3 +10,2 @@
import transformers
from transformers import AutoConfig, AutoTokenizer

@@ -32,3 +31,3 @@
def get_sample_inputs(
config: AutoConfig,
config: transformers.AutoConfig,
device: torch.device,
@@ -67,3 +66,3 @@
def get_sample_with_past_kv_inputs(
config: AutoConfig,
config: transformers.AutoConfig,
device: torch.device,
Copilot is powered by AI and may make mistakes. Always verify output.
from transformers import AutoConfig, AutoTokenizer

from onnxruntime import InferenceSession, OrtValue
Expand Down Expand Up @@ -240,8 +241,12 @@
def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
past_kv = {}
for i, (past_k, past_v) in enumerate(past_key_values):
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
if isinstance(past_key_values, transformers.cache_utils.DynamicCache):
past_kv[f"past_key_values_key_cache_{i}"] = past_k.detach().cpu().numpy()
past_kv[f"past_key_values_value_cache_{i}"] = past_v.detach().cpu().numpy()
else:
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
return past_kv


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import time

import numpy as np
import packaging.version as pv
import torch
import transformers

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'transformers' is imported with both 'import' and 'import from'.
Module 'onnxruntime.test.python.transformers' is imported with both 'import' and 'import from'.

Copilot Autofix

AI 13 days ago

To fix the problem, we should remove the from transformers import AutoConfig statement and access AutoConfig through the transformers module instead. This will ensure that the module is only imported once and will make the code more consistent and easier to understand.

  • Remove the from transformers import AutoConfig statement.
  • Replace all instances of AutoConfig with transformers.AutoConfig.
Suggested changeset 1
onnxruntime/python/tools/transformers/models/llama/llama_parity.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py
--- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py
+++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py
@@ -28,3 +28,2 @@
 from models.torch_export_patches.cache_helper import make_dynamic_cache
-from transformers import AutoConfig
 
@@ -35,3 +34,3 @@
 
-def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig):
+def get_sequence_lengths(args: argparse.Namespace, config: transformers.AutoConfig):
     past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
@@ -41,3 +40,3 @@
 
-def get_inputs(args: argparse.Namespace, config: AutoConfig):
+def get_inputs(args: argparse.Namespace, config: transformers.AutoConfig):
     # Dummy values for parity
@@ -104,3 +103,3 @@
     pytorch_model: None | torch.nn.Module = None,
-    config: None | AutoConfig = None,
+    config: None | transformers.AutoConfig = None,
 ):
EOF
@@ -28,3 +28,2 @@
from models.torch_export_patches.cache_helper import make_dynamic_cache
from transformers import AutoConfig

@@ -35,3 +34,3 @@

def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig):
def get_sequence_lengths(args: argparse.Namespace, config: transformers.AutoConfig):
past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
@@ -41,3 +40,3 @@

def get_inputs(args: argparse.Namespace, config: AutoConfig):
def get_inputs(args: argparse.Namespace, config: transformers.AutoConfig):
# Dummy values for parity
@@ -104,3 +103,3 @@
pytorch_model: None | torch.nn.Module = None,
config: None | AutoConfig = None,
config: None | transformers.AutoConfig = None,
):
Copilot is powered by AI and may make mistakes. Always verify output.
from benchmark_helper import setup_logger
from dist_settings import get_rank, get_size
from llama_inputs import (
Expand All @@ -23,6 +25,7 @@
verify_ort_inputs,
)
from llama_torch import setup_torch_model
from models.torch_export_patches.cache_helper import make_dynamic_cache
from transformers import AutoConfig

import onnxruntime as ort
Expand Down Expand Up @@ -71,6 +74,28 @@
return inputs


def torch_deepcopy(value):
if isinstance(value, (int, float, str)):
return value
if isinstance(value, tuple):
return tuple(torch_deepcopy(v) for v in value)
if isinstance(value, list):
return [torch_deepcopy(v) for v in value]
if isinstance(value, set):
return {torch_deepcopy(v) for v in value}
if isinstance(value, dict):
return {k: torch_deepcopy(v) for k, v in value.items()}
if isinstance(value, np.ndarray):
return value.copy()
if hasattr(value, "clone"):
return value.clone()
if isinstance(value, transformers.cache_utils.DynamicCache):
return make_dynamic_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache, strict=False))))
# We should have a code using serialization, deserialization assuming a model
# cannot be exported without them.
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")


def verify_parity(
args: argparse.Namespace,
location: str,
Expand All @@ -92,11 +117,18 @@

inputs = get_inputs(args, config)

if "past_key_values" in inputs and pv.Version(transformers.__version__) >= pv.Version("4.45"):
# Using DynamicCache
inputs["past_key_values"] = make_dynamic_cache(inputs["past_key_values"])

# Run inference with PyTorch
if args.execution_provider != "cpu":
torch.cuda.synchronize()
start_time = time.time()
pt_outputs = py_model(**inputs).logits.detach().cpu().numpy()
# If there is a cache in the inputs, we need to make a copy as the model modify them inplace.
# DynamicCache inherits from torch.nn.Module in some version of transformers.
# We need to make the copy manually.
pt_outputs = py_model(**torch_deepcopy(inputs)).logits.detach().cpu().numpy()
if args.execution_provider != "cpu":
torch.cuda.synchronize()
end_time = time.time()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
onnxscript>=0.2.3
optree
optimum>=1.14.1
transformers>=4.33.2,<= 4.38.0
transformers==4.48.0
torch>=2.2.0
onnx==1.17.0
datasets>=2.8.0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Any

import packaging.version as pv
import torch
import transformers

from .onnx_export_errors import (
bypass_export_some_errors,
register_additional_serialization_functions,
)


def is_torchdynamo_exporting() -> bool:
"Tells if torch is exporting a model."
import torch

if not hasattr(torch.compiler, "is_exporting"):
# torch.compiler.is_exporting requires torch>=2.7
return False

try:
return torch.compiler.is_exporting()
except Exception:
try:
import torch._dynamo as dynamo

return dynamo.is_exporting() # type: ignore
except Exception:
return False


def string_type(anything, **args):
# too long
return str(anything)


if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):

def make_dynamic_cache(
key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
"""
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
This version is valid for ``transformers >= 4.50``.

:param key_value_pairs: list of pairs of (key, values)
:return: :class:`transformers.cache_utils.DynamicCache`

Example:

::

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7

past_key_values = make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
)
for i in range(n_layers)
]
)
print(string_type(past_key_values, with_shape=True))
"""
return transformers.cache_utils.DynamicCache(key_value_pairs)

else:

def make_dynamic_cache(
key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
"""
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
This version is valid for ``transformers < 4.50``.

:param key_value_pairs: list of pairs of (key, values)
:return: :class:`transformers.cache_utils.DynamicCache`

Example:

::

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7

past_key_values = make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
)
for i in range(n_layers)
]
)
print(string_type(past_key_values, with_shape=True))
"""
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs))
for i, (key, value) in enumerate(key_value_pairs):
cache.update(key, value, i)
return cache


def make_encoder_decoder_cache(
self_attention_cache: transformers.cache_utils.DynamicCache,
cross_attention_cache: transformers.cache_utils.DynamicCache,
) -> transformers.cache_utils.EncoderDecoderCache:
"Creates an EncoderDecoderCache."
return transformers.cache_utils.EncoderDecoderCache(
self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache
)
Loading
Loading