diff --git a/.lightning/workflows/pytorch.yml b/.lightning/workflows/pytorch.yml index 15dfc4a1f9064..59b485946721e 100644 --- a/.lightning/workflows/pytorch.yml +++ b/.lightning/workflows/pytorch.yml @@ -121,7 +121,12 @@ run: | echo "Install package" extra=$(python -c "print({'lightning': 'pytorch-'}.get('${PACKAGE_NAME}', ''))") - uv pip install -e ".[${extra}dev]" --upgrade + + # Use find-links to prefer CUDA-specific packages from PyTorch index + uv pip install -e ".[${extra}dev]" --upgrade \ + --find-links="https://download.pytorch.org/whl/${UV_TORCH_BACKEND}" \ + --find-links="https://download.pytorch.org/whl/${UV_TORCH_BACKEND}/torch-tensorrt" + uv pip list echo "Ensure only a single package is installed" if [ "${PACKAGE_NAME}" == "pytorch" ]; then diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index d187769117be6..bcf69b2503e5d 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.9.0 +torch >=2.1.0, <2.10.0 fsspec[http] >=2022.5.0, <2025.11.0 packaging >=20.0, <=25.0 typing-extensions >4.5.0, <4.16.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 65a3f7484fb7d..634e7da0ffb34 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.9.0 +torch >=2.1.0, <2.10.0 tqdm >=4.57.0, <4.68.0 PyYAML >5.4, <6.1.0 fsspec[http] >=2022.5.0, <2025.11.0 diff --git a/requirements/typing.txt b/requirements/typing.txt index dc848c55e583d..8c5ad38fb7825 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ mypy==1.18.2 -torch==2.8.0 +torch==2.9.0 types-Markdown types-PyYAML diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py index 9c1b0a2a00572..cd2e05309e087 100644 --- a/src/lightning/fabric/utilities/spike.py +++ b/src/lightning/fabric/utilities/spike.py @@ -126,10 +126,10 @@ def _handle_spike(self, fabric: "Fabric", batch_idx: int) -> None: raise TrainingSpikeException(batch_idx=batch_idx) def _check_atol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool: - return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol)) + return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol)) # type: ignore def _check_rtol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool: - return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b)) + return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b)) # type: ignore def _is_better(self, diff_val: torch.Tensor) -> bool: if self.mode == "min": diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index ab99457eee2d1..f7305fbda8f90 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -29,6 +29,7 @@ _TORCHMETRICS_GREATER_EQUAL_1_0_0 = RequirementCache("torchmetrics>=1.0.0") _TORCH_EQUAL_2_8 = RequirementCache("torch>=2.8.0,<2.9.0") _TORCH_EQUAL_2_9 = RequirementCache("torch>=2.9.0,<2.10.0") +_TORCH_GREATER_EQUAL_2_8 = compare_version("torch", operator.ge, "2.8.0") _OMEGACONF_AVAILABLE = package_available("omegaconf") _TORCHVISION_AVAILABLE = RequirementCache("torchvision") diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 0564632b315a7..323cbecc0ee4e 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -117,6 +117,7 @@ def thread_police_duuu_daaa_duuu_daaa(): sys.version_info >= (3, 9) and isinstance(thread, _ExecutorManagerThread) or "ThreadPoolExecutor-" in thread.name + or thread.name == "InductorSubproc" # torch.compile ): # probably `torch.compile`, can't narrow it down further continue diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 53dd20e7e4399..518670c8d2483 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -173,6 +173,7 @@ def thread_police_duuu_daaa_duuu_daaa(): sys.version_info >= (3, 9) and isinstance(thread, _ExecutorManagerThread) or "ThreadPoolExecutor-" in thread.name + or thread.name == "InductorSubproc" # torch.compile ): # probably `torch.compile`, can't narrow it down further continue diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 76204784cce0a..cf2ff39c698c7 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -14,7 +14,7 @@ import pytest from lightning.fabric.utilities.imports import _IS_WINDOWS -from lightning.pytorch.utilities.imports import _TORCH_EQUAL_2_8, _TORCH_EQUAL_2_9 +from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_8 from lightning.pytorch.utilities.testing import _runif_reasons @@ -27,6 +27,6 @@ def RunIf(**kwargs): _xfail_gloo_windows = pytest.mark.xfail( RuntimeError, strict=True, - condition=(_IS_WINDOWS and (_TORCH_EQUAL_2_8 or _TORCH_EQUAL_2_9)), + condition=(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_8), reason="makeDeviceForHostname(): unsupported gloo device", ) diff --git a/tests/tests_pytorch/models/test_torch_tensorrt.py b/tests/tests_pytorch/models/test_torch_tensorrt.py index 630e59f711348..1ab8948d4482e 100644 --- a/tests/tests_pytorch/models/test_torch_tensorrt.py +++ b/tests/tests_pytorch/models/test_torch_tensorrt.py @@ -10,6 +10,7 @@ from lightning.pytorch.core.module import _TORCH_TRT_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.imports import _TORCH_EQUAL_2_9 from tests_pytorch.helpers.runif import RunIf @@ -110,7 +111,14 @@ def test_tensorrt_saves_on_multi_gpu(tmp_path): [ ("default", torch.fx.GraphModule), ("dynamo", torch.fx.GraphModule), - ("ts", torch.jit.ScriptModule), + pytest.param( + "ts", + torch.jit.ScriptModule, + marks=pytest.mark.skipif( + _TORCH_EQUAL_2_9, + reason="TorchScript IR crashes with torch_tensorrt on PyTorch 2.9", + ), + ), ], ) @RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") @@ -128,7 +136,17 @@ def test_tensorrt_save_ir_type(ir, export_type): ) @pytest.mark.parametrize( "ir", - ["default", "dynamo", "ts"], + [ + "default", + "dynamo", + pytest.param( + "ts", + marks=pytest.mark.skipif( + _TORCH_EQUAL_2_9, + reason="TorchScript IR crashes with torch_tensorrt on PyTorch 2.9", + ), + ), + ], ) @RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") def test_tensorrt_export_reload(output_format, ir, tmp_path): diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index ea3c31a370fce..76860fd82733f 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -55,7 +55,7 @@ from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_EQUAL_2_8 +from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_2_8 from tests_pytorch.conftest import mock_cuda_count, mock_mps_count from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -1730,7 +1730,12 @@ def test_exception_when_lightning_module_is_not_set_on_trainer(fn): @RunIf(min_cuda_gpus=1) # FixMe: the memory raises to 1024 from expected 512 -@pytest.mark.xfail(AssertionError, strict=True, condition=_TORCH_EQUAL_2_8, reason="temporarily disabled for torch 2.8") +@pytest.mark.xfail( + AssertionError, + strict=True, + condition=_TORCH_GREATER_EQUAL_2_8, + reason="temporarily disabled for torch >= 2.8", +) def test_multiple_trainer_constant_memory_allocated(tmp_path): """This tests ensures calling the trainer several times reset the memory back to 0."""