Skip to content

Update whisper transformer module to 4.48.0 #24382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 71 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
5453405
first draft to migrate to newer version of transformers
xadupre Mar 28, 2025
31e82a9
add patches
xadupre Mar 28, 2025
299f116
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Mar 31, 2025
cdec2d0
fix import
xadupre Mar 31, 2025
827d3bd
fix build and import
xadupre Mar 31, 2025
18b649e
build
xadupre Mar 31, 2025
0e77ed4
fix lint
xadupre Mar 31, 2025
4633a3e
lint
xadupre Mar 31, 2025
b12287a
lint
xadupre Mar 31, 2025
1b926cb
rename
xadupre Mar 31, 2025
6646e61
lint
xadupre Mar 31, 2025
a14b8b3
lint
xadupre Mar 31, 2025
9f3a816
remove args.dynamo
xadupre Apr 1, 2025
0c88e42
fix issues
xadupre Apr 1, 2025
8b60535
copy inputs
xadupre Apr 1, 2025
741285b
fix shape
xadupre Apr 1, 2025
f8490a5
fix validation
xadupre Apr 2, 2025
dbe202c
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Apr 2, 2025
ca43041
add use_dynamo_export
xadupre Apr 2, 2025
19d4dfb
lint
xadupre Apr 2, 2025
49fb806
Merge branch 'llama2' of https://github.com/xadupre/onnxruntime into …
xadupre Apr 3, 2025
8d3b0ba
fix requirements
xadupre Apr 3, 2025
835b76e
fix requitmeents
xadupre Apr 3, 2025
a0a8c21
fix dynamic shapes
xadupre Apr 3, 2025
f61c27b
2.6
xadupre Apr 3, 2025
902c6af
remove duplicated section
xadupre Apr 3, 2025
e3188ad
lint
xadupre Apr 3, 2025
ca8233f
Update whisper transformer module to 4.48.0
jchen351 Apr 10, 2025
6378cc0
Merge remote-tracking branch 'origin/xadupre/llama' into Cjian/whisper
jchen351 Apr 10, 2025
d5eedbc
Merge remote-tracking branch 'origin/main' into Cjian/whisper
jchen351 Apr 12, 2025
1d1c650
Install ninja in requirements.txt
jchen351 Apr 14, 2025
531d1cf
Add check if ninjia is installed
jchen351 Apr 15, 2025
a3c1bc1
merge conflicts
xadupre Apr 18, 2025
bd78ca0
update for whisper
xadupre Apr 18, 2025
2538793
remove wrong import
xadupre Apr 18, 2025
255f393
remove unusd import
xadupre Apr 18, 2025
6044947
lint
xadupre Apr 18, 2025
40aed64
almost there
xadupre Apr 18, 2025
70374db
Update onnxruntime/python/tools/transformers/models/torch_export_patc…
xadupre Apr 18, 2025
89e18b7
Update onnxruntime/python/tools/transformers/models/torch_export_patc…
xadupre Apr 18, 2025
9b1d9bb
Update onnxruntime/python/tools/transformers/models/torch_export_patc…
xadupre Apr 18, 2025
ad03686
Update onnxruntime/python/tools/transformers/models/whisper/common_on…
xadupre Apr 18, 2025
f8a5910
Update onnxruntime/python/tools/transformers/models/torch_export_patc…
xadupre Apr 18, 2025
31640bb
other fixes
xadupre Apr 18, 2025
5cda3f4
decoder ok
xadupre Apr 18, 2025
e5e032b
fix encoder
xadupre Apr 18, 2025
0fc8d15
+copyright
xadupre Apr 18, 2025
4446827
Update onnxruntime/python/tools/transformers/models/torch_export_patc…
xadupre Apr 18, 2025
c890afd
Update onnxruntime/python/tools/transformers/models/whisper/whisper_i…
xadupre Apr 18, 2025
6e8b37f
Update onnxruntime/python/tools/transformers/models/whisper/common_on…
xadupre Apr 18, 2025
bf364b6
Update onnxruntime/python/tools/transformers/models/whisper/common_on…
xadupre Apr 18, 2025
6ff2ccb
Update onnxruntime/python/tools/transformers/models/whisper/convert_t…
xadupre Apr 18, 2025
2b1ea1f
Update onnxruntime/python/tools/transformers/models/torch_export_patc…
xadupre Apr 18, 2025
670e504
Update onnxruntime/python/tools/transformers/models/whisper/common_on…
xadupre Apr 18, 2025
61ff3ad
Update onnxruntime/python/tools/transformers/models/whisper/common_on…
xadupre Apr 18, 2025
7d73351
Format
justinchuby Apr 18, 2025
a987583
Adding optree to the pip
jchen351 Apr 24, 2025
ccecdac
Adding optree to the pip
jchen351 Apr 25, 2025
aea3af0
Merge branch 'main' into Cjian/whisper
jchen351 Apr 27, 2025
708bf66
Update onnxruntime/python/tools/transformers/models/whisper/common_on…
jchen351 Apr 27, 2025
b900016
Update onnxruntime/python/tools/transformers/models/whisper/common_on…
jchen351 Apr 27, 2025
e6c9e55
Merge with main
jchen351 Apr 27, 2025
145dc5d
Merge remote-tracking branch 'origin/Cjian/whisper' into Cjian/whisper
jchen351 Apr 27, 2025
1df40d6
Merge with main
jchen351 Apr 27, 2025
16b62d6
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Apr 28, 2025
160cec4
Merge branch 'main' into Cjian/whisper
jchen351 Apr 28, 2025
75fb629
Merge with main
jchen351 Apr 29, 2025
d0edc48
Merge branch 'Cjian/wishper2' into Cjian/whisper
jchen351 Apr 29, 2025
75d841e
Merge with main
jchen351 Apr 29, 2025
79621a6
Update onnxruntime/python/tools/transformers/models/torch_export_patc…
jchen351 Apr 30, 2025
dabb0d0
Update onnxruntime/python/tools/transformers/models/torch_export_patc…
jchen351 Apr 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from dataclasses import fields, is_dataclass
from typing import Any

import numpy as np
import packaging.version as pv
import torch
from onnx import TensorProto
from onnx.helper import np_dtype_to_tensor_dtype
from transformers import __version__ as transformers_version
from transformers.cache_utils import DynamicCache, EncoderDecoderCache

from .cache_helper import make_dynamic_cache, make_encoder_decoder_cache
from .onnx_export_errors import (
bypass_export_some_errors,
register_additional_serialization_functions,
Expand All @@ -32,6 +41,52 @@ def is_torchdynamo_exporting() -> bool:
return False


def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int:
"""
Converts a torch dtype into a onnx element type.

:param to: torch dtype
:return: onnx type
"""
import torch

if to == torch.float32:
return TensorProto.FLOAT
if to == torch.float16:
return TensorProto.FLOAT16
if to == torch.bfloat16:
return TensorProto.BFLOAT16
if to == torch.float64:
return TensorProto.DOUBLE
if to == torch.int64:
return TensorProto.INT64
if to == torch.int32:
return TensorProto.INT32
if to == torch.uint64:
return TensorProto.UINT64
if to == torch.uint32:
return TensorProto.UINT32
if to == torch.bool:
return TensorProto.BOOL
if to == torch.SymInt:
return TensorProto.INT64
if to == torch.int16:
return TensorProto.INT16
if to == torch.uint16:
return TensorProto.UINT16
if to == torch.int8:
return TensorProto.INT8
if to == torch.uint8:
return TensorProto.UINT8
if to == torch.SymFloat:
return TensorProto.FLOAT
if to == torch.complex64:
return TensorProto.COMPLEX64
if to == torch.complex128:
return TensorProto.COMPLEX128
raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.")


def string_type(
obj: Any,
with_shape: bool = False,
Expand Down Expand Up @@ -178,8 +233,6 @@ def string_type(
return f"dict({s})"
# array
if isinstance(obj, np.ndarray):
from .onnx_helper import np_dtype_to_tensor_dtype

if with_min_max:
s = string_type(obj, with_shape=with_shape)
if len(obj.shape) == 0:
Expand Down Expand Up @@ -257,16 +310,12 @@ def string_type(

# Tensors
if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
from .onnx_helper import torch_dtype_to_onnx_dtype

i = torch_dtype_to_onnx_dtype(obj.dtype)
prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
if not with_shape:
return f"{prefix}F{i}r{len(obj.shape)}"
return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
if isinstance(obj, torch.Tensor):
from .onnx_helper import torch_dtype_to_onnx_dtype

if with_min_max:
s = string_type(obj, with_shape=with_shape, with_device=with_device)
if len(obj.shape) == 0:
Expand Down Expand Up @@ -307,6 +356,57 @@ def string_type(
return f"OV{dt}s{'x'.join(map(str, shape))}"
return f"OV{dt}r{len(shape)}"

if obj.__class__.__name__ == "MambaCache":
c = string_type(
obj.conv_states,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
d = string_type(
obj.ssm_states,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return f"MambaCache(conv_states={c}, ssm_states={d})"

if obj.__class__.__name__ == "DynamicCache":
kc = string_type(
obj.key_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
vc = string_type(
obj.value_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"

if obj.__class__.__name__ == "EncoderDecoderCache":
att = string_type(
obj.self_attention_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
cross = string_type(
obj.cross_attention_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return f"{obj.__class__.__name__}(self_attention_cache={att}, cross_attention_cache={cross})"

# others classes

if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
Expand Down Expand Up @@ -367,59 +467,6 @@ def string_type(
):
return repr(obj).replace(" ", "").replace("\n", " ")

# to avoid failures

if obj.__class__.__name__ == "MambaCache":
c = string_type(
obj.conv_states,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
d = string_type(
obj.ssm_states,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return f"MambaCache(conv_states={c}, ssm_states={d})"

if obj.__class__.__name__ == "DynamicCache":
kc = string_type(
obj.key_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
vc = string_type(
obj.value_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"

if obj.__class__.__name__ == "EncoderDecoderCache":
att = string_type(
obj.self_attention_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
cross = string_type(
obj.cross_attention_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return f"{obj.__class__.__name__}(self_attention_cache={att}, cross_attention_cache={cross})"

raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")


Expand Down Expand Up @@ -461,3 +508,39 @@ def make_encoder_decoder_cache(
) -> EncoderDecoderCache:
"Creates an EncoderDecoderCache."
return EncoderDecoderCache(self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache)


def torch_deepcopy(value: Any) -> Any:
"""Makes a deepcopy."""
if isinstance(value, (int, float, str)):
return value
if isinstance(value, tuple):
return tuple(torch_deepcopy(v) for v in value)
if isinstance(value, list):
return [torch_deepcopy(v) for v in value]
if isinstance(value, set):
return {torch_deepcopy(v) for v in value}
if isinstance(value, dict):
if type(value) is dict:
return {k: torch_deepcopy(v) for k, v in value.items()}
# for BaseModelOutput
return value.__class__(**{k: torch_deepcopy(v) for k, v in value.items()})
if isinstance(value, np.ndarray):
return value.copy()
if hasattr(value, "clone"):
return value.clone()
if value.__class__.__name__ == "DynamicCache":
return make_dynamic_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache, strict=False))))
if value.__class__.__name__ == "EncoderDecoderCache":
return make_encoder_decoder_cache(
torch_deepcopy(value.self_attention_cache),
torch_deepcopy(value.cross_attention_cache),
)
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
args, spec = torch.utils._pytree.tree_flatten(value)
new_args = torch_deepcopy(args)
return torch.utils._pytree.tree_unflatten(new_args, spec)

# We should have a code using serialization, deserialization assuming a model
# cannot be exported without them.
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import packaging.version as pv
import torch
from transformers import __version__ as transformers_version
Expand Down Expand Up @@ -30,6 +36,25 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
return len(cache2.key_cache) == len(cache.value_cache)


def flatten_unflatten_for_dynamic_shapes(obj):
"""
Returns the object in a different structure similar to what
the definition of the dynamic shapes should use.

:param obj: object from a custom class
:return: the serialized object
"""
flat, spec = torch.utils._pytree.tree_flatten(obj)
start = 0
end = 0
subtrees = []
for subspec in spec.children_specs:
end += subspec.num_leaves
subtrees.append(subspec.unflatten(flat[start:end]))
start = end
return subtrees


if pv.Version(transformers_version) > pv.Version("4.49.99999"):

def make_dynamic_cache(
Expand Down
Loading
Loading