Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
586 changes: 372 additions & 214 deletions lmdeploy/pytorch/backends/cuda/attention/default.py

Large diffs are not rendered by default.

485 changes: 369 additions & 116 deletions lmdeploy/pytorch/backends/cuda/attention/fa3.py

Large diffs are not rendered by default.

155 changes: 121 additions & 34 deletions lmdeploy/pytorch/backends/cuda/attention/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import torch

from lmdeploy.pytorch.compile_util import custom_op, get_custom_op_manager
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager
from lmdeploy.utils import get_logger

from .default import TritonAttentionImpl, TritonAttentionMetadata
from .default import TritonAttentionImpl, TritonAttentionMetadata, _fill_kv_cache_impl, _get_fill_meta

logger = get_logger('lmdeploy')

Expand Down Expand Up @@ -140,6 +142,8 @@ def __init__(

self.nsa_updater = NSAIndicesUpdater.build()

self.mod_key = get_custom_op_manager().register_mod_instance(self)

def _get_flash_mla_sparse_fwd(self):
if self.flash_mla_sparse_fwd is not None:
return self.flash_mla_sparse_fwd
Expand Down Expand Up @@ -391,7 +395,7 @@ def _fill_kv_cache_impl(self,
"""Fill kv cache."""
is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn
if not is_fp8_kvcache:
return super()._fill_kv_cache_impl(
return _fill_kv_cache_impl(
key,
value,
k_cache,
Expand All @@ -408,8 +412,7 @@ def _fill_kv_cache_impl(self,
assert quant_policy == 0

# fill seqlen args
fill_seqlens, fill_max_q_seqlen, fill_q_start_loc = self._get_fill_meta(
key,
fill_seqlens, fill_max_q_seqlen, fill_q_start_loc = _get_fill_meta(
attn_metadata,
max_q_seqlen,
)
Expand Down Expand Up @@ -517,44 +520,19 @@ def _forward_prefill(
else:
return self._prefill_triton(query, flatten_k, flatten_v, attn_metadata)

def forward(
def forward_impl(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
nsa_indices: torch.Tensor = None,
**kwargs,
k_scales_zeros: torch.Tensor | None = None,
v_scales_zeros: torch.Tensor | None = None,
nsa_indices: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass for MLA attention computation.

This method handles both prefill and decoding stages by:
1. Validating NSA requirements (FP8 KV cache)
2. Computing max query sequence length
3. Filling KV cache if new key/value are provided
4. Dispatching to appropriate stage-specific method

Architecture:
- Decoding: Uses flash_mla_with_kvcache with paged KV cache
- Prefill: Three paths based on availability and requirements
* Sparse (NSA + FP8): flash_mla_sparse_fwd
* FA3 optimized: flash_attn_varlen_func with split q_rope/q_nope
* Triton fallback: Custom triton kernel

Args:
query: Query tensor.
key: Key tensor (None for decoding-only).
value: Value tensor (None for decoding-only).
k_cache: Key cache tensor.
v_cache: Value cache tensor.
attn_metadata: Attention metadata containing stage info and indices.
k_scales_zeros: Key quantization scales/zeros.
v_scales_zeros: Value quantization scales/zeros.
nsa_indices: Optional sparse attention indices.
"""Forward pass for MLA attention computation implementation.

Returns:
Attention output tensor.
Expand Down Expand Up @@ -593,3 +571,112 @@ def forward(
k_scales_zeros,
v_scales_zeros,
)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
k_scales_zeros: torch.Tensor | None = None,
v_scales_zeros: torch.Tensor | None = None,
nsa_indices: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
"""Forward pass for MLA attention computation.

This method handles both prefill and decoding stages by:
1. Validating NSA requirements (FP8 KV cache)
2. Computing max query sequence length
3. Filling KV cache if new key/value are provided
4. Dispatching to appropriate stage-specific method

Architecture:
- Decoding: Uses flash_mla_with_kvcache with paged KV cache
- Prefill: Three paths based on availability and requirements
* Sparse (NSA + FP8): flash_mla_sparse_fwd
* FA3 optimized: flash_attn_varlen_func with split q_rope/q_nope
* Triton fallback: Custom triton kernel

Args:
query: Query tensor.
key: Key tensor (None for decoding-only).
value: Value tensor (None for decoding-only).
k_cache: Key cache tensor.
v_cache: Value cache tensor.
attn_metadata: Attention metadata containing stage info and indices.
k_scales_zeros: Key quantization scales/zeros.
v_scales_zeros: Value quantization scales/zeros.
nsa_indices: Optional sparse attention indices.

Returns:
Attention output tensor.
"""
if torch.compiler.is_compiling():
return flash_mla_attention_forward(
self.mod_key,
query,
key,
value,
k_cache,
v_cache,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
nsa_indices=nsa_indices,
)
else:
return self.forward_impl(
query,
key,
value,
k_cache,
v_cache,
attn_metadata=attn_metadata,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
nsa_indices=nsa_indices,
)


@custom_op('lmdeploy::flash_mla_attention_forward',
mutates_args=['k_cache', 'v_cache'],
split_prefill=True,
split_decoding=False)
def flash_mla_attention_forward(
mod_key: int,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_scales_zeros: torch.Tensor | None = None,
v_scales_zeros: torch.Tensor | None = None,
nsa_indices: torch.Tensor | None = None,
) -> torch.Tensor:
"""Flash MLA attention forward op."""
instance: 'FlashMLAImpl' = get_custom_op_manager().get_mod_instance(mod_key)
assert isinstance(instance, FlashMLAImpl)
step_ctx = get_step_ctx_manager().current_context()
attn_metadata: TritonAttentionMetadata = step_ctx.attn_metadata
v_cache = k_cache[..., :instance._MLA_NOPE_SIZE]
return instance.forward_impl(
query,
key,
value,
k_cache,
v_cache,
attn_metadata=attn_metadata,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
nsa_indices=nsa_indices,
)


@flash_mla_attention_forward.register_fake
def _(mod_key: int, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""Fake register for flash_mla_attention_forward."""
head_dim = value.size(-1)
out_shape = query.shape[:-1] + (head_dim, )
return query.new_empty(out_shape)
31 changes: 31 additions & 0 deletions lmdeploy/pytorch/backends/cuda/graph_runner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn

from lmdeploy.pytorch import envs
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig


def build_runner(model: nn.Module, model_config: ModelConfig, cache_config: CacheConfig, backend_config: BackendConfig,
device: torch.device):

use_compile = False
if envs.force_torch_compile:
use_compile = True
elif hasattr(model, 'use_torch_compile'):
use_compile = model.use_torch_compile() # type: ignore[attr-defined]

if use_compile:
from .compile_runner import TorchCompileRunner # noqa: F401
return TorchCompileRunner(model,
model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
device=device)
else:
from .cudagraph_runner import CUDAGraphRunner # noqa: F401
return CUDAGraphRunner(model,
model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
device=device)
Loading
Loading