Skip to content

feat: Refactor LLM model zoo and add KV cache support #3527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e5723db
chore: kv cache prototyping
peri044 Apr 25, 2025
3e0b46a
chore: add sdpa converter/lowering
peri044 Apr 27, 2025
e30fa42
feat: implement static/dynamic kv cache in Torch-TRT
May 7, 2025
a3a202f
chore: updates
May 13, 2025
3642614
chore: updates
peri044 May 14, 2025
3688630
chore: updates
peri044 May 16, 2025
c9f5f27
chore: refactor updates
peri044 May 20, 2025
0cb0dcc
chore: refactor updates
peri044 May 20, 2025
6cbb1bd
chore: updates
peri044 May 20, 2025
f539b55
chore: updates
May 27, 2025
0dc3a7e
chore: updates
peri044 May 28, 2025
095b5cf
chore: rebase with main
peri044 May 28, 2025
600e363
feat: Refactor LLM runner and implemented support for Qwen family
peri044 May 31, 2025
9309725
chore: updates
peri044 Jun 4, 2025
a50e7ac
chore: updates
Jun 5, 2025
cbf0d43
chore: set use_fp32_acc to False
peri044 Jun 5, 2025
817be62
chore: updates
peri044 Jun 6, 2025
7a06635
chore: updates
peri044 Jun 7, 2025
f47d6ff
chore: add static_cache_v3
peri044 Jun 7, 2025
535c6a8
chore: remove conditional branching for causal attention
peri044 Jun 11, 2025
7b7ac04
chore: remove conditional branching for causal attention
peri044 Jun 11, 2025
c1f0053
chore: remove is_causal input now that causal attention enhancement i…
peri044 Jun 11, 2025
8301ee6
chore: refactor
peri044 Jun 12, 2025
ecf88d1
chore: move code to tools/llm
peri044 Jun 12, 2025
723d1b2
chore: remove files not considered for release
peri044 Jun 12, 2025
b835ae9
chore: rebase with main
peri044 Jun 12, 2025
5eabf65
chore: updates
peri044 Jun 13, 2025
806616d
chore: Add README.md
peri044 Jun 13, 2025
53a2c56
chore: rebase with main
peri044 Jun 13, 2025
c2cbd5a
chore: add docs
peri044 Jun 13, 2025
81114f8
chore: Add a tutorial
peri044 Jun 13, 2025
57da513
chore: fix model name
peri044 Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,10 @@ Model Zoo
* :ref:`torch_compile_resnet`
* :ref:`torch_compile_transformer`
* :ref:`torch_compile_stable_diffusion`
* :ref:`compile_hf_models`
* :ref:`torch_compile_gpt2`
* :ref:`torch_export_gpt2`
* :ref:`torch_export_llama2`
* :ref:`torch_export_sam2`
* :ref:`torch_export_flux_dev`
* :ref:`notebooks`

.. toctree::
Expand All @@ -155,11 +154,10 @@ Model Zoo
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
tutorials/compile_hf_models
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
tutorials/_rendered_examples/dynamo/torch_compile_gpt2
tutorials/_rendered_examples/dynamo/torch_export_gpt2
tutorials/_rendered_examples/dynamo/torch_export_llama2
tutorials/_rendered_examples/dynamo/torch_export_sam2
tutorials/_rendered_examples/dynamo/torch_export_flux_dev
tutorials/notebooks
Expand Down
218 changes: 218 additions & 0 deletions docsrc/tutorials/compile_hf_models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
.. _compile_hf_models:

Compiling LLM models from Huggingface
======================================

This tutorial walks you through how to compile LLM models from Huggingface using Torch-TensorRT. We also introduce KV caching in Torch-TensorRT which can greatly improve the performance of LLM inference.
The code is available in the `tools/llm <https://github.com/pytorch/TensorRT/tree/main/tools/llm>`_ directory. We use the ``run_llm.py`` script to compile the model, generate outputs, and measure the performance.

.. note::
This is an **experimental release** and APIs may change in future versions.

.. note::
The compilation scripts and tutorials for Llama-2-7b-chat-hf and gpt2 models have been consolidated into the unified ``run_llm.py`` script located in the `tools/llm <https://github.com/pytorch/TensorRT/tree/main/tools/llm>`_ directory.

Overview of tools/llm Directory
-------------------------------

The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface:

* **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking
* **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization
* **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass.
* **Testing Components**: Model-specific test files for validation
* **Utility Functions**: ``utils.py`` and ``cache_utils.py`` for common operations

Supported Models
----------------
We have officially verified support for the following LLM families:

.. list-table::
:widths: 20 40 20 20
:header-rows: 1

* - Model Series
- HuggingFace Model Card
- Precision
- KV Cache Support ?
* - GPT-2
- gpt2
- FP16, FP32
- Yes
* - LLaMA 2
- meta-llama/Llama-2-7b-chat-hf
- FP16, FP32
- Yes
* - LLaMA 3.1
- meta-llama/Llama-3.1-8B-Instruct
- FP16, FP32
- Yes
* - LLaMA 3.2
- | meta-llama/Llama-3.2-1B-Instruct
| meta-llama/Llama-3.2-3B-Instruct
- FP16, FP32
- Yes
* - Qwen 2.5
- | Qwen/Qwen2.5-0.5B-Instruct
| Qwen/Qwen2.5-1.5B-Instruct
| Qwen/Qwen2.5-3B-Instruct
| Qwen/Qwen2.5-7B-Instruct
- FP16, FP32
- Yes

Getting Started with run_llm.py
-------------------------------

The main entry point is ``run_llm.py``, which provides a complete workflow for model compilation and benchmarking.

Basic Usage
^^^^^^^^^^^

.. code-block:: bash

python tools/llm/run_llm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--prompt "What is parallel programming?" \
--precision FP16 \
--num_tokens 128 \
--cache static_v2 \
--benchmark

Key Arguments
^^^^^^^^^^^^^

* ``--model``: Name or path of the HuggingFace LLM
* ``--tokenizer``: (Optional) Tokenizer name; defaults to model name
* ``--prompt``: Input prompt for text generation
* ``--precision``: Precision mode (``FP16``, ``FP32``)
* ``--num_tokens``: Number of output tokens to generate
* ``--cache``: KV cache type (``static_v1``, ``static_v2``, or empty for no KV caching)
* ``--benchmark``: Enable benchmarking mode for performance comparison
* ``--enable_pytorch_run``: Also run and compare PyTorch baseline


Other Usage Examples
^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash

# Compare different models performance
python tools/llm/run_llm.py --model gpt2 --benchmark --enable_pytorch_run
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --benchmark --enable_pytorch_run

# Generate the outputs (disable benchmarking) by specifying the number of tokens to generate. Default = 128
python tools/llm/run_llm.py --model gpt2 --prompt "What is parallel programming?" --num_tokens 128
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --num_tokens 128

# Test different caching approaches
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v1
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v2

# Compare FP16 vs FP32 performance
python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP16 --benchmark
python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark


KV Caching in Torch-TensorRT
---------------------------------

We provide two versions of static KV caching: `static_cache_v1 <https://github.com/pytorch/TensorRT/blob/main/tools/llm/static_cache_v1.py>`_ and `static_cache_v2 <https://github.com/pytorch/TensorRT/blob/main/tools/llm/static_cache_v2.py>`_.
In both implementations, we add static KV cache tensors as model inputs/outputs without storing them as external memory.
The length of KV cache = input sequence length + output sequence length (specified by ``--num_tokens``). The number of heads and head dimension are determined by the model config.

Static Cache v1
^^^^^^^^^^^^^^^^

The ``static_cache_v1.py`` implements KV cache in the model graph as follows:

.. code-block:: python

class StaticCacheV1Model(nn.Module):
def __init__(self):
super().__init__()

def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
# Concatenate new key/value pairs with existing cache
new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2)
new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2)

# Compute attention using the updated cache
attn_output = torch._C._nn.scaled_dot_product_attention(
q,
new_key_cache[:, :, :end_idx, :],
new_value_cache[:, :, :end_idx, :],
dropout_p=0.0,
is_causal=is_causal
)

return attn_output, new_key_cache, new_value_cache

In the above code, we concatenate the new key/value pairs with the existing cache and update it. To compute the attention, we use the updated cache and gather the corresponding keys/values from the cache up until and including the current token index.
The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module.

.. note::
The ``start_idx`` and ``end_idx`` are the start and end indices of the current token in the cache. For prefill phase, ``start_idx`` is 0 and ``end_idx`` is the input sequence length.
For decode phase, ``start_idx`` begins at the input sequence length and ``end_idx`` equals ``start_idx + 1``. The ``start_idx`` is incremented by 1 until the end of the sequence or we reach the maximum number of tokens to generate.


Static Cache v2
^^^^^^^^^^^^^^^^

The ``static_cache_v2.py`` is similar to ``static_cache_v1.py`` but it uses less number of slice operations. It implements KV cache in the model graph as follows:

.. code-block:: python

class StaticCacheV2Model(nn.Module):
def __init__(self):
super().__init__()

def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
attn_output = torch._C._nn.scaled_dot_product_attention(
q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
)

return attn_output, new_key_cache, new_value_cache

In the above code, we concatenate the existing key/value cache with current key/value of the token. We use this to directly compute the attention and update the key/value cache inserting the current key/value.
The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module.
The definitons of ``start_idx`` and ``end_idx`` are the same as ``static_cache_v1.py``.

After the model is compiled with static KV cache, the input signature of the model is changed. The new input signature is ``(input_ids, position_ids, key_cache_0, value_cache_0, ..., start_idx, end_idx)``.
The number of key/value cache tensors is equal to the number of attention heads in the model. We can use the ``generate_with_static_cache`` function to generate the outputs.

Generating Outputs
-------------------
We use custom `generate <https://github.com/pytorch/TensorRT/blob/main/tools/llm/utils.py#L112>`_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching.
There is also a `generate_with_static_cache <https://github.com/pytorch/TensorRT/blob/main/tools/llm/utils.py#L141>`_ function that performs autoregressive decoding with KV caching.

The ``generate_with_static_cache`` function takes care of preparing the inputs to the model compiled with static KV cache.
The model inputs are ``input_ids``, ``position_ids``, ``key_cache_0``, ``value_cache_0``, ...., ``start_idx``, ``end_idx``.
We initialize the key/value cache tensors with zeros and for every token generated, the new key/value cache tensors are the outputs of the model.

SDPA Converter (sdpa_converter.py)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

* Converts scaled dot-product attention operation using TRT Python API.
* Supports causal and standard self-attention.

SDPA Registration (register_sdpa.py)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

* This is a Torch-TensorRT lowering pass that replaces variants of SDPA with ``torch.nn.functional.scaled_dot_product_attention``.
* Registers the SDPA converter which is used for converting ``torch.nn.functional.scaled_dot_product_attention`` operation.


Limitations and Known Issues
----------------------------

* Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported
* Some model architectures (e.g. Phi-4) have issues with exporting the torch model.

Requirements
^^^^^^^^^^^^

* Torch-TensorRT 2.8.0 or later
* Transformers v4.52.3
9 changes: 9 additions & 0 deletions examples/dynamo/aot_plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""
.. _aot_plugin:

AOT Plugin
==========

This example demonstrates how to use an AOT plugin in Torch-TensorRT.
"""

import argparse
from typing import Tuple, Union

Expand Down
98 changes: 0 additions & 98 deletions examples/dynamo/torch_export_gpt2.py

This file was deleted.

Loading
Loading