Skip to content

Commit 38971a0

Browse files
authored
Patch jsonargparse for Python >= 3.12.8 (#20479)
* Patch argparse _parse_known_args * Add patch to test * Avoid importing lightning in assistant * Fix return type
1 parent c09fc66 commit 38971a0

File tree

8 files changed

+55
-21
lines changed

8 files changed

+55
-21
lines changed

.actions/assistant.py

+15
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,21 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None:
483483

484484

485485
if __name__ == "__main__":
486+
import sys
487+
486488
import jsonargparse
489+
from jsonargparse import ArgumentParser
490+
491+
def patch_jsonargparse_python_3_12_8():
492+
if sys.version_info < (3, 12, 8):
493+
return
494+
495+
def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
496+
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
497+
return namespace, args
498+
499+
setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)
500+
501+
patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641
487502

488503
jsonargparse.CLI(AssistantCLI, as_positional=False)

.github/workflows/ci-tests-fabric.yml

+9-9
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ jobs:
4949
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5050
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5151
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
52-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
53-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
54-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
55-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
56-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
57-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
52+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
53+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
54+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
55+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
56+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
57+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
5858
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
59-
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
60-
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
61-
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" }
59+
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
60+
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
61+
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" }
6262
# "oldest" versions tests, only on minimum Python
6363
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
6464
- {

.github/workflows/ci-tests-pytorch.yml

+9-9
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ jobs:
5353
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5454
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5555
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
56-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
57-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
58-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
59-
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
60-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
61-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
56+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
57+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
58+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
59+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
60+
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
61+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
6262
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
63-
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
64-
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
65-
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" }
63+
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
64+
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
65+
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" }
6666
# "oldest" versions tests, only on minimum Python
6767
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
6868
- {

examples/fabric/tensor_parallel/train.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import lightning as L
22
import torch
33
import torch.nn.functional as F
4-
from data import RandomTokenDataset
54
from lightning.fabric.strategies import ModelParallelStrategy
65
from model import ModelArgs, Transformer
76
from parallelism import parallelize
87
from torch.distributed.tensor.parallel import loss_parallel
98
from torch.utils.data import DataLoader
109

10+
from data import RandomTokenDataset
11+
1112

1213
def train():
1314
strategy = ModelParallelStrategy(

examples/pytorch/tensor_parallel/train.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import lightning as L
22
import torch
33
import torch.nn.functional as F
4-
from data import RandomTokenDataset
54
from lightning.pytorch.strategies import ModelParallelStrategy
65
from model import ModelArgs, Transformer
76
from parallelism import parallelize
87
from torch.distributed.tensor.parallel import loss_parallel
98
from torch.utils.data import DataLoader
109

10+
from data import RandomTokenDataset
11+
1112

1213
class Llama3(L.LightningModule):
1314
def __init__(self):

src/lightning/pytorch/cli.py

+14
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,18 @@
3737

3838
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.7")
3939

40+
41+
def patch_jsonargparse_python_3_12_8() -> None:
42+
if sys.version_info < (3, 12, 8):
43+
return
44+
45+
def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
46+
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
47+
return namespace, args
48+
49+
setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)
50+
51+
4052
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
4153
import docstring_parser
4254
from jsonargparse import (
@@ -48,6 +60,8 @@
4860
set_config_read_mode,
4961
)
5062

63+
patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641
64+
5165
register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483
5266
set_config_read_mode(fsspec_enabled=True)
5367
else:

tests/parity_fabric/test_parity_ddp.py

+3
Original file line numberDiff line numberDiff line change
@@ -162,5 +162,8 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float
162162

163163
if __name__ == "__main__":
164164
from jsonargparse import CLI
165+
from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8
166+
167+
patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641
165168

166169
CLI(run_parity_test)

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
import pytest
3030
import torch
3131
import yaml
32-
from jsonargparse import ArgumentParser
3332
from lightning.fabric.utilities.cloud_io import _load as pl_load
3433
from lightning.pytorch import Trainer, seed_everything
3534
from lightning.pytorch.callbacks import ModelCheckpoint
35+
from lightning.pytorch.cli import LightningArgumentParser as ArgumentParser
3636
from lightning.pytorch.demos.boring_classes import BoringModel
3737
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
3838
from lightning.pytorch.utilities.exceptions import MisconfigurationException

0 commit comments

Comments
 (0)