Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 73 additions & 43 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,44 @@ def _construct(item, require_clone: bool = True):
ipc_tensor = func(*args)
return ipc_tensor.clone() if require_clone else ipc_tensor

def _deserialize_weights(serialized_data):
raw = ForkingPickler.loads(pybase64.b64decode(serialized_data))
if request.load_format == 'flattened_bucket':
metadata: list[FlattenedTensorMetadata] = raw['metadata']
if not metadata:
return []
if 'flattened_tensor' in weights:
# Determine if clone is required
require_clone = weights.get('require_clone', True)
if 'event_ipc_handle' in weights and not hasattr(torch.cuda.Event, 'from_ipc_handle'):
# Force clone when IPC event is provided but cannot be used
require_clone = True
self._update_params_ipc_tensor = _construct(weights['flattened_tensor'],
require_clone=require_clone)
elif self._update_params_ipc_tensor is None:
raise ValueError(
'flattened_tensor is not provided in weights and no cached ipc tensor is available. '
'Please provide flattened_tensor on the first update_params call.')
if 'event_ipc_handle' in weights and hasattr(torch.cuda.Event, 'from_ipc_handle'):
self._update_params_ipc_event = torch.cuda.Event.from_ipc_handle(
device=torch.cuda.current_device(),
handle=weights['event_ipc_handle'],
)
flattened_tensor: torch.Tensor = self._update_params_ipc_tensor
if self._update_params_ipc_event is not None:
self._update_params_ipc_event.wait()
bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata)
return list(bucket.reconstruct_tensors())
return [(k, _construct(v)) for k, v in raw]

def _split_main_and_draft(weights):
# TODO, zhouxinyu, support split and update weights for other mtp methods
if not self.spec_agent.is_enabled() or self.spec_agent.method != 'qwen3_5_mtp':
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may add a TODO or warning message in here

return weights, []
main = [(name, weight) for name, weight in weights if not name.startswith('mtp.')]
draft = [(name, weight) for name, weight in weights if name.startswith('mtp.')]
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Draft-model weight updates will likely fail because draft_weights retain the 'mtp.' prefix, but the spec draft model built by spec_agent is a standalone patched model whose parameter names typically do not include that outer prefix. This can lead to missing-key/KeyError inside load_weights when indexing params_dict[name]. Consider stripping the 'mtp.' prefix (and/or applying an explicit mapping) before passing weights through _rename_weights_iterator/load_weights for spec_model.

Suggested change
draft = [(name, weight) for name, weight in weights if name.startswith('mtp.')]
# For the draft (spec) model, strip the outer "mtp." prefix from parameter names
draft = [(name[len('mtp.'):], weight)
for name, weight in weights
if name.startswith('mtp.')]

Copilot uses AI. Check for mistakes.
return main, draft

with self.all_context():
# After deserialization, weights is a dict with following keys:
# - metadata: List[FlattenedTensorMetadata]
Expand All @@ -1186,52 +1224,33 @@ def _construct(item, require_clone: bool = True):
serialized_data = request.serialized_named_tensors
if isinstance(serialized_data, list):
serialized_data = serialized_data[self.dist_ctx.tp_group.rank]

model = self.patched_model.get_model()
weights = ForkingPickler.loads(pybase64.b64decode(serialized_data))
if request.load_format == 'flattened_bucket':
metadata: list[FlattenedTensorMetadata] = weights['metadata']
if metadata:
if 'flattened_tensor' in weights:
# Determine if clone is required
require_clone = weights.get('require_clone', True)
if 'event_ipc_handle' in weights and not hasattr(torch.cuda.Event, 'from_ipc_handle'):
# Force clone when IPC event is provided but cannot be used
require_clone = True
self._update_params_ipc_tensor = _construct(weights['flattened_tensor'],
require_clone=require_clone)
elif self._update_params_ipc_tensor is None:
raise ValueError(
'flattened_tensor is not provided in weights and no cached ipc tensor is available. '
'Please provide flattened_tensor on the first update_params call.')
if 'event_ipc_handle' in weights and hasattr(torch.cuda.Event, 'from_ipc_handle'):
self._update_params_ipc_event = torch.cuda.Event.from_ipc_handle(
device=torch.cuda.current_device(),
handle=weights['event_ipc_handle'],
)
flattened_tensor: torch.Tensor = self._update_params_ipc_tensor
if self._update_params_ipc_event is not None:
self._update_params_ipc_event.wait()
bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata)
weights = bucket.reconstruct_tensors()
else:
# empty data
weights = []
else:
weights = [(k, _construct(v)) for k, v in weights]
spec_model = self.spec_agent.get_model()

weights = ModelWeightLoader._rename_weights_iterator(weights, model)
model.load_weights(weights)
if self._update_params_ipc_event is not None:
self._update_params_ipc_event.record()
weights = _deserialize_weights(serialized_data)
main_weights, draft_weights = _split_main_and_draft(weights)

for m, w, tag in [(model, main_weights, 'main'), (spec_model, draft_weights, 'draft')]:
if m is None or not w:
continue

w = list(ModelWeightLoader._rename_weights_iterator(w, m))
logger.info(f'Update_params: {tag}_num_tensors={len(w)}')
m.load_weights(iter(w))

if self._update_params_ipc_event is not None:
self._update_params_ipc_event.record()

if request.finished:
for _, mod in model.named_modules():
if not hasattr(mod, 'update_weights'):
continue
mod.update_weights()
torch.cuda.synchronize()
self._update_params_ipc_event = None
self._update_params_ipc_tensor = None
for m in filter(None, [model, spec_model]):
for _, mod in m.named_modules():
if hasattr(mod, 'update_weights'):
mod.update_weights()

torch.cuda.synchronize()
self._update_params_ipc_event = None
self._update_params_ipc_tensor = None

torch.cuda.empty_cache()

Expand All @@ -1241,11 +1260,17 @@ async def sleep(self, level: int = 1):
self.state.is_sleeping = True
if self.dist_config.dp > 1:
await self.state.to_sleep.wait()
device = 'cpu' if level == 1 else 'meta'
self.cache_engine = None
self.state_cache_engine = None
self.reset_graph_runner()
device = 'cpu' if level == 1 else 'meta'
self.patched_model.get_model().to(device=device, non_blocking=True)

spec_model = self.spec_agent.get_model()
if spec_model is not None:
self.spec_agent.cache_engine = None
spec_model.to(device=device, non_blocking=True)

torch.cuda.synchronize()
# force clean _update_params_ipc tensor and event after all gpu jobs done
self._update_params_ipc_tensor = None
Expand All @@ -1258,11 +1283,16 @@ def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
if tags is None:
tags = ['weights', 'kv_cache']

if 'weights' in tags:
device = next(self.patched_model.get_model().parameters()).device
assert device.type in ['cpu', 'meta']
spec_model = self.spec_agent.get_model()

if device.type == 'cpu':
self.patched_model.get_model().to(torch.cuda.current_device())
if spec_model is not None:
spec_model.to(torch.cuda.current_device())
else:
# user should update weights after wakeup
old_empty_init = self.misc_config.empty_init
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/engine/mp_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def update_params(self, request: Any):
"""Update params."""
return self._collective_rpc('update_params', request)

def get_schedule_metrics(self):
async def get_schedule_metrics(self):
"""Get schedule metrics."""
return self._collective_rpc('get_schedule_metrics')
return await self._collective_rpc_async('get_schedule_metrics')

def p2p_initialize(self, conn_request: DistServeInitRequest):
"""Init rdma link."""
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/spec_decode/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ def update_main_model_outputs(self, output: dict[str, torch.Tensor], model_input
# replace with aux
output['hidden_states'] = output.pop('aux_hidden_states')
return hidden_states, output

def get_model(self):
"""Get model."""
return None
6 changes: 5 additions & 1 deletion lmdeploy/pytorch/spec_decode/spec_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig):
self._forward_impl(inputs)

def reset_graph_runner(self):
'reset graph runner'
"""Reset graph runner."""
if self.proposer.model is not None and hasattr(self.proposer.model, 'reset'):
self.proposer.model.reset()

def get_model(self):
"""Get model."""
return self.proposer.model.get_model()
7 changes: 5 additions & 2 deletions lmdeploy/serve/core/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,11 @@ def _build_stat_loggers(self):
# set stats loggers of metrics processor
metrics_processor.stat_loggers = self.stat_loggers

def get_schedule_metrics(self):
return self.engine.get_schedule_metrics()
async def get_schedule_metrics(self):
result = self.engine.get_schedule_metrics()
if asyncio.iscoroutine(result):
return await result
return result

async def do_log_stats(self):
"""Loop through CLI logger and Prometheus logger and output the
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ async def _force_log():
await asyncio.sleep(log_interval)

# periodically update schedule metrics, as they change less frequently than iteration stats
schedule_metrics = async_engine.get_schedule_metrics()
schedule_metrics = await async_engine.get_schedule_metrics()
await metrics_processor.update_schedule_stats(schedule_metrics)

await async_engine.do_log_stats()
Expand Down
Loading