[Fix][Feat] Fix worker sorting with external pg bundles & Support persistent buffer for update_params#4397
Conversation
There was a problem hiding this comment.
Pull request overview
This PR improves RLHF/distributed weight-update workflows by (1) preserving Ray worker/bundle ordering when users provide external placement group bundle indices and (2) reducing IPC cloning overhead for frequent update_params calls via a persistent IPC tensor/event, supported by a preallocated flattened-tensor bucket path.
Changes:
- Ray: skip IP-based worker sorting and preserve user-provided bundle index order when
LMDEPLOY_RAY_EXTERNAL_PG_BUNDLESis set. - Pytorch agent: add persistent IPC tensor/event handling to avoid per-update cloning when possible.
- Utils: extend
FlattenedTensorBucketand serialization to support optional/preallocated flattened buffers (zero/low-copy paths).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
lmdeploy/utils.py |
Makes flattened bucket serialization more flexible and adds a preallocated buffer path for concatenation. |
lmdeploy/pytorch/engine/model_agent/agent.py |
Adds persistent IPC tensor/event support in update_params and cleanup during sleep/finalization. |
lmdeploy/pytorch/engine/executor/ray_executor.py |
Adjusts worker sorting and bundle index selection for externally-provided Ray PG bundle indices. |
Comments suppressed due to low confidence (1)
lmdeploy/pytorch/engine/model_agent/agent.py:1183
- The comment says
request.serialized_named_tensors is now a dict ..., but the code still treatsrequest.serialized_named_tensorsas base64-encoded bytes and the deserialized object (weights) is the dict. Please reword this to avoid misleading API users (especially sinceUpdateParamsRequest.serialized_named_tensorsis typed asUnion[str, List[str], Dict]).
# request.serialized_named_tensors is now a dict with following keys:
# - metadata: List[FlattenedTensorMetadata]
# - flattened_tensor: the flattened tensor for weights, optional
# - event_ipc_handle: the ipc handle of the event
# that used to sync stream across processes, optional
serialized_data = request.serialized_named_tensors
if isinstance(serialized_data, list):
serialized_data = serialized_data[self.dist_ctx.tp_group.rank]
model = self.patched_model.get_model()
weights = ForkingPickler.loads(pybase64.b64decode(serialized_data))
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
lmdeploy/utils.py
Outdated
| assert flattened_tensor.numel() >= current_idx, \ | ||
| 'Provided flattened tensor numel is smaller than ' + \ | ||
| f'required numel: {flattened_tensor.numel()} < {current_idx}' | ||
| assert sum([t.numel() for t in flattened_tensor_list]) == current_idx |
There was a problem hiding this comment.
These new branches add non-trivial behavior (optional flattened_tensor in serialization and the preallocated-buffer out= cat path), but there are no unit tests covering them. Since this module already has tests (tests/test_lmdeploy/test_utils.py), adding focused tests around this initialization path would help prevent regressions.
| assert sum([t.numel() for t in flattened_tensor_list]) == current_idx | |
| assert sum([t.numel() for t in flattened_tensor_list]) == current_idx | |
| # Validate that the provided preallocated buffer is compatible with inputs. | |
| # All named_tensors are already verified to share the same dtype. | |
| first_tensor = named_tensors[0][1] if named_tensors else None | |
| if first_tensor is not None: | |
| if flattened_tensor.dtype != first_tensor.dtype: | |
| raise ValueError( | |
| f'flattened_tensor dtype {flattened_tensor.dtype} does not match ' | |
| f'input tensors dtype {first_tensor.dtype}' | |
| ) | |
| if flattened_tensor.device != first_tensor.device: | |
| raise ValueError( | |
| f'flattened_tensor device {flattened_tensor.device} does not match ' | |
| f'input tensors device {first_tensor.device}' | |
| ) | |
| if not flattened_tensor.is_contiguous(): | |
| raise ValueError('flattened_tensor must be contiguous when used as an output buffer') |
| if not _envs.ray_external_pg_bundles: | ||
| for bundle_id, bundle in enumerate(placement_group.bundle_specs): | ||
| if bundle.get(device_str, 0): | ||
| bundle_indices.append(bundle_id) | ||
| else: | ||
| # use external specified bundle indices,keep the order as well | ||
| bundle_indices = _envs.ray_external_pg_bundles.copy() | ||
| attn_tp = self.dist_config.attn_tp | ||
| bundle_indices = bundle_indices[:attn_tp] | ||
|
|
There was a problem hiding this comment.
When _envs.ray_external_pg_bundles is set, bundle_indices is taken verbatim with no validation. It would be safer to validate indices are in-range for placement_group.bundle_specs, that each selected bundle actually has the required device_str resource, and that len(bundle_indices) >= attn_tp (or raise a clear error). Otherwise Ray may fail later with a less actionable scheduling error or create fewer workers than required.
| # use external specified bundle indices,keep the order as well | ||
| bundle_indices = _envs.ray_external_pg_bundles.copy() |
There was a problem hiding this comment.
PR description mentions an env var RAY_EXTERNAL_PG_BUNDLES, but the code reads LMDEPLOY_RAY_EXTERNAL_PG_BUNDLES via _envs.ray_external_pg_bundles. Please align the PR description / documentation (or add an alias) to avoid users setting the wrong variable and not getting the intended behavior.
There was a problem hiding this comment.
PR description mentions an env var
RAY_EXTERNAL_PG_BUNDLES, but the code readsLMDEPLOY_RAY_EXTERNAL_PG_BUNDLESvia_envs.ray_external_pg_bundles. Please align the PR description / documentation (or add an alias) to avoid users setting the wrong variable and not getting the intended behavior.
@copilot I update PR description, please view this again.
| if 'flattened_tensor' in weights: | ||
| self._update_params_ipc_tensor = _construct(weights['flattened_tensor'], | ||
| require_clone=weights.get('require_clone', True)) | ||
| if 'event_ipc_handle' in weights and hasattr(torch.cuda.Event, 'from_ipc_handle'): | ||
| self._update_params_ipc_event = torch.cuda.Event.from_ipc_handle( | ||
| device=torch.cuda.current_device(), | ||
| handle=weights['event_ipc_handle'], | ||
| ) | ||
| flattened_tensor: torch.Tensor = self._update_params_ipc_tensor | ||
| if self._update_params_ipc_event is not None: | ||
| self._update_params_ipc_event.wait() | ||
| bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata) | ||
| weights = bucket.reconstruct_tensors() |
There was a problem hiding this comment.
When request.load_format == 'flattened_bucket' and metadata is non-empty, this code will pass flattened_tensor=None into FlattenedTensorBucket if (a) the request omits flattened_tensor and (b) _update_params_ipc_tensor has not been initialized from a prior call. That will raise in FlattenedTensorBucket.__init__ and makes the “flattened_tensor optional” behavior fragile; consider explicitly erroring with a clear message unless a cached _update_params_ipc_tensor already exists (or require flattened_tensor on the first call).
| if 'event_ipc_handle' in weights and hasattr(torch.cuda.Event, 'from_ipc_handle'): | ||
| self._update_params_ipc_event = torch.cuda.Event.from_ipc_handle( | ||
| device=torch.cuda.current_device(), | ||
| handle=weights['event_ipc_handle'], | ||
| ) | ||
| flattened_tensor: torch.Tensor = self._update_params_ipc_tensor | ||
| if self._update_params_ipc_event is not None: | ||
| self._update_params_ipc_event.wait() |
There was a problem hiding this comment.
If the producer supplies event_ipc_handle but the local PyTorch build lacks torch.cuda.Event.from_ipc_handle, this silently skips cross-process stream synchronization and proceeds to read from the IPC tensor. That can lead to stale/partially-written weights; consider failing fast (or forcing require_clone=True / torch.cuda.synchronize() fallback) when an event handle is provided but cannot be imported.
lmdeploy/utils.py
Outdated
| assert len(flattened_tensor.shape) == 1, 'flattened_tensor must be 1-D tensor' | ||
| assert flattened_tensor.numel() >= current_idx, \ | ||
| 'Provided flattened tensor numel is smaller than ' + \ | ||
| f'required numel: {flattened_tensor.numel()} < {current_idx}' | ||
| assert sum([t.numel() for t in flattened_tensor_list]) == current_idx |
There was a problem hiding this comment.
In the preallocated-buffer path, input validation uses assert and only checks 1-D and numel(). Since this is user-provided input (and asserts can be stripped with -O), it would be safer to raise ValueError with a clear message and also validate dtype/device (and ideally contiguity) match the source tensors to avoid cryptic torch.cat(..., out=...) runtime errors.
| assert len(flattened_tensor.shape) == 1, 'flattened_tensor must be 1-D tensor' | |
| assert flattened_tensor.numel() >= current_idx, \ | |
| 'Provided flattened tensor numel is smaller than ' + \ | |
| f'required numel: {flattened_tensor.numel()} < {current_idx}' | |
| assert sum([t.numel() for t in flattened_tensor_list]) == current_idx | |
| # Validate user-provided preallocated buffer. | |
| if flattened_tensor.dim() != 1: | |
| raise ValueError( | |
| f'flattened_tensor must be a 1-D tensor, but got shape {tuple(flattened_tensor.shape)}' | |
| ) | |
| if flattened_tensor.numel() < current_idx: | |
| raise ValueError( | |
| 'Provided flattened tensor numel is smaller than required numel: ' | |
| f'{flattened_tensor.numel()} < {current_idx}' | |
| ) | |
| # Ensure dtype and device are compatible with source tensors. | |
| reference_tensor = named_tensors[0][1] | |
| if flattened_tensor.dtype != reference_tensor.dtype: | |
| raise ValueError( | |
| f'flattened_tensor dtype {flattened_tensor.dtype} does not match source ' | |
| f'tensors dtype {reference_tensor.dtype}' | |
| ) | |
| if flattened_tensor.device != reference_tensor.device: | |
| raise ValueError( | |
| f'flattened_tensor device {flattened_tensor.device} does not match source ' | |
| f'tensors device {reference_tensor.device}' | |
| ) | |
| if not flattened_tensor.is_contiguous(): | |
| raise ValueError('flattened_tensor must be contiguous') | |
| total_numel = sum(t.numel() for t in flattened_tensor_list) | |
| if total_numel != current_idx: | |
| raise ValueError( | |
| 'Mismatch between computed and expected flattened size: ' | |
| f'{total_numel} != {current_idx}' | |
| ) |
…e require_clone when IPC event unavailable
Motivation
This PR addresses two issues related to distributed training parameter update in RLHF scenarios:
Worker Sorting Issue: When using external placement group bundle indices (e.g., in specific Ray cluster configurations), the workers might be incorrectly sorted by IP address, which broke the intended worker-to-bundle mapping.
Memory Copy Overhead: The
update_paramsinterface used for RL training required cloning tensors for IPC (Inter-Process Communication) on every parameter update, causing unnecessary memory overhead and synchronization costs.Modification
ray_executor.py:
_sort_workersto skip IP-based sorting when external bundle indices are specified viaLMDEPLOY_RAY_EXTERNAL_PG_BUNDLESenv var_valid_bundle_idmethodagent.py:
_update_params_ipc_tensor) and CUDA event (_update_params_ipc_event) for efficient parameter updatesutils.py:
FlattenedTensorBucketto support optional pre-allocated flattened tensor bufferflattened_tensorfield optional in serialization to support zero-copy scenariosBC-breaking (Optional)
No BC-breaking changes. The modifications are backward compatible:
Use cases (Optional)
This PR optimizes RL training workflows where:
Checklist