-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[WIP] Support export of Llama with DynamicCache and transformers>=4.48 #24291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
5453405
first draft to migrate to newer version of transformers
xadupre 31e82a9
add patches
xadupre 299f116
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre cdec2d0
fix import
xadupre 827d3bd
fix build and import
xadupre 18b649e
build
xadupre 0e77ed4
fix lint
xadupre 4633a3e
lint
xadupre b12287a
lint
xadupre 1b926cb
rename
xadupre 6646e61
lint
xadupre a14b8b3
lint
xadupre 9f3a816
remove args.dynamo
xadupre 0c88e42
fix issues
xadupre 8b60535
copy inputs
xadupre 741285b
fix shape
xadupre f8490a5
fix validation
xadupre dbe202c
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre ca43041
add use_dynamo_export
xadupre 19d4dfb
lint
xadupre 49fb806
Merge branch 'llama2' of https://github.com/xadupre/onnxruntime into …
xadupre 8d3b0ba
fix requirements
xadupre 835b76e
fix requitmeents
xadupre a0a8c21
fix dynamic shapes
xadupre f61c27b
2.6
xadupre 902c6af
remove duplicated section
xadupre e3188ad
lint
xadupre File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 3 additions & 1 deletion
4
onnxruntime/python/tools/transformers/models/llama/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
onnxscript>=0.2.3 | ||
optree | ||
optimum>=1.14.1 | ||
transformers>=4.33.2,<= 4.38.0 | ||
transformers==4.48.0 | ||
torch>=2.2.0 | ||
onnx==1.17.0 | ||
datasets>=2.8.0 | ||
|
112 changes: 112 additions & 0 deletions
112
onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from typing import Any | ||
|
||
import packaging.version as pv | ||
import torch | ||
import transformers | ||
|
||
from .onnx_export_errors import ( | ||
bypass_export_some_errors, | ||
register_additional_serialization_functions, | ||
) | ||
|
||
|
||
def is_torchdynamo_exporting() -> bool: | ||
"Tells if torch is exporting a model." | ||
import torch | ||
|
||
if not hasattr(torch.compiler, "is_exporting"): | ||
# torch.compiler.is_exporting requires torch>=2.7 | ||
return False | ||
|
||
try: | ||
return torch.compiler.is_exporting() | ||
except Exception: | ||
try: | ||
import torch._dynamo as dynamo | ||
|
||
return dynamo.is_exporting() # type: ignore | ||
except Exception: | ||
return False | ||
|
||
|
||
def string_type(anything, **args): | ||
# too long | ||
return str(anything) | ||
|
||
|
||
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): | ||
|
||
def make_dynamic_cache( | ||
key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], | ||
) -> transformers.cache_utils.DynamicCache: | ||
""" | ||
Creates an instance of :class:`transformers.cache_utils.DynamicCache`. | ||
This version is valid for ``transformers >= 4.50``. | ||
|
||
:param key_value_pairs: list of pairs of (key, values) | ||
:return: :class:`transformers.cache_utils.DynamicCache` | ||
|
||
Example: | ||
|
||
:: | ||
|
||
n_layers = 2 | ||
bsize, nheads, slen, dim = 2, 4, 3, 7 | ||
|
||
past_key_values = make_dynamic_cache( | ||
[ | ||
( | ||
torch.randn(bsize, nheads, slen, dim), | ||
torch.randn(bsize, nheads, slen, dim), | ||
) | ||
for i in range(n_layers) | ||
] | ||
) | ||
print(string_type(past_key_values, with_shape=True)) | ||
""" | ||
return transformers.cache_utils.DynamicCache(key_value_pairs) | ||
|
||
else: | ||
|
||
def make_dynamic_cache( | ||
key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], | ||
) -> transformers.cache_utils.DynamicCache: | ||
""" | ||
Creates an instance of :class:`transformers.cache_utils.DynamicCache`. | ||
This version is valid for ``transformers < 4.50``. | ||
|
||
:param key_value_pairs: list of pairs of (key, values) | ||
:return: :class:`transformers.cache_utils.DynamicCache` | ||
|
||
Example: | ||
|
||
:: | ||
|
||
n_layers = 2 | ||
bsize, nheads, slen, dim = 2, 4, 3, 7 | ||
|
||
past_key_values = make_dynamic_cache( | ||
[ | ||
( | ||
torch.randn(bsize, nheads, slen, dim), | ||
torch.randn(bsize, nheads, slen, dim), | ||
) | ||
for i in range(n_layers) | ||
] | ||
) | ||
print(string_type(past_key_values, with_shape=True)) | ||
""" | ||
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) | ||
for i, (key, value) in enumerate(key_value_pairs): | ||
cache.update(key, value, i) | ||
return cache | ||
|
||
|
||
def make_encoder_decoder_cache( | ||
self_attention_cache: transformers.cache_utils.DynamicCache, | ||
cross_attention_cache: transformers.cache_utils.DynamicCache, | ||
) -> transformers.cache_utils.EncoderDecoderCache: | ||
"Creates an EncoderDecoderCache." | ||
return transformers.cache_utils.EncoderDecoderCache( | ||
self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check notice
Code scanning / CodeQL
Module is imported with 'import' and 'import from' Note
Copilot Autofix
AI 13 days ago
To fix the problem, we should remove the
from transformers import AutoConfig, AutoTokenizer
statement and access these components directly from thetransformers
module. This will make the code more consistent and easier to understand.from transformers import AutoConfig, AutoTokenizer
statement.AutoConfig
andAutoTokenizer
to usetransformers.AutoConfig
andtransformers.AutoTokenizer
, respectively.