Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Any, Dict, List, Tuple
from typing import Any

import torch
from torch.profiler import record_function
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
max_tokens: int,
num_blocks: int,
is_decoding: bool,
pool: Tuple[int, int],
pool: tuple[int, int],
model_config: ModelConfig,
device: torch.device,
decode_query_len: int = 1,
Expand All @@ -89,6 +89,8 @@ def __init__(
mla_index_topk=getattr(self.model_config, 'mla_index_topk', None),
decode_query_len=decode_query_len,
use_fa3_decoding=model_config.model_paradigm == 'ar_spec',
is_ssm=len(model_config.states_shapes) > 0,
use_mrope=model_config.use_mrope,
)
self.device = device
self.max_batches = max_batches
Expand Down Expand Up @@ -153,7 +155,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf
self.enable_graph = self.check_enable_graph()

self.graph_pool_handle = torch.cuda.graph_pool_handle()
self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict()
self._runner_map: dict[Any, CUDASingleGraphRunner] = dict()
self.has_try_compile_model: bool = False

# strategy factory
Expand Down Expand Up @@ -187,7 +189,7 @@ def _get_capture_tokens(self, batch_size: int):
return size
assert False, f'Unsupported batch_size={batch_size}'

def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List,
def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: list,
attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs):
"""Get graph key."""
context = self.ctx_mgr.current_context()
Expand Down Expand Up @@ -261,7 +263,7 @@ def __call__(self, **kwargs):
@record_function('prepare_inputs_for_generation')
def prepare_inputs_for_generation(
self,
past_key_values: List[List[torch.Tensor]],
past_key_values: list[list[torch.Tensor]],
inputs_embeds: torch.Tensor = None,
context: StepContext = None,
):
Expand Down Expand Up @@ -303,6 +305,6 @@ def update_inputs(self, inputs):
dp_meta.sync_tp_size(tp_size)
return inputs

def get_capture_batch_sizes(self) -> List[int]:
def get_capture_batch_sizes(self) -> list[int]:
"""Capture batch sizes."""
return _get_capture_batch_size_impl(self.cache_config.max_batches)
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ class ModelConfig:
# quant config
quant_config: 'QuantizationConfig' = None

# flags mark if this model use mrope
use_mrope: bool = False

def get_head_size(self):
"""Get head size."""
return self.head_dim
Expand Down
13 changes: 8 additions & 5 deletions lmdeploy/pytorch/configurations/glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False,
hf_config.scoring_func = 'sigmoid'
if not hasattr(hf_config, 'moe_layer_freq'):
hf_config.moe_layer_freq = 1
return super().build(hf_config,
model_path=model_path,
is_draft_model=is_draft_model,
spec_method=spec_method,
**kwargs)
cfg = super().build(hf_config,
model_path=model_path,
is_draft_model=is_draft_model,
spec_method=spec_method,
**kwargs)
cfg.use_mrope = True
return cfg


class Glm4MoeModelConfigBuilder(DefaultModelConfigBuilder):
Expand Down Expand Up @@ -58,6 +60,7 @@ def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False,
is_draft_model=is_draft_model,
spec_method=spec_method,
**kwargs)
cfg.use_mrope = True
cfg.model_paradigm = model_paradigm
cfg.num_layers = num_layers
return cfg
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/configurations/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):
dtype = torch.bfloat16
cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]
cfg.check_env_func = _check_env_qwen3_next

cfg.use_mrope = True
return cfg
1 change: 1 addition & 0 deletions lmdeploy/pytorch/configurations/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ def build(cls, hf_config, model_path: str = None, **kwargs):
cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)
setattr(hf_config, 'dtype', hf_config.text_config.dtype)
cfg.hf_config = hf_config
cfg.use_mrope = True
return cfg
32 changes: 18 additions & 14 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import gc
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any

import numpy as np
import torch
Expand All @@ -28,7 +28,7 @@

logger = get_logger('lmdeploy')

SeqList = List[SchedulerSequence]
SeqList = list[SchedulerSequence]


@dataclass
Expand All @@ -37,15 +37,15 @@ class InferOutput:

session_id: int
resp: Response
token_ids: Union[np.ndarray, List[int]]
token_ids: np.ndarray | list[int]
meta: Any = None
finish: bool = False
logits: torch.Tensor = None
logprobs: torch.Tensor = None

# send cache blocks back for migration in Disaggregated LLM Serving
# when Prefill Engine is Done.
cache_block_ids: List[int] = None
cache_block_ids: list[int] = None

# for logging
req_metrics: RequestMetrics = None
Expand All @@ -54,10 +54,13 @@ class InferOutput:
routed_experts: torch.Tensor = None


def _build_seq_meta(cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any):
def _build_seq_meta(model_config: ModelConfig, cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any):
from lmdeploy.pytorch.messages import SequenceMeta

seq_meta = SequenceMeta(cache_config.block_size, strategy=seq_strategy, sampling_strategy=sampling_strategy)
seq_meta = SequenceMeta(cache_config.block_size,
strategy=seq_strategy,
sampling_strategy=sampling_strategy,
use_mrope=model_config.use_mrope)
return seq_meta


Expand Down Expand Up @@ -156,7 +159,8 @@ def __init__(
self.input_processor = self.executor.get_input_processor()
cache_config = self.executor.cache_config
self.adapter_manager = self._build_adapter_manager(adapters)
self.seq_meta = _build_seq_meta(cache_config,
self.seq_meta = _build_seq_meta(model_config=self.model_config,
cache_config=cache_config,
seq_strategy=self.seq_strategy,
sampling_strategy=self.sampling_strategy)
self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta)
Expand Down Expand Up @@ -229,7 +233,7 @@ def from_pretrained(cls,
speculative_config=speculative_config,
)

def _download_adapters(self, adapters: Dict[str, str], engine_config: PytorchEngineConfig):
def _download_adapters(self, adapters: dict[str, str], engine_config: PytorchEngineConfig):
"""Download adapters."""
download_dir = engine_config.download_dir
revision = engine_config.revision
Expand Down Expand Up @@ -274,7 +278,7 @@ def _get_max_session_len(self):
session_len = min(max_tokens, session_len)
return session_len

def _on_add_session(self, reqs: List[Request], **kwargs):
def _on_add_session(self, reqs: list[Request], **kwargs):
"""On add session callback."""
for req in reqs:
session_id = req.data['session_id']
Expand All @@ -286,7 +290,7 @@ def _on_add_session(self, reqs: List[Request], **kwargs):
if resp:
self._response(req.resp, resp_type)

def _on_stop_session(self, reqs: List[Request], **kwargs):
def _on_stop_session(self, reqs: list[Request], **kwargs):
"""On stop session callback."""
for req in reqs:
session_id = req.data['session_id']
Expand All @@ -305,7 +309,7 @@ def _on_stop_session(self, reqs: List[Request], **kwargs):
if resp:
self._response(req.resp, resp_type)

def _on_end_session(self, reqs: List[Request], **kwargs):
def _on_end_session(self, reqs: list[Request], **kwargs):
"""On end session callback."""
for req in reqs:
session_id = req.data['session_id']
Expand All @@ -321,7 +325,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs):
if resp:
self._response(req.resp, resp_type)

def _on_add_message(self, reqs: List[Request], **kwargs):
def _on_add_message(self, reqs: list[Request], **kwargs):
"""On add message callback."""
valid_reqs = []
for req in reqs:
Expand Down Expand Up @@ -359,7 +363,7 @@ def _on_add_message(self, reqs: List[Request], **kwargs):
if len(valid_reqs) > 0:
self._add_message(valid_reqs)

def _add_message(self, reqs: List[Request]):
def _add_message(self, reqs: list[Request]):

def __update_max_new_tokens(msg):
"""Update max new tokens."""
Expand Down Expand Up @@ -440,7 +444,7 @@ def sleep(self, level: int = 1):
"""Sleep."""
self.executor.sleep(level)

def wakeup(self, tags: Optional[List[str]] = None):
def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
self.executor.wakeup(tags)

Expand Down
32 changes: 23 additions & 9 deletions lmdeploy/pytorch/engine/inputs_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING

import numpy as np
import torch
Expand Down Expand Up @@ -54,10 +54,12 @@ class InputsMakerConfig:
dp: int = 1
spec_decoding: bool = False
enable_chunked_prefill: bool = False
use_mrope: bool = False

@staticmethod
def from_engine(engine: 'Engine'):
cache_config = engine.cache_config
model_config = engine.model_config
return InputsMakerConfig(
spec_decoding=engine.specdecode_config is not None,
max_batches=cache_config.max_batches,
Expand All @@ -66,6 +68,7 @@ def from_engine(engine: 'Engine'):
is_ssm=len(cache_config.states_shapes) > 0,
dp=engine.dist_config.dp,
enable_chunked_prefill=engine.misc_config.enable_chunked_prefill,
use_mrope=model_config.use_mrope,
)


Expand Down Expand Up @@ -219,8 +222,8 @@ def __init__(

# running seqs
# mark the seqs that have been sent to executor
self.running_seqs: List['SchedulerSequence'] = []
self.to_evict_seqs: List['SchedulerSequence'] = []
self.running_seqs: list['SchedulerSequence'] = []
self.to_evict_seqs: list['SchedulerSequence'] = []

# long context chunker
self.long_context_chunker = LongContextChunker(config.max_prefill_token_num)
Expand Down Expand Up @@ -379,14 +382,19 @@ def create_model_inputs(self, messages: 'SeqList', is_prefill: bool):
state_offsets = torch.tensor([msg.logical_state for msg in messages])
model_inputs.state_offsets = state_offsets

if self.config.use_mrope:
mrope_pos_ids = [msg.mrope_pos_ids for msg in messages]
mrope_pos_ids = torch.as_tensor(np.concatenate(mrope_pos_ids)).T
model_inputs.mrope_pos_ids = mrope_pos_ids

return model_inputs

@torch.inference_mode()
@record_function('create_model_inputs_long_context')
def create_model_inputs_long_context(self,
seq: 'SchedulerSequence',
chunk_size: int,
multimodals: Optional['MultiModalInputs'] = None):
multimodals: 'MultiModalInputs|None' = None):
"""Create model inputs for long context messages."""
token_ids = seq.token_ids[:chunk_size]
input_ids = torch.as_tensor(token_ids)[None]
Expand Down Expand Up @@ -436,6 +444,12 @@ def create_model_inputs_long_context(self,
if self.config.is_ssm:
model_inputs.state_offsets = torch.tensor([seq.logical_state])

# mrope
if self.config.use_mrope:
mrope_pos_ids = seq.mrope_pos_ids[:chunk_size]
mrope_pos_ids = torch.as_tensor(mrope_pos_ids).T
model_inputs.mrope_pos_ids = mrope_pos_ids

return model_inputs

@torch.inference_mode()
Expand All @@ -453,8 +467,8 @@ def create_model_inputs_delta(self):

valid_mask = np.array(valid_mask)
indices_cpu = np.arange(0, batch_size)[valid_mask]
valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]
invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]
valid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]
invalid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]
if len(valid_seqs) == 0:
return None, valid_seqs, invalid_seqs

Expand Down Expand Up @@ -498,8 +512,8 @@ def create_model_inputs_delta_valid_only(self):

valid_mask = np.array(valid_mask, dtype=bool)
indices_cpu = np.arange(0, batch_size)[valid_mask]
valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]
invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]
valid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]
invalid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]

num_decode_tokens = self.engine_strategy.get_num_decode_tokens()
max_q_seqlen = num_decode_tokens
Expand All @@ -523,7 +537,7 @@ def create_model_inputs_delta_valid_only(self):

return output, valid_seqs, invalid_seqs

def update_running_seqs(self, running: 'SeqList', inputs: Optional[ModelInputs]):
def update_running_seqs(self, running: 'SeqList', inputs: 'ModelInputs|None'):
"""Update running seqs."""
if self.config.role == EngineRole.Prefill:
# p node will not update running seqs
Expand Down
12 changes: 9 additions & 3 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ def __init__(
# long context
self._prev_chunk_output: Dict = None

# make dummy meta
self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(model_config)

@contextmanager
def all_context(self):
device_mgr = get_device_manager()
Expand All @@ -427,10 +430,11 @@ def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheCo
self.cache_config = cache_config
self.spec_agent.set_cache_config(spec_cache_config)

def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):
def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig | None = None):
"""Set model config."""
self.model_config = model_config
self.spec_agent.set_model_config(spec_model_config)
self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(model_config)

def get_free_mem(self):
"""Gather available memory."""
Expand Down Expand Up @@ -461,7 +465,8 @@ def warmup(self):
inputs = self.inputs_strategy.make_dummy(max_batches,
is_decoding=False,
device='cuda',
vocab_size=self.model_config.vocab_size)
vocab_size=self.model_config.vocab_size,
meta=self.make_dummy_meta)
if dp > 1:
num_tokens = inputs.input_ids.numel()
inputs.build_dp_meta([num_tokens] * world_size)
Expand All @@ -480,7 +485,8 @@ def warmup(self):
inputs = self.inputs_strategy.make_dummy(num_tokens,
is_decoding=True,
device='cuda',
vocab_size=self.model_config.vocab_size)
vocab_size=self.model_config.vocab_size,
meta=self.make_dummy_meta)
if dp > 1:
num_tokens = inputs.input_ids.numel()
inputs.build_dp_meta([num_tokens] * world_size)
Expand Down
Loading
Loading