From 1bc2ce758ea1432036cf2f592b7c05690738377c Mon Sep 17 00:00:00 2001 From: mieshkiwrk Date: Tue, 10 Sep 2024 10:06:11 +0200 Subject: [PATCH 01/14] Add compile_fn for Trainer --- src/lightning/pytorch/trainer/trainer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 406f686efe732..2d5b680d58789 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -25,7 +25,7 @@ import os from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Union, Callable from weakref import proxy import torch @@ -127,6 +127,7 @@ def __init__( sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, + compile_fn: Optional[Callable] = None ) -> None: r"""Customize every aspect of training via flags. @@ -468,6 +469,8 @@ def __init__( self.should_stop = False self.state = TrainerState() + self.compile_fn = compile_fn + # configure profiler setup._init_profiler(self, profiler) @@ -956,6 +959,10 @@ def _run( # strategy will configure model and move it to the device self.strategy.setup(self) + # compile if compile_fn provided after configured strategy + if self.compile_fn is not None: + self.strategy.model = self.compile_fn(self.strategy.model) + # hook if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_fit_start") From e26132a0329a0d7cb0d060f0abe318d84a0ea0c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Sep 2024 08:17:36 +0000 Subject: [PATCH 02/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/trainer/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 2d5b680d58789..bab0dce90beb3 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -25,7 +25,7 @@ import os from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, Iterable, List, Optional, Union, Callable +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union from weakref import proxy import torch @@ -127,7 +127,7 @@ def __init__( sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, - compile_fn: Optional[Callable] = None + compile_fn: Optional[Callable] = None, ) -> None: r"""Customize every aspect of training via flags. @@ -470,7 +470,7 @@ def __init__( self.state = TrainerState() self.compile_fn = compile_fn - + # configure profiler setup._init_profiler(self, profiler) @@ -962,7 +962,7 @@ def _run( # compile if compile_fn provided after configured strategy if self.compile_fn is not None: self.strategy.model = self.compile_fn(self.strategy.model) - + # hook if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_fit_start") From 925c3763e9f8b8de4b624b98190f8386a74cef23 Mon Sep 17 00:00:00 2001 From: mieshkiwrk Date: Tue, 10 Sep 2024 10:26:05 +0200 Subject: [PATCH 03/14] Add parameter description --- src/lightning/pytorch/trainer/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index bab0dce90beb3..bbd31a94ee00a 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -290,6 +290,9 @@ def __init__( Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' + compile_fn: Provide torch.compile function to be applied after configuring strategy + Default: ``None``. + Raises: TypeError: If ``gradient_clip_val`` is not an int or float. From 86d2c702ea530c0543225c76a6ea8feb9e0f3411 Mon Sep 17 00:00:00 2001 From: Mieszko Dziadowiec Date: Wed, 27 Nov 2024 13:07:14 +0100 Subject: [PATCH 04/14] Test reapply_compile for trainer --- src/lightning/pytorch/trainer/trainer.py | 31 ++++++++++++------------ 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index bbd31a94ee00a..1b3326771f4d4 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -25,7 +25,7 @@ import os from contextlib import contextmanager from datetime import timedelta -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Union from weakref import proxy import torch @@ -79,6 +79,10 @@ LRSchedulerConfig, ) from lightning.pytorch.utilities.warnings import PossibleUserWarning +from lightning.fabric.wrappers import ( + _unwrap_compiled, + _to_compiled +) log = logging.getLogger(__name__) @@ -127,7 +131,7 @@ def __init__( sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, - compile_fn: Optional[Callable] = None, + reapply_compile = False ) -> None: r"""Customize every aspect of training via flags. @@ -290,9 +294,6 @@ def __init__( Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' - compile_fn: Provide torch.compile function to be applied after configuring strategy - Default: ``None``. - Raises: TypeError: If ``gradient_clip_val`` is not an int or float. @@ -307,6 +308,8 @@ def __init__( if default_root_dir is not None: default_root_dir = os.fspath(default_root_dir) + self._reapply_compile = reapply_compile + self.barebones = barebones if barebones: # opt-outs @@ -472,8 +475,6 @@ def __init__( self.should_stop = False self.state = TrainerState() - self.compile_fn = compile_fn - # configure profiler setup._init_profiler(self, profiler) @@ -535,19 +536,20 @@ def fit( For more information about multiple dataloaders, see this :ref:`section `. """ - model = _maybe_unwrap_optimized(model) + model, compile_kwargs = _unwrap_compiled(model) if self._reapply_compile else (_maybe_unwrap_optimized(model), None) self.strategy._lightning_module = model _verify_strategy_supports_compile(model, self.strategy) self.state.fn = TrainerFn.FITTING self.state.status = TrainerStatus.RUNNING self.training = True call._call_and_handle_interrupt( - self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path + self, self._fit_impl, model, compile_kwargs, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) def _fit_impl( self, model: "pl.LightningModule", + compile_kwargs, train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, @@ -577,7 +579,7 @@ def _fit_impl( model_provided=True, model_connected=self.lightning_module is not None, ) - self._run(model, ckpt_path=ckpt_path) + self._run(model, compile_kwargs, ckpt_path=ckpt_path) assert self.state.stopped self.training = False @@ -908,7 +910,7 @@ def _predict_impl( return results def _run( - self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None + self, model: "pl.LightningModule", compile_kwargs, ckpt_path: Optional[_PATH] = None ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( @@ -962,10 +964,9 @@ def _run( # strategy will configure model and move it to the device self.strategy.setup(self) - # compile if compile_fn provided after configured strategy - if self.compile_fn is not None: - self.strategy.model = self.compile_fn(self.strategy.model) - + if compile_kwargs is not None: + self.strategy.model = _to_compiled(self.strategy.model, compile_kwargs) + # hook if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_fit_start") From 8db1a6fd7b51a9475e4514a1327b44e9f0e6ea71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 12:09:00 +0000 Subject: [PATCH 05/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/trainer/trainer.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 1b3326771f4d4..ceaa36d151520 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -35,6 +35,7 @@ from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.types import _PATH +from lightning.fabric.wrappers import _to_compiled, _unwrap_compiled from lightning.pytorch.accelerators import Accelerator from lightning.pytorch.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBar from lightning.pytorch.core.datamodule import LightningDataModule @@ -79,10 +80,6 @@ LRSchedulerConfig, ) from lightning.pytorch.utilities.warnings import PossibleUserWarning -from lightning.fabric.wrappers import ( - _unwrap_compiled, - _to_compiled -) log = logging.getLogger(__name__) @@ -131,7 +128,7 @@ def __init__( sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, - reapply_compile = False + reapply_compile=False, ) -> None: r"""Customize every aspect of training via flags. @@ -309,7 +306,7 @@ def __init__( default_root_dir = os.fspath(default_root_dir) self._reapply_compile = reapply_compile - + self.barebones = barebones if barebones: # opt-outs @@ -536,7 +533,9 @@ def fit( For more information about multiple dataloaders, see this :ref:`section `. """ - model, compile_kwargs = _unwrap_compiled(model) if self._reapply_compile else (_maybe_unwrap_optimized(model), None) + model, compile_kwargs = ( + _unwrap_compiled(model) if self._reapply_compile else (_maybe_unwrap_optimized(model), None) + ) self.strategy._lightning_module = model _verify_strategy_supports_compile(model, self.strategy) self.state.fn = TrainerFn.FITTING @@ -966,7 +965,7 @@ def _run( if compile_kwargs is not None: self.strategy.model = _to_compiled(self.strategy.model, compile_kwargs) - + # hook if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_fit_start") From f0f0a57536d344f153dc3203400c8d42a1e9c59f Mon Sep 17 00:00:00 2001 From: Mieszko Dziadowiec Date: Mon, 2 Dec 2024 11:49:35 +0100 Subject: [PATCH 06/14] Remove reapply_compile flag --- src/lightning/pytorch/trainer/trainer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index ceaa36d151520..c16a9e24a42c7 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -127,8 +127,7 @@ def __init__( plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, - default_root_dir: Optional[_PATH] = None, - reapply_compile=False, + default_root_dir: Optional[_PATH] = None ) -> None: r"""Customize every aspect of training via flags. @@ -305,8 +304,6 @@ def __init__( if default_root_dir is not None: default_root_dir = os.fspath(default_root_dir) - self._reapply_compile = reapply_compile - self.barebones = barebones if barebones: # opt-outs @@ -533,8 +530,9 @@ def fit( For more information about multiple dataloaders, see this :ref:`section `. """ + # when provided compiled model, unwrap and re-do after applied strategy model, compile_kwargs = ( - _unwrap_compiled(model) if self._reapply_compile else (_maybe_unwrap_optimized(model), None) + _unwrap_compiled(model) if isinstance(model, torch._dynamo.OptimizedModule) else (_maybe_unwrap_optimized(model), None) ) self.strategy._lightning_module = model _verify_strategy_supports_compile(model, self.strategy) @@ -548,7 +546,7 @@ def fit( def _fit_impl( self, model: "pl.LightningModule", - compile_kwargs, + compile_kwargs: Optional[Dict[str, Any]] = None, train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, @@ -909,7 +907,7 @@ def _predict_impl( return results def _run( - self, model: "pl.LightningModule", compile_kwargs, ckpt_path: Optional[_PATH] = None + self, model: "pl.LightningModule", compile_kwargs: Optional[Dict[str, Any]] = None, ckpt_path: Optional[_PATH] = None ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( @@ -963,6 +961,7 @@ def _run( # strategy will configure model and move it to the device self.strategy.setup(self) + # when provided compiled model, unwrap is done in fit method, re-apply compile after applying strategy if compile_kwargs is not None: self.strategy.model = _to_compiled(self.strategy.model, compile_kwargs) From 2c498f5023c9cec2dcf89cc94e0814d33fe58088 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Dec 2024 10:49:56 +0000 Subject: [PATCH 07/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/trainer/trainer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index c16a9e24a42c7..1302d30eb8080 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -127,7 +127,7 @@ def __init__( plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, - default_root_dir: Optional[_PATH] = None + default_root_dir: Optional[_PATH] = None, ) -> None: r"""Customize every aspect of training via flags. @@ -532,7 +532,9 @@ def fit( """ # when provided compiled model, unwrap and re-do after applied strategy model, compile_kwargs = ( - _unwrap_compiled(model) if isinstance(model, torch._dynamo.OptimizedModule) else (_maybe_unwrap_optimized(model), None) + _unwrap_compiled(model) + if isinstance(model, torch._dynamo.OptimizedModule) + else (_maybe_unwrap_optimized(model), None) ) self.strategy._lightning_module = model _verify_strategy_supports_compile(model, self.strategy) @@ -907,7 +909,10 @@ def _predict_impl( return results def _run( - self, model: "pl.LightningModule", compile_kwargs: Optional[Dict[str, Any]] = None, ckpt_path: Optional[_PATH] = None + self, + model: "pl.LightningModule", + compile_kwargs: Optional[Dict[str, Any]] = None, + ckpt_path: Optional[_PATH] = None, ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( From bbcaad15eb29771a85587d2c08d201cbab9fdcdd Mon Sep 17 00:00:00 2001 From: Mieszko Dziadowiec Date: Mon, 2 Dec 2024 12:12:22 +0100 Subject: [PATCH 08/14] Dict -> dict --- src/lightning/pytorch/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 1302d30eb8080..8935abe6d2088 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -548,7 +548,7 @@ def fit( def _fit_impl( self, model: "pl.LightningModule", - compile_kwargs: Optional[Dict[str, Any]] = None, + compile_kwargs: Optional[dict[str, Any]] = None, train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, @@ -911,7 +911,7 @@ def _predict_impl( def _run( self, model: "pl.LightningModule", - compile_kwargs: Optional[Dict[str, Any]] = None, + compile_kwargs: Optional[dict[str, Any]] = None, ckpt_path: Optional[_PATH] = None, ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if self.state.fn == TrainerFn.FITTING: From 809c6c43f5b28ec3877f540336c221c5429c98a9 Mon Sep 17 00:00:00 2001 From: Mieszko Dziadowiec Date: Fri, 10 Jan 2025 11:49:44 +0100 Subject: [PATCH 09/14] Test trainer rewrap compiled module over DDP strategy --- .../strategies/test_ddp_integration.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 836072d36be83..676234a07701b 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -18,6 +18,7 @@ import lightning.pytorch as pl import pytest import torch +from torch._dynamo import OptimizedModule from lightning.fabric.plugins.environments import ClusterEnvironment, LightningEnvironment from lightning.fabric.utilities.distributed import _distributed_is_initialized from lightning.pytorch import Trainer @@ -448,3 +449,31 @@ def creates_processes_externally(self): RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`." ): trainer.fit(model) + + +@RunIf(dynamo=True) +@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile)) +@mock.patch.dict(os.environ, {}) +def test_reapply_compile(tmp_path): + """Test that Trainer can rewrap a compiled module such that compilation happens over the DDP-wrapper.""" + trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp", max_steps=2, logger=False) + + model = BoringModel() + compile_kwargs = {"mode": "reduce-overhead"} + compiled_model = torch.compile(model, **compile_kwargs) + torch.compile.reset_mock() + + trainer.fit(compiled_model) + trainer_model = trainer.strategy.model + + assert isinstance(trainer_model, OptimizedModule) + assert isinstance(trainer_model._orig_mod, DistributedDataParallel) + # Assert we called compile again with the same arguments, but on the DDP-wrapped module + torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs) + + assert trainer_model._orig_mod.module == model + + # Smoke-testing forward to ensure we don't get compilation errors + for _ in range(3): + trainer_model(torch.randn(2, 32, device="cpu")).sum().backward() + assert True From c74bdab5d4c91f2a9f62316892217af143d9573b Mon Sep 17 00:00:00 2001 From: Mieszko Dziadowiec Date: Fri, 10 Jan 2025 12:15:23 +0100 Subject: [PATCH 10/14] Run DDP test_reapply_compile on gpu --- tests/tests_pytorch/strategies/test_ddp_integration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 118da4e7e9bb5..47e869f425b42 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -451,12 +451,12 @@ def creates_processes_externally(self): trainer.fit(model) -@RunIf(dynamo=True) +@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True) @mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile)) @mock.patch.dict(os.environ, {}) def test_reapply_compile(tmp_path): """Test that Trainer can rewrap a compiled module such that compilation happens over the DDP-wrapper.""" - trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp", max_steps=2, logger=False) + trainer = Trainer(accelerator="gpu", devices=2, strategy="ddp", max_steps=2, logger=False) model = BoringModel() compile_kwargs = {"mode": "reduce-overhead"} @@ -475,5 +475,5 @@ def test_reapply_compile(tmp_path): # Smoke-testing forward to ensure we don't get compilation errors for _ in range(3): - trainer_model(torch.randn(2, 32, device="cpu")).sum().backward() + trainer_model(torch.randn(2, 32, device="gpu")).sum().backward() assert True From 9bc5774586750d9b528d20ea70e889e2b99a0f03 Mon Sep 17 00:00:00 2001 From: Mieszko Dziadowiec Date: Fri, 10 Jan 2025 12:16:32 +0100 Subject: [PATCH 11/14] Add test for reapply_compile with FSDP on gpu --- tests/tests_pytorch/strategies/test_fsdp.py | 28 +++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index f3e88ca356764..14a2e76b55ef7 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -12,6 +12,7 @@ import pytest import torch import torch.nn as nn +from torch._dynamo import OptimizedModule from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap from torchmetrics import Accuracy @@ -971,3 +972,30 @@ def configure_optimizers(self): max_steps=4, ) trainer.fit(model, ckpt_path=checkpoint_path_full) + + +@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True) +@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile)) +@mock.patch.dict(os.environ, {}) +def test_reapply_compile(): + """Test that Trainer can rewrap a compiled module such that compilation happens over the FSDP-wrapper.""" + trainer = Trainer(accelerator="gpu", devices=2, strategy="fsdp", max_steps=2, logger=False) + + model = BoringModel() + compile_kwargs = {"mode": "reduce-overhead"} + compiled_model = torch.compile(model, **compile_kwargs) + torch.compile.reset_mock() + + trainer.fit(compiled_model) + trainer_model = trainer.strategy.model + + assert isinstance(trainer_model, OptimizedModule) + assert isinstance(trainer_model._orig_mod, FullyShardedDataParallel) + # Assert we called compile again with the same arguments, but on the FSDP-wrapped module + torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs) + + assert trainer_model._orig_mod.module == model + + # Smoke-testing forward to ensure we don't get compilation errors + for _ in range(3): + trainer_model(torch.randn(2, 32, device="gpu")).sum().backward() From 87c13776937c540b3e190a9859bbb87ece6598be Mon Sep 17 00:00:00 2001 From: Mieszko Dziadowiec Date: Fri, 10 Jan 2025 12:17:11 +0100 Subject: [PATCH 12/14] Update test_ddp_integration.py --- tests/tests_pytorch/strategies/test_ddp_integration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 47e869f425b42..23e83e95b1d7b 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -476,4 +476,3 @@ def test_reapply_compile(tmp_path): # Smoke-testing forward to ensure we don't get compilation errors for _ in range(3): trainer_model(torch.randn(2, 32, device="gpu")).sum().backward() - assert True From b17a3dc17159226f4430d99d18dd1fda974cd8a4 Mon Sep 17 00:00:00 2001 From: Mieszko Dziadowiec Date: Fri, 10 Jan 2025 15:34:49 +0100 Subject: [PATCH 13/14] Remove not used tmp_path argument --- tests/tests_pytorch/strategies/test_ddp_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 23e83e95b1d7b..bd57674782674 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -454,7 +454,7 @@ def creates_processes_externally(self): @RunIf(min_cuda_gpus=2, standalone=True, dynamo=True) @mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile)) @mock.patch.dict(os.environ, {}) -def test_reapply_compile(tmp_path): +def test_reapply_compile(): """Test that Trainer can rewrap a compiled module such that compilation happens over the DDP-wrapper.""" trainer = Trainer(accelerator="gpu", devices=2, strategy="ddp", max_steps=2, logger=False) From 8e73a21e6307cfa73fa3c62de8f259e2c3159d4f Mon Sep 17 00:00:00 2001 From: Mieszko Dziadowiec Date: Wed, 5 Feb 2025 10:09:27 +0100 Subject: [PATCH 14/14] test_trainer_compiled_model change Don't use compiler_ctx in case of OptimizedModule, unwrapping using fabric don't fill these fields --- tests/tests_pytorch/utilities/test_compile.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/utilities/test_compile.py b/tests/tests_pytorch/utilities/test_compile.py index a053c847dfd6c..995484f13f30d 100644 --- a/tests/tests_pytorch/utilities/test_compile.py +++ b/tests/tests_pytorch/utilities/test_compile.py @@ -46,18 +46,14 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0): model = BoringModel() compiled_model = torch.compile(model) - assert model._compiler_ctx is compiled_model._compiler_ctx # shared reference # can train with compiled model trainer = Trainer(**trainer_kwargs) trainer.fit(compiled_model) - assert trainer.model._compiler_ctx["compiler"] == "dynamo" + assert isinstance(trainer.strategy.model, torch._dynamo.OptimizedModule) # the compiled model can be uncompiled to_uncompiled_model = to_uncompiled(compiled_model) - assert model._compiler_ctx is None - assert compiled_model._compiler_ctx is None - assert to_uncompiled_model._compiler_ctx is None # the compiled model needs to be passed with pytest.raises(ValueError, match="required to be a compiled LightningModule"): @@ -66,7 +62,7 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0): # the uncompiled model can be fitted trainer = Trainer(**trainer_kwargs) trainer.fit(model) - assert trainer.model._compiler_ctx is None + assert not isinstance(trainer.strategy.model, torch._dynamo.OptimizedModule) # some strategies do not support it if RequirementCache("deepspeed"):