feat(conversion): support distributed adapter export#4221
Conversation
Light Code ReviewFindings
Suggested test cases No perf tests impacted. |
Signed-off-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Chen Cui <chcui@nvidia.com>
50e8cc2 to
121ef0c
Compare
|
/claude review |
|
|
||
| def _export_adapter_distributed(args: argparse.Namespace) -> None: | ||
| device = _configure_cuda_device() | ||
| ckpt_path = Path(args.lora_checkpoint).expanduser().resolve() |
There was a problem hiding this comment.
Bug (minor): The CPU path in export_adapter_ckpt (auto_bridge.py:1433-1434) validates that the checkpoint path exists before doing any work:
if not ckpt_path.exists():
raise FileNotFoundError(f"PEFT checkpoint not found: {ckpt_path}")This distributed path skips that check, so a typo'd path would surface as a cryptic dist_checkpointing.load error instead of a clear FileNotFoundError.
| def _export_adapter_distributed(args: argparse.Namespace) -> None: | |
| device = _configure_cuda_device() | |
| ckpt_path = Path(args.lora_checkpoint).expanduser().resolve() | |
| def _export_adapter_distributed(args: argparse.Namespace) -> None: | |
| device = _configure_cuda_device() | |
| ckpt_path = Path(args.lora_checkpoint).expanduser().resolve() | |
| if not ckpt_path.exists(): | |
| raise FileNotFoundError(f"PEFT checkpoint not found: {ckpt_path}") | |
| config = AutoConfig.from_pretrained(args.hf_model_path, trust_remote_code=args.trust_remote_code) |
|
|
||
| sharded_state_dict = _generate_model_state_dict(model, {}) | ||
| sharded_state_dict = apply_peft_adapter_filter_to_state_dict(sharded_state_dict, lora) | ||
| loaded_sd = dist_checkpointing.load(sharded_state_dict, str(ckpt_path)) | ||
| model_key = _get_loaded_model_key(loaded_sd, ckpt_path) | ||
| model[0].load_state_dict(loaded_sd[model_key], strict=False) |
There was a problem hiding this comment.
Bug: When PP > 1 with virtual pipeline parallelism, provide_distributed_model returns multiple model chunks. _get_loaded_model_key finds only the first key (e.g. "model0"), and model[0].load_state_dict(...) only loads into chunk 0 — subsequent chunks stay uninitialized.
If PP > 1 export isn't expected to be used yet, consider adding a guard:
| sharded_state_dict = _generate_model_state_dict(model, {}) | |
| sharded_state_dict = apply_peft_adapter_filter_to_state_dict(sharded_state_dict, lora) | |
| loaded_sd = dist_checkpointing.load(sharded_state_dict, str(ckpt_path)) | |
| model_key = _get_loaded_model_key(loaded_sd, ckpt_path) | |
| model[0].load_state_dict(loaded_sd[model_key], strict=False) | |
| sharded_state_dict = _generate_model_state_dict(model, {}) | |
| sharded_state_dict = apply_peft_adapter_filter_to_state_dict(sharded_state_dict, lora) | |
| loaded_sd = dist_checkpointing.load(sharded_state_dict, str(ckpt_path)) | |
| model_key = _get_loaded_model_key(loaded_sd, ckpt_path) | |
| for chunk in model: | |
| chunk.load_state_dict(loaded_sd[model_key], strict=False) |
Or if only one chunk is valid, assert len(model) == 1 to catch misuse early.
|
test |
Code ReviewFindings1. Missing checkpoint path validation in distributed export ( The CPU path in 2. When 3. Duplicated dtype alias map
Test coverage gaps
Suggested test casesNo perf tests impacted. |
Signed-off-by: Chen Cui <chcui@nvidia.com>
Summary
Blast Radius / Test Assessment
Validation
uv run --no-sync pre-commit run --files examples/conversion/adapter/export_adapter.py src/megatron/bridge/models/conversion/auto_bridge.py src/megatron/bridge/models/conversion/model_bridge.py src/megatron/bridge/models/conversion/peft_bridge.py tests/unit_tests/models/test_adapter_export.py tests/unit_tests/models/test_model_bridge_lora.py examples/conversion/adapter/README.mduv run --no-sync python -m py_compile examples/conversion/adapter/export_adapter.py src/megatron/bridge/models/conversion/auto_bridge.py src/megatron/bridge/models/conversion/model_bridge.py src/megatron/bridge/models/conversion/peft_bridge.py tests/unit_tests/models/test_adapter_export.py tests/unit_tests/models/test_model_bridge_lora.pyFocused pytest was attempted with
uv run python -m pytest tests/unit_tests/models/test_adapter_export.py tests/unit_tests/models/test_model_bridge_lora.py -q, but local dependency resolution failed before test collection becausenvidia-resiliency-ext==0.6.0has no compatible wheel for this host platform.