diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 58d093cf9b..2d64c88d70 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import functools -from typing import Any, Dict, List, Tuple +from typing import Any import torch from torch.profiler import record_function @@ -66,7 +66,7 @@ def __init__( max_tokens: int, num_blocks: int, is_decoding: bool, - pool: Tuple[int, int], + pool: tuple[int, int], model_config: ModelConfig, device: torch.device, decode_query_len: int = 1, @@ -89,6 +89,8 @@ def __init__( mla_index_topk=getattr(self.model_config, 'mla_index_topk', None), decode_query_len=decode_query_len, use_fa3_decoding=model_config.model_paradigm == 'ar_spec', + is_ssm=len(model_config.states_shapes) > 0, + use_mrope=model_config.use_mrope, ) self.device = device self.max_batches = max_batches @@ -153,7 +155,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf self.enable_graph = self.check_enable_graph() self.graph_pool_handle = torch.cuda.graph_pool_handle() - self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() + self._runner_map: dict[Any, CUDASingleGraphRunner] = dict() self.has_try_compile_model: bool = False # strategy factory @@ -187,7 +189,7 @@ def _get_capture_tokens(self, batch_size: int): return size assert False, f'Unsupported batch_size={batch_size}' - def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List, + def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: list, attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs): """Get graph key.""" context = self.ctx_mgr.current_context() @@ -261,7 +263,7 @@ def __call__(self, **kwargs): @record_function('prepare_inputs_for_generation') def prepare_inputs_for_generation( self, - past_key_values: List[List[torch.Tensor]], + past_key_values: list[list[torch.Tensor]], inputs_embeds: torch.Tensor = None, context: StepContext = None, ): @@ -303,6 +305,6 @@ def update_inputs(self, inputs): dp_meta.sync_tp_size(tp_size) return inputs - def get_capture_batch_sizes(self) -> List[int]: + def get_capture_batch_sizes(self) -> list[int]: """Capture batch sizes.""" return _get_capture_batch_size_impl(self.cache_config.max_batches) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index d80045ab92..2e668ef935 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -341,6 +341,9 @@ class ModelConfig: # quant config quant_config: 'QuantizationConfig' = None + # flags mark if this model use mrope + use_mrope: bool = False + def get_head_size(self): """Get head size.""" return self.head_dim diff --git a/lmdeploy/pytorch/configurations/glm4.py b/lmdeploy/pytorch/configurations/glm4.py index 80deb6831c..26bb26386c 100644 --- a/lmdeploy/pytorch/configurations/glm4.py +++ b/lmdeploy/pytorch/configurations/glm4.py @@ -18,11 +18,13 @@ def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, hf_config.scoring_func = 'sigmoid' if not hasattr(hf_config, 'moe_layer_freq'): hf_config.moe_layer_freq = 1 - return super().build(hf_config, - model_path=model_path, - is_draft_model=is_draft_model, - spec_method=spec_method, - **kwargs) + cfg = super().build(hf_config, + model_path=model_path, + is_draft_model=is_draft_model, + spec_method=spec_method, + **kwargs) + cfg.use_mrope = True + return cfg class Glm4MoeModelConfigBuilder(DefaultModelConfigBuilder): @@ -58,6 +60,7 @@ def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, is_draft_model=is_draft_model, spec_method=spec_method, **kwargs) + cfg.use_mrope = True cfg.model_paradigm = model_paradigm cfg.num_layers = num_layers return cfg diff --git a/lmdeploy/pytorch/configurations/qwen3_5.py b/lmdeploy/pytorch/configurations/qwen3_5.py index 9c1b0111fa..0bb87545bb 100644 --- a/lmdeploy/pytorch/configurations/qwen3_5.py +++ b/lmdeploy/pytorch/configurations/qwen3_5.py @@ -45,4 +45,6 @@ def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): dtype = torch.bfloat16 cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)] cfg.check_env_func = _check_env_qwen3_next + + cfg.use_mrope = True return cfg diff --git a/lmdeploy/pytorch/configurations/qwen3_vl.py b/lmdeploy/pytorch/configurations/qwen3_vl.py index 6b78efcd0b..212d34c721 100644 --- a/lmdeploy/pytorch/configurations/qwen3_vl.py +++ b/lmdeploy/pytorch/configurations/qwen3_vl.py @@ -22,4 +22,5 @@ def build(cls, hf_config, model_path: str = None, **kwargs): cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs) setattr(hf_config, 'dtype', hf_config.text_config.dtype) cfg.hf_config = hf_config + cfg.use_mrope = True return cfg diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 71f061f5ef..de160a0276 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -3,7 +3,7 @@ import gc import os from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any import numpy as np import torch @@ -28,7 +28,7 @@ logger = get_logger('lmdeploy') -SeqList = List[SchedulerSequence] +SeqList = list[SchedulerSequence] @dataclass @@ -37,7 +37,7 @@ class InferOutput: session_id: int resp: Response - token_ids: Union[np.ndarray, List[int]] + token_ids: np.ndarray | list[int] meta: Any = None finish: bool = False logits: torch.Tensor = None @@ -45,7 +45,7 @@ class InferOutput: # send cache blocks back for migration in Disaggregated LLM Serving # when Prefill Engine is Done. - cache_block_ids: List[int] = None + cache_block_ids: list[int] = None # for logging req_metrics: RequestMetrics = None @@ -54,10 +54,13 @@ class InferOutput: routed_experts: torch.Tensor = None -def _build_seq_meta(cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any): +def _build_seq_meta(model_config: ModelConfig, cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any): from lmdeploy.pytorch.messages import SequenceMeta - seq_meta = SequenceMeta(cache_config.block_size, strategy=seq_strategy, sampling_strategy=sampling_strategy) + seq_meta = SequenceMeta(cache_config.block_size, + strategy=seq_strategy, + sampling_strategy=sampling_strategy, + use_mrope=model_config.use_mrope) return seq_meta @@ -156,7 +159,8 @@ def __init__( self.input_processor = self.executor.get_input_processor() cache_config = self.executor.cache_config self.adapter_manager = self._build_adapter_manager(adapters) - self.seq_meta = _build_seq_meta(cache_config, + self.seq_meta = _build_seq_meta(model_config=self.model_config, + cache_config=cache_config, seq_strategy=self.seq_strategy, sampling_strategy=self.sampling_strategy) self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta) @@ -229,7 +233,7 @@ def from_pretrained(cls, speculative_config=speculative_config, ) - def _download_adapters(self, adapters: Dict[str, str], engine_config: PytorchEngineConfig): + def _download_adapters(self, adapters: dict[str, str], engine_config: PytorchEngineConfig): """Download adapters.""" download_dir = engine_config.download_dir revision = engine_config.revision @@ -274,7 +278,7 @@ def _get_max_session_len(self): session_len = min(max_tokens, session_len) return session_len - def _on_add_session(self, reqs: List[Request], **kwargs): + def _on_add_session(self, reqs: list[Request], **kwargs): """On add session callback.""" for req in reqs: session_id = req.data['session_id'] @@ -286,7 +290,7 @@ def _on_add_session(self, reqs: List[Request], **kwargs): if resp: self._response(req.resp, resp_type) - def _on_stop_session(self, reqs: List[Request], **kwargs): + def _on_stop_session(self, reqs: list[Request], **kwargs): """On stop session callback.""" for req in reqs: session_id = req.data['session_id'] @@ -305,7 +309,7 @@ def _on_stop_session(self, reqs: List[Request], **kwargs): if resp: self._response(req.resp, resp_type) - def _on_end_session(self, reqs: List[Request], **kwargs): + def _on_end_session(self, reqs: list[Request], **kwargs): """On end session callback.""" for req in reqs: session_id = req.data['session_id'] @@ -321,7 +325,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs): if resp: self._response(req.resp, resp_type) - def _on_add_message(self, reqs: List[Request], **kwargs): + def _on_add_message(self, reqs: list[Request], **kwargs): """On add message callback.""" valid_reqs = [] for req in reqs: @@ -359,7 +363,7 @@ def _on_add_message(self, reqs: List[Request], **kwargs): if len(valid_reqs) > 0: self._add_message(valid_reqs) - def _add_message(self, reqs: List[Request]): + def _add_message(self, reqs: list[Request]): def __update_max_new_tokens(msg): """Update max new tokens.""" @@ -440,7 +444,7 @@ def sleep(self, level: int = 1): """Sleep.""" self.executor.sleep(level) - def wakeup(self, tags: Optional[List[str]] = None): + def wakeup(self, tags: list[str] | None = None): """Wakeup.""" self.executor.wakeup(tags) diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index 506a372250..dff099d892 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -2,7 +2,7 @@ import logging from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING import numpy as np import torch @@ -54,10 +54,12 @@ class InputsMakerConfig: dp: int = 1 spec_decoding: bool = False enable_chunked_prefill: bool = False + use_mrope: bool = False @staticmethod def from_engine(engine: 'Engine'): cache_config = engine.cache_config + model_config = engine.model_config return InputsMakerConfig( spec_decoding=engine.specdecode_config is not None, max_batches=cache_config.max_batches, @@ -66,6 +68,7 @@ def from_engine(engine: 'Engine'): is_ssm=len(cache_config.states_shapes) > 0, dp=engine.dist_config.dp, enable_chunked_prefill=engine.misc_config.enable_chunked_prefill, + use_mrope=model_config.use_mrope, ) @@ -219,8 +222,8 @@ def __init__( # running seqs # mark the seqs that have been sent to executor - self.running_seqs: List['SchedulerSequence'] = [] - self.to_evict_seqs: List['SchedulerSequence'] = [] + self.running_seqs: list['SchedulerSequence'] = [] + self.to_evict_seqs: list['SchedulerSequence'] = [] # long context chunker self.long_context_chunker = LongContextChunker(config.max_prefill_token_num) @@ -379,6 +382,11 @@ def create_model_inputs(self, messages: 'SeqList', is_prefill: bool): state_offsets = torch.tensor([msg.logical_state for msg in messages]) model_inputs.state_offsets = state_offsets + if self.config.use_mrope: + mrope_pos_ids = [msg.mrope_pos_ids for msg in messages] + mrope_pos_ids = torch.as_tensor(np.concatenate(mrope_pos_ids)).T + model_inputs.mrope_pos_ids = mrope_pos_ids + return model_inputs @torch.inference_mode() @@ -386,7 +394,7 @@ def create_model_inputs(self, messages: 'SeqList', is_prefill: bool): def create_model_inputs_long_context(self, seq: 'SchedulerSequence', chunk_size: int, - multimodals: Optional['MultiModalInputs'] = None): + multimodals: 'MultiModalInputs|None' = None): """Create model inputs for long context messages.""" token_ids = seq.token_ids[:chunk_size] input_ids = torch.as_tensor(token_ids)[None] @@ -436,6 +444,12 @@ def create_model_inputs_long_context(self, if self.config.is_ssm: model_inputs.state_offsets = torch.tensor([seq.logical_state]) + # mrope + if self.config.use_mrope: + mrope_pos_ids = seq.mrope_pos_ids[:chunk_size] + mrope_pos_ids = torch.as_tensor(mrope_pos_ids).T + model_inputs.mrope_pos_ids = mrope_pos_ids + return model_inputs @torch.inference_mode() @@ -453,8 +467,8 @@ def create_model_inputs_delta(self): valid_mask = np.array(valid_mask) indices_cpu = np.arange(0, batch_size)[valid_mask] - valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu] - invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]] + valid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu] + invalid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]] if len(valid_seqs) == 0: return None, valid_seqs, invalid_seqs @@ -498,8 +512,8 @@ def create_model_inputs_delta_valid_only(self): valid_mask = np.array(valid_mask, dtype=bool) indices_cpu = np.arange(0, batch_size)[valid_mask] - valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu] - invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]] + valid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu] + invalid_seqs: list['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]] num_decode_tokens = self.engine_strategy.get_num_decode_tokens() max_q_seqlen = num_decode_tokens @@ -523,7 +537,7 @@ def create_model_inputs_delta_valid_only(self): return output, valid_seqs, invalid_seqs - def update_running_seqs(self, running: 'SeqList', inputs: Optional[ModelInputs]): + def update_running_seqs(self, running: 'SeqList', inputs: 'ModelInputs|None'): """Update running seqs.""" if self.config.role == EngineRole.Prefill: # p node will not update running seqs diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index f9c2919962..9dc452e507 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -415,6 +415,9 @@ def __init__( # long context self._prev_chunk_output: Dict = None + # make dummy meta + self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(model_config) + @contextmanager def all_context(self): device_mgr = get_device_manager() @@ -427,10 +430,11 @@ def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheCo self.cache_config = cache_config self.spec_agent.set_cache_config(spec_cache_config) - def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None): + def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig | None = None): """Set model config.""" self.model_config = model_config self.spec_agent.set_model_config(spec_model_config) + self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(model_config) def get_free_mem(self): """Gather available memory.""" @@ -461,7 +465,8 @@ def warmup(self): inputs = self.inputs_strategy.make_dummy(max_batches, is_decoding=False, device='cuda', - vocab_size=self.model_config.vocab_size) + vocab_size=self.model_config.vocab_size, + meta=self.make_dummy_meta) if dp > 1: num_tokens = inputs.input_ids.numel() inputs.build_dp_meta([num_tokens] * world_size) @@ -480,7 +485,8 @@ def warmup(self): inputs = self.inputs_strategy.make_dummy(num_tokens, is_decoding=True, device='cuda', - vocab_size=self.model_config.vocab_size) + vocab_size=self.model_config.vocab_size, + meta=self.make_dummy_meta) if dp > 1: num_tokens = inputs.input_ids.numel() inputs.build_dp_meta([num_tokens] * world_size) diff --git a/lmdeploy/pytorch/engine/model_agent/inputs_maker.py b/lmdeploy/pytorch/engine/model_agent/inputs_maker.py index ba1dfb7a63..d3cb10bf78 100644 --- a/lmdeploy/pytorch/engine/model_agent/inputs_maker.py +++ b/lmdeploy/pytorch/engine/model_agent/inputs_maker.py @@ -43,6 +43,9 @@ def __init__(self, model_agent: 'BaseModelAgent'): self._ready_event = torch.cuda.Event() self._ready_event.record() + # other + self.make_dummy_meta = model_agent.make_dummy_meta + def _make_dummy_forward_inputs(self): """Make dummy forward inputs.""" is_decoding = self.cache_config.role != EngineRole.Prefill @@ -52,7 +55,8 @@ def _make_dummy_forward_inputs(self): model_inputs = self.inputs_strategy.make_dummy(batch_size, is_decoding, device=self.device, - vocab_size=self.model_config.vocab_size) + vocab_size=self.model_config.vocab_size, + meta=self.make_dummy_meta) forward_inputs = dict(inputs=model_inputs, ) return forward_inputs diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index c020403fa8..be8aa2b61a 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -176,6 +176,7 @@ class SequenceMeta: block_size: int strategy: 'SequenceStrategy' = None sampling_strategy: 'SamplingStrategy' = None + use_mrope: bool = False class SequenceManager: @@ -532,6 +533,25 @@ def clone(self): return ret +class HistoryMropePosIds(_HistoryDataBase): + """History mrope position ids.""" + ALLOC_SIZE = 64 + + def __init__(self, pos_ids: np.ndarray | None = None, dtype: np.dtype = np.int64): + super().__init__(pos_ids, dtype) + + def _create_empty_array(self, dtype): + """Create empty array. + + Override in subclass for different shapes. + """ + return np.empty((self.ALLOC_SIZE, 3), dtype=dtype) + + def _get_pad_width(self, reserve_size: int): + """Get pad width for multi-dimensional array.""" + return ((0, reserve_size), (0, 0)) + + class HistoryMultiModals: def __init__(self, multimodals: MultiModalInputs = None): @@ -617,6 +637,9 @@ class SchedulerSequence: # logits all_logits: HistoryLogits = field(default_factory=HistoryLogits) + # mrope + history_mrope_pos_ids: HistoryMropePosIds = field(default_factory=HistoryMropePosIds) + def __post_init__(self): """Post init.""" self._seq_meta: SequenceMeta = self.session.seq_meta @@ -756,6 +779,13 @@ def logits(self): """Get logits.""" return self.all_logits.get_logits() + @property + def mrope_pos_ids(self): + """Get mrope pos ids.""" + start = self.num_history_ids + end = start + self._num_token_ids + return self.history_mrope_pos_ids[start:end] + def append_logits(self, logits: Union[Tensor, np.ndarray]): """Append logits.""" if not self.return_logits: @@ -797,6 +827,58 @@ def _update_multimodals(self, multimodals: MultiModalInputs): multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_valid_ids) self.history_multimodals.add_inputs(multimodals) + def _update_mrope_pos_ids(self): + """Update mrope pos ids.""" + if not self._seq_meta.use_mrope: + return + + num_rope_pos = len(self.history_mrope_pos_ids) + num_appends = self.num_all_ids - num_rope_pos + + if num_appends == 0: + return + + if num_rope_pos == 0: + next_pos = 0 + else: + next_pos = self.history_mrope_pos_ids[-1].max() + 1 + + multimodals = self.history_multimodals.get_datas(num_rope_pos, self.num_all_ids) + if multimodals is None or len(multimodals) == 0: + if num_appends == 1: + pos_ids = np.array([[next_pos] * 3], dtype=np.int64) + else: + pos_ids = np.arange(next_pos, next_pos + num_appends, dtype=np.int64) + pos_ids = pos_ids[:, None].repeat(3, axis=1) + else: + pos_ids = [] + assert len(multimodals) == 1 + modal_datas = list(multimodals.values())[0] + mm_offset = next_pos + for modal_data in modal_datas: + mm_start = modal_data.start + mm_offset + + # tokens + if next_pos < mm_start: + text_pos_ids = np.arange(next_pos, mm_start, dtype=np.int64) + pos_ids.append(text_pos_ids[:, None].repeat(3, axis=1)) + + # imgs + mm_pos_ids = modal_data.mrope_pos_ids + assert mm_pos_ids is not None, ( + 'MROPE position ids is required for multimodal inputs when use_mrope is True.') + new_pos = mm_pos_ids[-1].max() + 1 + next_pos = mm_start + new_pos + mm_offset = mm_offset + new_pos - mm_pos_ids.shape[0] + pos_ids.append(mm_pos_ids + mm_start) + + # add final text part + text_pos_ids = np.arange(next_pos, num_appends + mm_offset, dtype=np.int64) + pos_ids.append(text_pos_ids[:, None].repeat(3, axis=1)) + pos_ids = np.concatenate(pos_ids, axis=0) + + self.history_mrope_pos_ids.append(pos_ids) + def update_token_ids(self, token_ids: Tensor, multimodals: MultiModalInputs = None, diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 538b9c6f3a..d1519a46ad 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal import numpy as np import torch @@ -11,7 +11,7 @@ import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.backends import get_backend from lmdeploy.pytorch.config import CacheConfig, DLLMConfig, ModelConfig, QuantizationConfig -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.utils import CtxMgrBase, singleton if TYPE_CHECKING: @@ -20,11 +20,11 @@ @dataclass class DPMeta: - tp_sizes: List[int] = None - moe_tp_sizes: List[int] = None + tp_sizes: list[int] = None + moe_tp_sizes: list[int] = None @staticmethod - def _gather_tp_sizes(tp: int, seqlen: int, num_tokens: List[int], dist_ctx: dist.DistContext, layer_type: str): + def _gather_tp_sizes(tp: int, seqlen: int, num_tokens: list[int], dist_ctx: dist.DistContext, layer_type: str): """Gather tp size.""" attn_tp = dist_ctx.dist_config.attn_tp if tp > 1 and tp != attn_tp: @@ -38,7 +38,7 @@ def _gather_tp_sizes(tp: int, seqlen: int, num_tokens: List[int], dist_ctx: dist return tp_sizes @classmethod - def build(cls, seqlen: int, num_tokens: List[int]): + def build(cls, seqlen: int, num_tokens: list[int]): """Get dp meta.""" dist_ctx = dist.get_dist_manager().current_context() dist_config = dist_ctx.dist_config @@ -63,10 +63,10 @@ def sync_tp_size(self, tp_size: int): class VisionModelInputs: """Vision model inputs.""" history_lengths: torch.LongTensor = None - input_embeddings: List[List[torch.Tensor]] = None - input_embedding_ranges: List[torch.LongTensor] = None + input_embeddings: list[list[torch.Tensor]] = None + input_embedding_ranges: list[torch.LongTensor] = None input_embedding_indexing: torch.BoolTensor = None - input_multimodals: List[MultiModalTensor] = None + input_multimodals: list[MultiModalData] = None def to_device(self, device: str, non_blocking: bool = False): """To device.""" @@ -125,7 +125,7 @@ def get_inputs(self, history_lengths: torch.Tensor, seq_lengths: torch.Tensor): class ModelInputsDelta: """Delta of ModelInputs.""" # valid indices - indices: Optional[torch.Tensor] + indices: torch.Tensor | None # new block offsets block_offsets: torch.Tensor # cpu copy of indices @@ -135,7 +135,7 @@ class ModelInputsDelta: sum_kv_seqlen: int is_decoding: bool = True # sliding window - num_ignored_history: Optional[torch.Tensor] = None + num_ignored_history: torch.Tensor | None = None @property def seq_length(self): @@ -182,18 +182,21 @@ class ModelInputs: max_q_seqlen: int max_kv_seqlen: int sum_kv_seqlen: int - local_adapter_ids: torch.Tensor = None - vision_inputs: VisionModelInputs = None - model_metas: List[Dict[str, Any]] = None - dp_meta: 'DPMeta' = None + local_adapter_ids: torch.Tensor | None = None + vision_inputs: VisionModelInputs | None = None + model_metas: list[dict[str, Any]] | None = None + dp_meta: DPMeta | None = None enable_microbatch: bool = False is_dummy: bool = False - state_offsets: torch.Tensor = None - target_hidden_states: torch.Tensor = None - target_position_ids: torch.Tensor = None + state_offsets: torch.Tensor | None = None + target_hidden_states: torch.Tensor | None = None + target_position_ids: torch.Tensor | None = None is_chunk: bool = False is_first_chunk: bool = True + # mrope, shape(3, sum_seqlens) + mrope_pos_ids: torch.Tensor | None = None + def step(self, input_ids: torch.Tensor, step_seqlens: torch.Tensor = None): """Update input ids.""" assert self.is_decoding @@ -205,6 +208,9 @@ def step(self, input_ids: torch.Tensor, step_seqlens: torch.Tensor = None): if input_ids.dim() == 1: input_ids = input_ids[None, :] self.input_ids = input_ids + + if self.mrope_pos_ids is not None: + self.mrope_pos_ids = self.mrope_pos_ids + step_seqlens[None] return self @torch.inference_mode() @@ -222,7 +228,7 @@ def to_device(self, device: str, non_blocking: bool = False): return ModelInputs(**out_dict) - def build_dp_meta(self, num_tokens: List[int]): + def build_dp_meta(self, num_tokens: list[int]): """Build dp meta.""" self.dp_meta = DPMeta.build(self.input_ids.numel(), num_tokens) @@ -248,28 +254,31 @@ class StepContext: q_seqlens: torch.LongTensor kv_seqlens: torch.IntTensor q_start_loc: torch.LongTensor - kv_caches: List + kv_caches: list is_decoding: bool sum_kv_seqlen: int - max_kv_seqlen: int = None - local_adapter_ids: torch.LongTensor = None - input_embeddings: torch.Tensor = None - input_embedding_indexing: torch.Tensor = None - input_multimodals: List[MultiModalTensor] = None - vision_inputs: VisionModelInputs = None + max_kv_seqlen: int | None = None + local_adapter_ids: torch.LongTensor | None = None + input_embeddings: torch.Tensor | None = None + input_embedding_indexing: torch.Tensor | None = None + input_multimodals: list[MultiModalData] | None = None + vision_inputs: VisionModelInputs | None = None attn_metadata: Any = None kv_quant_policy: Literal[0, 4, 8] = 0 - model_metas: List[Dict[str, Any]] = None - dp_meta: DPMeta = None + model_metas: list[dict[str, Any]] | None = None + dp_meta: DPMeta | None = None enable_microbatch: bool = False # for draft model - target_hidden_states: torch.Tensor = None + target_hidden_states: torch.Tensor | None = None # states for ssm - state_caches: List = None - state_offsets: torch.LongTensor = None + state_caches: list | None = None + state_offsets: torch.LongTensor | None = None + + # mrope + mrope_position_ids: torch.Tensor | None = None - _outputs: Dict = field(default_factory=dict) + _outputs: dict = field(default_factory=dict) @classmethod def new( @@ -277,8 +286,8 @@ def new( inputs: ModelInputs, model_config: ModelConfig, cache_config: CacheConfig, - kv_caches: List = None, - state_caches: List = None, + kv_caches: list | None = None, + state_caches: list | None = None, kv_quant_policy: Literal[0, 4, 8] = 0, ): """Build step context. @@ -334,6 +343,7 @@ def new( state_caches=state_caches, state_offsets=inputs.state_offsets, target_hidden_states=inputs.target_hidden_states, + mrope_position_ids=inputs.mrope_pos_ids, ) ret = get_backend().update_step_context(ret) @@ -408,8 +418,8 @@ def build_context( inputs: ModelInputs, model_config: ModelConfig, cache_config: CacheConfig, - kv_caches: List = None, - state_caches: List = None, + kv_caches: list | None = None, + state_caches: list | None = None, kv_quant_policy: Literal[0, 4, 8] = 0, ): """Build context.""" diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index 56e3169bb7..b690217564 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -9,7 +9,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding, build_rotary_params) from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj, @@ -866,10 +866,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index ad8adc9739..9b8d0b5472 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -12,7 +12,7 @@ from lmdeploy.pytorch.distributed import get_tp_world_rank from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) @@ -901,10 +901,10 @@ def preprocess_input(self, input_ids: List[int], input_multimodals=None, **kwarg if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/deepseek_vl2.py b/lmdeploy/pytorch/models/deepseek_vl2.py index 290b9a4fc0..b778c6ebeb 100644 --- a/lmdeploy/pytorch/models/deepseek_vl2.py +++ b/lmdeploy/pytorch/models/deepseek_vl2.py @@ -11,7 +11,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .deepseek_v2 import DeepseekV2ForCausalLM @@ -440,13 +440,13 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict( - image_token_id=image_token_id, - images_spatial_crop=images_spatial_crop, - )) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict( + image_token_id=image_token_id, + images_spatial_crop=images_spatial_crop, + )) input_imgs.append(mm_data) diff --git a/lmdeploy/pytorch/models/gemma3_vl.py b/lmdeploy/pytorch/models/gemma3_vl.py index 8f4ea8e972..cff9615df2 100644 --- a/lmdeploy/pytorch/models/gemma3_vl.py +++ b/lmdeploy/pytorch/models/gemma3_vl.py @@ -8,7 +8,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import build_model_from_hf_config @@ -108,10 +108,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/glm4_1v.py b/lmdeploy/pytorch/models/glm4_1v.py index 9b89164bef..dbefc845d5 100644 --- a/lmdeploy/pytorch/models/glm4_1v.py +++ b/lmdeploy/pytorch/models/glm4_1v.py @@ -2,22 +2,22 @@ # adapted from: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Iterable, List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .glm4 import Glm4DecoderLayer -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .qwen2_vl import Qwen2VLInputProcessor as Glm4vInputProcessor +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixin, vlm_model @@ -717,162 +717,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - input_buffers = graph_meta.input_buffers - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - images = [] - if input_mm is not None: - images = input_mm.get('image', []) - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - for img in images: - grid_thw = img.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = img.end - img.start - max(h, w) - mrope_delta -= num_pad - fill_start = img.start - pos_start - fill_end = img.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, - past_key_values: List[List[torch.Tensor]], - inputs_embeds: Optional[torch.Tensor] = None, - context: StepContext = None): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor - - -class Glm4vInputProcessor(BaseModelInputProcessor): - """Glm4v input processor.""" - - def __init__(self, config: PretrainedConfig) -> None: - self.config = config - - def preprocess_input(self, - input_ids: List[int], - input_multimodals: List[Dict[str, Any]] = None, - **kwargs) -> PreprocessInputResult: - """Prepare multimodal input.""" - if input_multimodals is None or len(input_multimodals) == 0: - return input_ids, input_multimodals - - input_imgs = [] - for input_mm in input_multimodals: - pixel_values = input_mm['pixel_values'] - image_grid_thw = input_mm['image_grid_thw'] - offset = input_mm['offset'] - start = offset - image_token_id = input_mm['image_token_id'] - num_pad = input_mm['image_tokens'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() - - mm_data = MultiModalTensor(data=pixel_values, - start=start, - end=start + num_pad, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) - input_imgs.append(mm_data) - - result = PreprocessInputResult( - input_ids=input_ids, - input_multimodals=dict(image=input_imgs), - ) - return result diff --git a/lmdeploy/pytorch/models/interns1_pro.py b/lmdeploy/pytorch/models/interns1_pro.py index 51ed9deaf6..77f7b57f93 100644 --- a/lmdeploy/pytorch/models/interns1_pro.py +++ b/lmdeploy/pytorch/models/interns1_pro.py @@ -7,7 +7,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .interns1_pro_ts import InternS1ProTimeSeriesModel @@ -383,10 +383,10 @@ def preprocess_input(self, ts_sr = input_mm['ts_sr'] num_pad = input_mm['num_ts_tokens'] - mm_data = MultiModalTensor(data=ts_values, - start=offset, - end=offset + num_pad, - meta=dict(ts_token_id=ts_token_id, ts_lens=ts_lens, ts_sr=ts_sr)) + mm_data = MultiModalData(data=ts_values, + start=offset, + end=offset + num_pad, + meta=dict(ts_token_id=ts_token_id, ts_lens=ts_lens, ts_sr=ts_sr)) else: pixel_values = input_mm['pixel_values'].to(self.dtype) image_grid_thw = input_mm['image_grid_thw'] @@ -397,10 +397,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=start, - end=start + num_pad, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=start, + end=start + num_pad, + meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 5b6c261dd2..43b80644f6 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -12,7 +12,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import LayerNorm, RMSNorm from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -992,10 +992,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/internvl3_hf.py b/lmdeploy/pytorch/models/internvl3_hf.py index 7cd4cd940c..4ea2eb2f45 100644 --- a/lmdeploy/pytorch/models/internvl3_hf.py +++ b/lmdeploy/pytorch/models/internvl3_hf.py @@ -13,7 +13,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import LayerNorm, RMSNorm from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -736,10 +736,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/llama4.py b/lmdeploy/pytorch/models/llama4.py index 4b3c2196bc..e7711b83d3 100644 --- a/lmdeploy/pytorch/models/llama4.py +++ b/lmdeploy/pytorch/models/llama4.py @@ -8,7 +8,7 @@ import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) @@ -1033,10 +1033,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/llava.py b/lmdeploy/pytorch/models/llava.py index e87242df4c..4004441050 100644 --- a/lmdeploy/pytorch/models/llava.py +++ b/lmdeploy/pytorch/models/llava.py @@ -11,7 +11,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -555,10 +555,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( @@ -834,10 +834,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_sizes=image_sizes, image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_sizes=image_sizes, image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/phi3_v.py b/lmdeploy/pytorch/models/phi3_v.py index c6804d5586..aff3f78935 100644 --- a/lmdeploy/pytorch/models/phi3_v.py +++ b/lmdeploy/pytorch/models/phi3_v.py @@ -8,7 +8,7 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .phi3 import Phi3ForCausalLM, Phi3Model @@ -379,10 +379,10 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=offset, - end=offset + num_pad, - meta=dict(image_sizes=image_sizes, image_token_id=image_token_id)) + mm_data = MultiModalData(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_sizes=image_sizes, image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/qwen2_5_vl.py b/lmdeploy/pytorch/models/qwen2_5_vl.py index 9c19a7de21..7a3e3b73d6 100644 --- a/lmdeploy/pytorch/models/qwen2_5_vl.py +++ b/lmdeploy/pytorch/models/qwen2_5_vl.py @@ -2,23 +2,23 @@ # adapted from: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Iterable, List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.models.qwen2_vl import Qwen2Model -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.models.qwen2_vl import Qwen2VLInputProcessor as Qwen2_5_VLInputProcessor from lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, RMSNorm, SiluAndMul from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import add_prefix -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, vlm_model @@ -411,7 +411,11 @@ def __init__(self, # build model self.model = Qwen2Model(text_config, dtype=dtype, device=device) # build lm_head - self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) + self.lm_head = self.build_lm_head(text_config.hidden_size, + text_config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, @@ -564,165 +568,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - input_buffers = graph_meta.input_buffers - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - images = [] - if input_mm is not None: - images = input_mm.get('image', []) - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - for img in images: - grid_thw = img.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = img.end - img.start - max(h, w) - mrope_delta -= num_pad - fill_start = img.start - pos_start - fill_end = img.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, - past_key_values: List[List[torch.Tensor]], - inputs_embeds: Optional[torch.Tensor] = None, - context: StepContext = None): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor - - -InputMultiModalType = List[Dict[str, Any]] - - -class Qwen2_5_VLInputProcessor(BaseModelInputProcessor): - """Qwen2 input processor.""" - - def __init__(self, config: PretrainedConfig) -> None: - self.config = config - - def preprocess_input(self, - input_ids: List[int], - input_multimodals: List[Dict[str, Any]] = None, - **kwargs) -> PreprocessInputResult: - """Prepare multimodal input.""" - if input_multimodals is None or len(input_multimodals) == 0: - return input_ids, input_multimodals - - input_imgs = [] - for input_mm in input_multimodals: - pixel_values = input_mm['pixel_values'] - image_grid_thw = input_mm['image_grid_thw'] - offset = input_mm['offset'] - start = offset - image_token_id = input_mm['image_token_id'] - num_pad = input_mm['image_tokens'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() - - mm_data = MultiModalTensor(data=pixel_values, - start=start, - end=start + num_pad, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) - input_imgs.append(mm_data) - - result = PreprocessInputResult( - input_ids=input_ids, - input_multimodals=dict(image=input_imgs), - ) - return result diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index 605f8ded76..78fdae83d1 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -1,21 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple +import numpy as np import torch from torch import nn from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, FlashAttention, LayerNorm, RMSNorm, SiluAndMul, build_rotary_embedding_from_config) from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, build_embedding, vlm_model @@ -632,7 +633,11 @@ def __init__(self, # build model self.model = Qwen2Model(text_config, dtype=dtype, device=device) # build lm_head - self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) + self.lm_head = self.build_lm_head(text_config.hidden_size, + text_config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, @@ -767,124 +772,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - input_buffers = graph_meta.input_buffers - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - images = [] - if input_mm is not None: - images = input_mm.get('image', []) - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - for img in images: - grid_thw = img.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = img.end - img.start - max(h, w) - mrope_delta -= num_pad - fill_start = img.start - pos_start - fill_end = img.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, - past_key_values: List[List[torch.Tensor]], - inputs_embeds: Optional[torch.Tensor] = None, - context: StepContext = None): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor @@ -899,6 +786,23 @@ class Qwen2VLInputProcessor(BaseModelInputProcessor): def __init__(self, config: PretrainedConfig) -> None: self.config = config + @staticmethod + def _get_multimodal_pos_ids(grid_thw: Sequence[int]) -> np.ndarray: + """Get mrope ids.""" + t, h, w = grid_thw + h = h // 2 + w = w // 2 + stride = np.array([h * w, w, 1])[None] + size = np.array([t, h, w])[None] + pos_ids = np.arange(t * h * w)[:, None].repeat(3, axis=1) + pos_ids = pos_ids // stride % size + return pos_ids + + @staticmethod + def make_mrope(grid_thw: torch.Tensor, ): + img_pos_ids = Qwen2VLInputProcessor._get_multimodal_pos_ids(grid_thw[0].tolist()) + return img_pos_ids + def preprocess_input(self, input_ids: List[int], input_multimodals: List[Dict[str, Any]] = None, @@ -918,10 +822,13 @@ def preprocess_input(self, if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() - mm_data = MultiModalTensor(data=pixel_values, - start=start, - end=start + num_pad, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + mrope_pos_ids = self.make_mrope(image_grid_thw) + + mm_data = MultiModalData(data=pixel_values, + start=start, + end=start + num_pad, + mrope_pos_ids=mrope_pos_ids, + meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/qwen3_5.py b/lmdeploy/pytorch/models/qwen3_5.py index 59f373f075..1bb186c06d 100644 --- a/lmdeploy/pytorch/models/qwen3_5.py +++ b/lmdeploy/pytorch/models/qwen3_5.py @@ -25,7 +25,7 @@ from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3_5VisionRotaryEmbedding from .qwen2_5_vl import Qwen2_5_VLInputProcessor as Qwen3_5InputProcessor from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3_5VisionAttention -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, vlm_model @@ -1254,134 +1254,6 @@ def __skip_layers(name): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - - max_batchs = graph_meta.max_batchs - device = graph_meta.device - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - state_ids = torch.full((max_batchs, ), -1, dtype=torch.long, device=device) - input_buffers['state_ids'] = state_ids - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - input_buffers = graph_meta.input_buffers - new_inputs = super().fill_buffers_cudagraph(graph_meta, *args, **kwargs) - state_ids = kwargs['state_ids'] - input_buffers['state_ids'].fill_(-1) - input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids) - new_inputs['state_ids'] = input_buffers['state_ids'] - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas_cpu = torch.tensor(mrope_deltas, device='cpu') - if (mrope_deltas_cpu == mrope_deltas_cpu[0]).all(): - mrope_deltas = position_ids.new_full((len(mrope_deltas), ), mrope_deltas[0]) - else: - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - images = [] - if input_mm is not None: - images = input_mm.get('image', []) - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - for img in images: - grid_thw = img.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = img.end - img.start - max(h, w) - mrope_delta -= num_pad - fill_start = img.start - pos_start - fill_end = img.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, past_key_values: List[List[torch.Tensor]], inputs_embeds: torch.Tensor | None, - context: StepContext): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor diff --git a/lmdeploy/pytorch/models/qwen3_next.py b/lmdeploy/pytorch/models/qwen3_next.py index 4c56c01aa3..7117b2ad38 100644 --- a/lmdeploy/pytorch/models/qwen3_next.py +++ b/lmdeploy/pytorch/models/qwen3_next.py @@ -18,7 +18,7 @@ from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, build_embedding @@ -694,29 +694,6 @@ def prepare_inputs_for_generation( state_ids=context.state_offsets, ) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_batchs = graph_meta.max_batchs - device = graph_meta.device - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - state_ids = torch.full((max_batchs, ), -1, dtype=torch.long, device=device) - input_buffers['state_ids'] = state_ids - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - input_buffers = graph_meta.input_buffers - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - state_ids = kwargs['state_ids'] - input_buffers['state_ids'].fill_(-1) - input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids) - new_inputs['state_ids'] = input_buffers['state_ids'] - - return new_inputs - def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], expert_params_mapping: List): """Load weight experts.""" diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index a6f694c6f2..323d757e95 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -21,7 +21,7 @@ from .qwen2_5_vl import Qwen2_5_VLInputProcessor as Qwen3VLInputProcessor from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3VLVisionAttention from .qwen3 import Qwen3model -from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixinV1, vlm_model @@ -718,124 +718,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] load_weight(param, loaded_weight) - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Make cudagraph buffers from forward inputs.""" - max_tokens = graph_meta.max_tokens - - input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) - - return input_buffers - - def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): - """Fill cudagraph buffers from forward inputs.""" - - new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) - - input_ids = kwargs.get('input_ids') - num_tokens = input_ids.size(-1) - new_batch_size = graph_meta.max_batchs - - is_decoding = graph_meta.is_decoding - input_buffers = graph_meta.input_buffers - mrope_position_ids = kwargs.get('mrope_position_ids', None) - if mrope_position_ids is not None: - input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids - if is_decoding: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] - else: - new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] - - return new_inputs - - def _get_model_metas(self, context: StepContext): - """Get model metas.""" - model_metas = context.model_metas - if model_metas is None: - batch_size = context.q_seqlens.numel() - return [dict(mrope_delta=0)] * batch_size - return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] - - def _update_model_meta_decoding(self, context: StepContext): - """Update model meta for decoding.""" - model_metas = self._get_model_metas(context) - position_ids = context.position_ids - - mrope_deltas = [meta['mrope_delta'] for meta in model_metas] - mrope_deltas = position_ids.new_tensor(mrope_deltas) - mrope_position_ids = position_ids + mrope_deltas[None] - mrope_position_ids = mrope_position_ids.expand(3, -1) - - context.mrope_position_ids = mrope_position_ids - return model_metas - - def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): - """Get mrope ids.""" - t, h, w = grid_thw - h //= 2 - w //= 2 - stride = torch.tensor([h * w, w, 1], device=device)[:, None] - size = torch.tensor([t, h, w], device=device)[:, None] - pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) - pos_ids = pos_ids // stride % size - return pos_ids - - def _update_model_meta_prefilling(self, context: StepContext): - """Update model meta for prefilling.""" - model_metas = self._get_model_metas(context) - input_multimodals = context.input_multimodals - if input_multimodals is None: - input_multimodals = [None] * len(model_metas) - position_ids = context.position_ids - batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) - mrope_position_ids = [] - new_model_metas = [] - for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): - images = [] - if input_mm is not None: - images = input_mm.get('image', []) - if model_meta is None or 'mrope_delta' not in model_meta: - mrope_delta = 0 - else: - mrope_delta = model_meta['mrope_delta'] - - pos_start = pos_ids[0].item() - mrope_pos_ids = pos_ids + mrope_delta - mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() - for img in images: - grid_thw = img.meta['grid_thw'][0].tolist() - _, h, w = grid_thw - h //= 2 - w //= 2 - num_pad = img.end - img.start - max(h, w) - mrope_delta -= num_pad - fill_start = img.start - pos_start - fill_end = img.end - pos_start - img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) - img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] - mrope_pos_ids[:, fill_end:] -= num_pad - mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids - - mrope_position_ids.append(mrope_pos_ids) - new_model_metas.append(dict(mrope_delta=mrope_delta)) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=1) - context.mrope_position_ids = mrope_position_ids - - return new_model_metas - - def update_model_metas(self, - past_key_values: List[List[torch.Tensor]], - inputs_embeds: Optional[torch.Tensor] = None, - context: StepContext = None): - """Update model meta.""" - if context.is_decoding: - return self._update_model_meta_decoding(context) - else: - return self._update_model_meta_prefilling(context) - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 2b5a4dc8ad..f4c7394cd4 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any import torch from torch import Tensor @@ -8,7 +8,7 @@ from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager -BuffType = Dict[str, Tensor] +BuffType = dict[str, Tensor] def _get_meta_flashattn( @@ -21,9 +21,9 @@ def _get_meta_flashattn( cache_seqlens: torch.Tensor, qkv_dtype=torch.bfloat16, headdim_v=None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k_new: Optional[torch.Tensor] = None, - page_size: Optional[int] = None, + cu_seqlens_q: torch.Tensor | None = None, + cu_seqlens_k_new: torch.Tensor | None = None, + page_size: int | None = None, causal=True, window_size=(-1, -1), # -1 means infinite context window num_splits=0, @@ -77,9 +77,11 @@ class CudaGraphMeta: vocab_size: int = 1 use_mla_fp8_cache: bool = False use_flash_mla: bool = False - mla_index_topk: Optional[int] = None + mla_index_topk: int | None = None decode_query_len: int = 1 use_fa3_decoding: bool = False + is_ssm: bool = False + use_mrope: bool = False class CudaGraphMixin: @@ -89,7 +91,7 @@ def support_cuda_graph( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - past_key_values: List[List[torch.Tensor]], + past_key_values: list[list[torch.Tensor]], attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, **kwargs, @@ -102,7 +104,7 @@ def make_output_buffers(self, output): if isinstance(output, torch.Tensor): output_buffers = dict(hidden_states=output) else: - assert isinstance(output, Dict) + assert isinstance(output, dict) output_buffers = output return output_buffers @@ -138,7 +140,8 @@ def update_meta_flashattn(self, graph_meta: CudaGraphMeta, block_size: int, max_ ) return scheduler_metadata - def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, past_key_values: List, **kwargs) -> BuffType: + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, past_key_values: list[list[torch.Tensor]], + **kwargs) -> BuffType: """Make cudagraph buffers from forward inputs.""" max_batches = graph_meta.max_batchs max_tokens = graph_meta.max_tokens @@ -190,12 +193,21 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, past_key_valu max_seqlen_k=decode_query_len, cache_seqlens=input_buffers['kv_seqlens']) + # mrope + if graph_meta.use_mrope: + input_buffers['mrope_position_ids'] = torch.zeros(3, max_tokens, dtype=torch.int64, device=device) + + # ssm + if graph_meta.is_ssm: + state_ids = torch.full((max_batches, ), -1, dtype=torch.int64, device=device) + input_buffers['state_ids'] = state_ids + return input_buffers @record_function('fill_buffers_cudagraph') def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, position_ids: Tensor, - past_key_values: List, attn_metadata: Any, inputs_embeds: Tensor, - **kwargs) -> Dict[str, Tensor]: + past_key_values: list[list[torch.Tensor]], attn_metadata: Any, inputs_embeds: Tensor, + **kwargs) -> dict[str, Tensor]: """Fill cudagraph buffers from forward inputs.""" block_offsets: Tensor = attn_metadata.block_offsets @@ -269,6 +281,7 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p past_key_values=past_key_values, attn_metadata=attn_metadata, ) + new_inputs.update(kwargs) new_inputs['input_ids'] = input_buffers['input_ids'] new_inputs['position_ids'] = input_buffers['position_ids'] @@ -276,7 +289,20 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p if inputs_embeds is not None: new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'] - new_inputs.update(kwargs) + # mrope + if graph_meta.use_mrope: + mrope_position_ids = kwargs.get('mrope_position_ids', None) + if mrope_position_ids is not None: + input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids + new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] + + # ssm + if graph_meta.is_ssm: + state_ids = kwargs['state_ids'] + input_buffers['state_ids'].fill_(-1) + input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids) + new_inputs['state_ids'] = input_buffers['state_ids'] + return new_inputs def update_context_cudagraph(self, graph_meta: CudaGraphMeta, context: StepContext): @@ -293,7 +319,15 @@ def update_context_cudagraph(self, graph_meta: CudaGraphMeta, context: StepConte context.kv_seqlens = input_buffers['kv_seqlens'] context.q_start_loc = input_buffers['q_start_loc'] - def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: Tensor, **kwargs): + # mrope + if graph_meta.use_mrope: + context.mrope_position_ids = input_buffers['mrope_position_ids'] + + # ssm + if graph_meta.is_ssm: + context.state_offsets = input_buffers['state_ids'] + + def get_outputs_cudagraph(self, output_buffers: dict[str, torch.Tensor], input_ids: Tensor, **kwargs): """Get outputs from buffers.""" num_tokens = input_ids.size(-1) outputs = dict() diff --git a/lmdeploy/pytorch/multimodal/__init__.py b/lmdeploy/pytorch/multimodal/__init__.py index c3e8c6a16f..fc2d5890d9 100644 --- a/lmdeploy/pytorch/multimodal/__init__.py +++ b/lmdeploy/pytorch/multimodal/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .data_type import MultiModalData, MultiModalTensor +from .data_type import MultiModalData -__all__ = ['MultiModalData', 'MultiModalTensor'] +__all__ = ['MultiModalData'] diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py index dd3ec9a37d..34c545ce0e 100644 --- a/lmdeploy/pytorch/multimodal/data_type.py +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -1,26 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass, fields -from typing import Any, Dict, List, Union +from typing import Any +import numpy as np from torch import Tensor - -class MultiModalData: - pass - - -MultiModalDataList = List[MultiModalData] - -NestedTensor = Union[Tensor, List[Tensor]] +NestedTensor = Tensor | list[Tensor] @dataclass -class MultiModalTensor: +class MultiModalData: data: NestedTensor start: int - end: int = None - encoder_len: int = None - meta: Dict[str, Any] = None + end: int | None = None + encoder_len: int | None = None + meta: dict[str, Any] | None = None + + # for qwen-vl + mrope_pos_ids: np.ndarray | None = None def __post_init__(self): if self.end is None: @@ -53,7 +50,7 @@ def to_device(self, device: str, non_blocking: bool = False): new_meta[k] = v out_dict['meta'] = new_meta - return MultiModalTensor(**out_dict) + return MultiModalData(**out_dict) -MultiModalInputs = Dict[str, List[MultiModalTensor]] +MultiModalInputs = dict[str, list[MultiModalData]] diff --git a/lmdeploy/pytorch/multimodal/image_type.py b/lmdeploy/pytorch/multimodal/image_type.py deleted file mode 100644 index 19211a381f..0000000000 --- a/lmdeploy/pytorch/multimodal/image_type.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from dataclasses import dataclass -from typing import Any, ClassVar, Dict - -from PIL import Image - -from .data_type import MultiModalData - - -@dataclass -class ImageData(MultiModalData): - data: Image - loc: int - meta: Dict[str, Any] = None - type: ClassVar[str] = 'image' diff --git a/lmdeploy/pytorch/nn/gated_delta.py b/lmdeploy/pytorch/nn/gated_delta.py index c61dcab6b5..01038f889f 100644 --- a/lmdeploy/pytorch/nn/gated_delta.py +++ b/lmdeploy/pytorch/nn/gated_delta.py @@ -76,19 +76,16 @@ def conv1d_func(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, state_ids = gated_delta_meta.state_ids assert x.dim() == 3 - x = x.transpose(-2, -1) if weight.dim() == 3: assert weight.size(1) == 1 weight = weight[:, 0] # fill conv state - # TODO: find efficient way to fill conv state without gather + scatter - final_state = conv_state.index_select(0, state_ids) - batch_size = conv_state.size(0) - conv_idx = conv_idx[:, None].expand(-1, x.size(1), -1) - torch.gather(x.expand(batch_size, -1, -1), -1, conv_idx, out=final_state) + final_state = x[0, conv_idx].transpose(-2, -1) conv_state = conv_state.index_copy_(0, state_ids, final_state) + # note that we have not set init states + x = x.transpose(-2, -1) out = self.causal_conv1d_fn( x, weight, diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 9208b7cdf2..c9611a0cba 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -4,7 +4,6 @@ from collections import OrderedDict from contextlib import contextmanager from dataclasses import dataclass -from typing import Dict, List from torch.profiler import record_function @@ -20,8 +19,8 @@ logger = get_logger('lmdeploy') -MapType = Dict[int, int] -SeqList = List[SchedulerSequence] +MapType = dict[int, int] +SeqList = list[SchedulerSequence] @dataclass @@ -46,14 +45,14 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, - seq_meta: SequenceMeta = None, + seq_meta: SequenceMeta | None = None, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config - self.sessions: Dict[int, SchedulerSession] = OrderedDict() + self.sessions: dict[int, SchedulerSession] = OrderedDict() # For Disaggregation - self.locked_sessions: Dict[int, SchedulerSession] = OrderedDict() + self.locked_sessions: dict[int, SchedulerSession] = OrderedDict() self.block_manager = build_block_manager(cache_config) self.block_trie = BlockTrie(self.cache_config, self.block_manager) diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 51739d05d5..af99df1d63 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -45,6 +45,9 @@ def __init__( self.cache_config = specdecode_config.cache_config self.num_spec_tokens = specdecode_config.num_speculative_tokens + # make dummy meta + self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(self.model_config) + def set_cache_config(self, cache_config: CacheConfig): """Set all cache config.""" self.cache_config = cache_config @@ -52,6 +55,9 @@ def set_cache_config(self, cache_config: CacheConfig): def set_model_config(self, model_config: ModelConfig): """Set model config.""" self.model_config = model_config + if model_config is not None: + # make dummy meta + self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(self.model_config) def build_model(self, empty_init: bool, target_model=None, build_model_ctx=None): """Build draft model.""" @@ -194,7 +200,8 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig): device='cuda', vocab_size=self.model_config.vocab_size, target_hidden_size=target_hidden_size, - target_dtype=self.model_config.dtype) + target_dtype=self.model_config.dtype, + meta=self.make_dummy_meta) self._forward_impl(inputs) @@ -203,26 +210,24 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig): for batch_size in capture_batch_sizes: # decode with num_spec_tokens + 1 per seq - inputs = self.inputs_strategy.make_dummy( - batch_size, - is_decoding=True, - device='cuda', - vocab_size=self.model_config.vocab_size, - max_q_seqlen=self.num_spec_tokens + 1, - target_hidden_size=target_hidden_size, - target_dtype=self.model_config.dtype, - ) + inputs = self.inputs_strategy.make_dummy(batch_size, + is_decoding=True, + device='cuda', + vocab_size=self.model_config.vocab_size, + max_q_seqlen=self.num_spec_tokens + 1, + target_hidden_size=target_hidden_size, + target_dtype=self.model_config.dtype, + meta=self.make_dummy_meta) self._forward_impl(inputs) # decode 1 tokens per sequence - inputs = self.inputs_strategy.make_dummy( - batch_size, - is_decoding=True, - device='cuda', - vocab_size=self.model_config.vocab_size, - max_q_seqlen=1, - target_hidden_size=self.model_config.hidden_size, - target_dtype=self.model_config.dtype, - ) + inputs = self.inputs_strategy.make_dummy(batch_size, + is_decoding=True, + device='cuda', + vocab_size=self.model_config.vocab_size, + max_q_seqlen=1, + target_hidden_size=self.model_config.hidden_size, + target_dtype=self.model_config.dtype, + meta=self.make_dummy_meta) self._forward_impl(inputs) def reset_graph_runner(self): diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index 9c7abb5887..5678a11aab 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -25,6 +25,12 @@ def get_model_inputs_next_decoding(inputs: ModelInputs, input_ids: torch.Tensor, state_offsets = inputs.state_offsets if state_offsets is not None: state_offsets = state_offsets.clone() + + # mrope + mrope_pos_ids = inputs.mrope_pos_ids + if mrope_pos_ids is not None: + index = inputs.seq_length.cumsum(0) - 1 + mrope_pos_ids = mrope_pos_ids[:, index] + 1 return ModelInputs( input_ids=input_ids, seq_length=torch.full_like(inputs.seq_length, max_q_seqlen), @@ -38,6 +44,7 @@ def get_model_inputs_next_decoding(inputs: ModelInputs, input_ids: torch.Tensor, local_adapter_ids=inputs.local_adapter_ids, model_metas=model_metas, state_offsets=state_offsets, + mrope_pos_ids=mrope_pos_ids, ) diff --git a/lmdeploy/pytorch/strategies/ar/model_inputs.py b/lmdeploy/pytorch/strategies/ar/model_inputs.py index 7c1910311a..1fcff049e0 100644 --- a/lmdeploy/pytorch/strategies/ar/model_inputs.py +++ b/lmdeploy/pytorch/strategies/ar/model_inputs.py @@ -7,7 +7,7 @@ from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta -from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs +from ..base.model_inputs import MakeDummyMeta, ModelInputsStrategy, make_dummy_inputs def merge_model_inputs(inputs: ModelInputs, other: ModelInputs) -> ModelInputs: @@ -51,6 +51,12 @@ def __try_pad_block_offsets(block_offsets: torch.Tensor, target_size: int): if inputs.state_offsets is not None: state_offsets = torch.cat([inputs.state_offsets, other.state_offsets], dim=0) + # mrope + mrope_pos_ids = None + if inputs.mrope_pos_ids is not None: + assert other.mrope_pos_ids is not None + mrope_pos_ids = torch.cat([inputs.mrope_pos_ids, other.mrope_pos_ids], dim=1) + return ModelInputs( input_ids=input_ids, seq_length=seq_length, @@ -64,6 +70,7 @@ def __try_pad_block_offsets(block_offsets: torch.Tensor, target_size: int): local_adapter_ids=local_adapter_ids, model_metas=model_metas, state_offsets=state_offsets, + mrope_pos_ids=mrope_pos_ids, ) @@ -74,14 +81,16 @@ def make_dummy(self, is_decoding: bool, device: str = 'cpu', dummy_block_id: int = 0, - vocab_size: int = 1) -> ModelInputs: + vocab_size: int = 1, + meta: MakeDummyMeta | None = None) -> ModelInputs: """Create dummy model inputs.""" return make_dummy_inputs(batch_size, max_q_seqlen=1, is_decoding=is_decoding, device=device, dummy_block_id=dummy_block_id, - vocab_size=vocab_size) + vocab_size=vocab_size, + meta=meta) @record_function('ModelInputs.merge') def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs: @@ -140,6 +149,11 @@ def index_select(inputs: ModelInputs, if target_position_ids is not None: target_position_ids = target_position_ids[indices] + # mrope + mrope_pos_ids = inputs.mrope_pos_ids + if mrope_pos_ids is not None: + mrope_pos_ids = mrope_pos_ids[:, indices] + # return new inputs return ModelInputs( input_ids=input_ids, @@ -156,6 +170,7 @@ def index_select(inputs: ModelInputs, state_offsets=state_offsets, target_hidden_states=target_hidden_states, target_position_ids=target_position_ids, + mrope_pos_ids=mrope_pos_ids, ) @record_function('ModelInputs.update_inputs') diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py index b9b277f961..39754da69f 100644 --- a/lmdeploy/pytorch/strategies/ar/sequence.py +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -57,6 +57,8 @@ def update_token_ids(self, if model_meta is not None: self.model_meta = model_meta + self._update_mrope_pos_ids() + def set_step(self, step: int): """Set step.""" num_all_ids = self.num_all_ids diff --git a/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py b/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py index b8ffc94352..aecf6caf05 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py +++ b/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py @@ -5,7 +5,7 @@ from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta from ..ar.model_inputs import merge_model_inputs -from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs +from ..base.model_inputs import MakeDummyMeta, ModelInputsStrategy, make_dummy_inputs class ARSpecModelInputsStrategy(ModelInputsStrategy): @@ -23,6 +23,7 @@ def make_dummy( max_q_seqlen: int = 1, target_hidden_size: int = None, target_dtype: torch.dtype = torch.bfloat16, + meta: MakeDummyMeta | None = None, ) -> ModelInputs: """Create dummy model inputs.""" inputs = make_dummy_inputs(batch_size, @@ -30,7 +31,8 @@ def make_dummy( is_decoding=is_decoding, device=device, dummy_block_id=dummy_block_id, - vocab_size=vocab_size) + vocab_size=vocab_size, + meta=meta) if target_hidden_size is not None: inputs.target_hidden_states = torch.randn((1, batch_size * max_q_seqlen, target_hidden_size), dtype=target_dtype, diff --git a/lmdeploy/pytorch/strategies/ar_spec/sequence.py b/lmdeploy/pytorch/strategies/ar_spec/sequence.py index 7089bce3d0..828310dd3b 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/sequence.py +++ b/lmdeploy/pytorch/strategies/ar_spec/sequence.py @@ -136,6 +136,8 @@ def update_token_ids(self, if model_meta is not None: self.model_meta = model_meta + self._update_mrope_pos_ids() + class ARSpecSequenceStrategy(ARSequenceStrategy): diff --git a/lmdeploy/pytorch/strategies/base/model_inputs.py b/lmdeploy/pytorch/strategies/base/model_inputs.py index 8c18420e0b..1f06772735 100644 --- a/lmdeploy/pytorch/strategies/base/model_inputs.py +++ b/lmdeploy/pytorch/strategies/base/model_inputs.py @@ -1,20 +1,33 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from dataclasses import dataclass import torch from torch.profiler import record_function +from lmdeploy.pytorch.config import ModelConfig from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta +@dataclass +class MakeDummyMeta: + """Make dummy meta for model inputs strategy.""" + # Add any fields needed for making dummy inputs + use_ssm: bool = False + use_mrope: bool = False + + @record_function('make_dummy_input') def make_dummy_inputs(batch_size: int, max_q_seqlen: int, is_decoding: bool, device: str = 'cpu', dummy_block_id: int = 0, - vocab_size: int = 1): + vocab_size: int = 1, + meta: MakeDummyMeta | None = None) -> ModelInputs: """Make dummy inputs global implement.""" + if meta is None: + meta = MakeDummyMeta() num_tokens = batch_size * max_q_seqlen max_kv_seqlen = max_q_seqlen input_ids = torch.randint(0, vocab_size, ( @@ -26,7 +39,14 @@ def make_dummy_inputs(batch_size: int, block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device) num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device) local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device) - state_offsets = torch.full((batch_size, ), -1, dtype=torch.long, device=device) + + state_offsets = None + if meta.use_ssm: + state_offsets = torch.full((batch_size, ), -1, dtype=torch.long, device=device) + + mrope_pos_ids = None + if meta.use_mrope: + mrope_pos_ids = torch.zeros(3, num_tokens, dtype=torch.long, device=device) return ModelInputs( input_ids=input_ids, @@ -41,18 +61,27 @@ def make_dummy_inputs(batch_size: int, local_adapter_ids=local_adapter_ids, is_dummy=True, state_offsets=state_offsets, + mrope_pos_ids=mrope_pos_ids, ) class ModelInputsStrategy(ABC): + def create_make_dummy_meta(self, model_config: ModelConfig): + """Create make dummy meta.""" + return MakeDummyMeta( + use_ssm=len(model_config.states_shapes) > 0, + use_mrope=model_config.use_mrope, + ) + @abstractmethod def make_dummy(self, batch_size: int, is_decoding: bool, device: str = 'cpu', dummy_block_id: int = 0, - vocab_size: int = 1) -> ModelInputs: + vocab_size: int = 1, + meta: MakeDummyMeta | None = None) -> ModelInputs: """Create dummy model inputs.""" pass diff --git a/lmdeploy/pytorch/strategies/dllm/model_inputs.py b/lmdeploy/pytorch/strategies/dllm/model_inputs.py index 151b952d0a..6cf291360f 100644 --- a/lmdeploy/pytorch/strategies/dllm/model_inputs.py +++ b/lmdeploy/pytorch/strategies/dllm/model_inputs.py @@ -2,7 +2,7 @@ from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta from ..ar.model_inputs import merge_model_inputs -from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs +from ..base.model_inputs import MakeDummyMeta, ModelInputsStrategy, make_dummy_inputs class DLLMModelInputsStrategy(ModelInputsStrategy): @@ -10,19 +10,23 @@ class DLLMModelInputsStrategy(ModelInputsStrategy): def __init__(self, block_size: int): self.block_size = block_size - def make_dummy(self, - batch_size: int, - is_decoding: bool, - device: str = 'cpu', - dummy_block_id: int = 0, - vocab_size: int = 1) -> ModelInputs: + def make_dummy( + self, + batch_size: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1, + meta: MakeDummyMeta | None = None, + ) -> ModelInputs: """Create dummy model inputs.""" return make_dummy_inputs(batch_size, max_q_seqlen=self.block_size, is_decoding=is_decoding, device=device, dummy_block_id=dummy_block_id, - vocab_size=vocab_size) + vocab_size=vocab_size, + meta=meta) def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs: """Merge model inputs."""