-
Notifications
You must be signed in to change notification settings - Fork 678
Draft model update params #4452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6f553a8
e944997
7a48257
b745d7f
36090b8
5e9f548
41d816a
7847174
5febdff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1177,6 +1177,44 @@ def _construct(item, require_clone: bool = True): | |||||||||||
| ipc_tensor = func(*args) | ||||||||||||
| return ipc_tensor.clone() if require_clone else ipc_tensor | ||||||||||||
|
|
||||||||||||
| def _deserialize_weights(serialized_data): | ||||||||||||
| raw = ForkingPickler.loads(pybase64.b64decode(serialized_data)) | ||||||||||||
| if request.load_format == 'flattened_bucket': | ||||||||||||
| metadata: list[FlattenedTensorMetadata] = raw['metadata'] | ||||||||||||
| if not metadata: | ||||||||||||
| return [] | ||||||||||||
| if 'flattened_tensor' in weights: | ||||||||||||
| # Determine if clone is required | ||||||||||||
| require_clone = weights.get('require_clone', True) | ||||||||||||
| if 'event_ipc_handle' in weights and not hasattr(torch.cuda.Event, 'from_ipc_handle'): | ||||||||||||
| # Force clone when IPC event is provided but cannot be used | ||||||||||||
| require_clone = True | ||||||||||||
| self._update_params_ipc_tensor = _construct(weights['flattened_tensor'], | ||||||||||||
| require_clone=require_clone) | ||||||||||||
| elif self._update_params_ipc_tensor is None: | ||||||||||||
| raise ValueError( | ||||||||||||
| 'flattened_tensor is not provided in weights and no cached ipc tensor is available. ' | ||||||||||||
| 'Please provide flattened_tensor on the first update_params call.') | ||||||||||||
| if 'event_ipc_handle' in weights and hasattr(torch.cuda.Event, 'from_ipc_handle'): | ||||||||||||
| self._update_params_ipc_event = torch.cuda.Event.from_ipc_handle( | ||||||||||||
| device=torch.cuda.current_device(), | ||||||||||||
| handle=weights['event_ipc_handle'], | ||||||||||||
| ) | ||||||||||||
| flattened_tensor: torch.Tensor = self._update_params_ipc_tensor | ||||||||||||
| if self._update_params_ipc_event is not None: | ||||||||||||
| self._update_params_ipc_event.wait() | ||||||||||||
| bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata) | ||||||||||||
| return list(bucket.reconstruct_tensors()) | ||||||||||||
| return [(k, _construct(v)) for k, v in raw] | ||||||||||||
|
|
||||||||||||
| def _split_main_and_draft(weights): | ||||||||||||
| # TODO, zhouxinyu, support split and update weights for other mtp methods | ||||||||||||
| if not self.spec_agent.is_enabled() or self.spec_agent.method != 'qwen3_5_mtp': | ||||||||||||
| return weights, [] | ||||||||||||
| main = [(name, weight) for name, weight in weights if not name.startswith('mtp.')] | ||||||||||||
| draft = [(name, weight) for name, weight in weights if name.startswith('mtp.')] | ||||||||||||
|
||||||||||||
| draft = [(name, weight) for name, weight in weights if name.startswith('mtp.')] | |
| # For the draft (spec) model, strip the outer "mtp." prefix from parameter names | |
| draft = [(name[len('mtp.'):], weight) | |
| for name, weight in weights | |
| if name.startswith('mtp.')] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may add a TODO or warning message in here