Skip to content

Commit 326e80b

Browse files
committed
feat: support __replace__ for Version
Co-authored-by: Damian Shaw <[email protected]> Signed-off-by: Henry Schreiner <[email protected]>
1 parent fd469c3 commit 326e80b

File tree

2 files changed

+304
-1
lines changed

2 files changed

+304
-1
lines changed

src/packaging/version.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@
99

1010
from __future__ import annotations
1111

12+
import copy
1213
import re
1314
import sys
14-
from typing import Any, Callable, SupportsInt, Tuple, Union
15+
import typing
16+
from typing import Any, Callable, Literal, SupportsInt, Tuple, TypedDict, Union
1517

1618
from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType
1719

20+
if typing.TYPE_CHECKING:
21+
from typing_extensions import Self, Unpack
22+
1823
__all__ = ["VERSION_PATTERN", "InvalidVersion", "Version", "parse"]
1924

2025
LocalType = Tuple[Union[int, str], ...]
@@ -35,6 +40,15 @@
3540
VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool]
3641

3742

43+
class _VersionReplace(TypedDict, total=False):
44+
epoch: int | None
45+
release: tuple[int, ...] | None
46+
pre: tuple[Literal["a", "b", "rc"], int] | None
47+
post: int | None
48+
dev: int | None
49+
local: str | None
50+
51+
3852
def parse(version: str) -> Version:
3953
"""Parse the given version string.
4054
@@ -164,6 +178,10 @@ def __ne__(self, other: object) -> bool:
164178
"""
165179

166180

181+
# Validation pattern for local version in replace()
182+
_LOCAL_PATTERN = re.compile(r"[a-z0-9]+(?:[._-][a-z0-9]+)*", re.IGNORECASE)
183+
184+
167185
class Version(_BaseVersion):
168186
"""This class abstracts handling of a project's versions.
169187
@@ -227,6 +245,100 @@ def __init__(self, version: str) -> None:
227245
# Key which will be used for sorting
228246
self._key_cache = None
229247

248+
def __replace__(self, **kwargs: Unpack[_VersionReplace]) -> Self:
249+
new_version = self.__class__.__new__(self.__class__)
250+
new_version._key_cache = None
251+
if "epoch" in kwargs:
252+
epoch = kwargs["epoch"] or 0
253+
if isinstance(epoch, int) and epoch >= 0: # type: ignore[redundant-expr]
254+
new_version._epoch = epoch
255+
else:
256+
msg = f"epoch must be non-negative integer, got {epoch}"
257+
raise InvalidVersion(msg)
258+
else:
259+
new_version._epoch = self._epoch
260+
261+
if "release" in kwargs:
262+
release = (0,) if kwargs["release"] is None else kwargs["release"]
263+
if (
264+
isinstance(release, tuple) # type: ignore[redundant-expr]
265+
and len(release) > 0
266+
and all(isinstance(i, int) and i >= 0 for i in release) # type: ignore[redundant-expr]
267+
):
268+
new_version._release = release
269+
else:
270+
msg = (
271+
"release must be a non-empty tuple of non-negative integers,"
272+
f" got {release}"
273+
)
274+
raise InvalidVersion(msg)
275+
else:
276+
new_version._release = self._release
277+
278+
if "pre" in kwargs:
279+
pre = kwargs["pre"]
280+
if pre is None or (
281+
(
282+
isinstance(pre, tuple) # type: ignore[redundant-expr]
283+
and len(pre) == 2 # type: ignore[redundant-expr]
284+
and pre[0] in ("a", "b", "rc")
285+
and isinstance(pre[1], int)
286+
)
287+
and pre[1] >= 0
288+
):
289+
new_version._pre = pre
290+
else:
291+
msg = (
292+
"pre must be a tuple of ('a'|'b'|'rc', non-negative int),"
293+
f" got {pre}"
294+
)
295+
raise InvalidVersion(msg)
296+
else:
297+
new_version._pre = self._pre
298+
299+
if "post" in kwargs:
300+
post = kwargs["post"]
301+
if post is None:
302+
new_version._post = None
303+
elif isinstance(post, int) and post >= 0: # type: ignore[redundant-expr]
304+
new_version._post = ("post", post)
305+
else:
306+
msg = f"post must be non-negative integer, got {post}"
307+
raise InvalidVersion(msg)
308+
else:
309+
new_version._post = self._post
310+
311+
if "dev" in kwargs:
312+
dev = kwargs["dev"]
313+
if dev is None:
314+
new_version._dev = None
315+
elif isinstance(dev, int) and dev >= 0: # type: ignore[redundant-expr]
316+
new_version._dev = ("dev", dev)
317+
else:
318+
msg = f"dev must be non-negative integer, got {dev}"
319+
raise InvalidVersion(msg)
320+
else:
321+
new_version._dev = self._dev
322+
323+
if "local" in kwargs:
324+
local = kwargs["local"]
325+
if local is None:
326+
new_version._local = None
327+
elif isinstance(local, str) and _LOCAL_PATTERN.fullmatch(local): # type: ignore[redundant-expr]
328+
new_version._local = _parse_local_version(local)
329+
else:
330+
msg = f"local must be a valid version string, got {local!r}"
331+
raise InvalidVersion(msg)
332+
else:
333+
new_version._local = self._local
334+
335+
return new_version
336+
337+
def replace(self, **kwargs: Unpack[_VersionReplace]) -> Self:
338+
if sys.version_info >= (3, 13):
339+
return copy.replace(self, **kwargs)
340+
return self.__replace__(**kwargs)
341+
230342
@property
231343
def _key(self) -> CmpKey:
232344
if self._key_cache is None:

tests/test_version.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,3 +775,194 @@ def test_micro_version(self) -> None:
775775
assert Version("2.1.3").micro == 3
776776
assert Version("2.1").micro == 0
777777
assert Version("2").micro == 0
778+
779+
# Tests for replace() method
780+
def test_replace_no_args(self) -> None:
781+
"""replace() with no arguments should return an equivalent version"""
782+
v = Version("1.2.3a1.post2.dev3+local")
783+
v_replaced = v.replace()
784+
assert v == v_replaced
785+
assert str(v) == str(v_replaced)
786+
787+
def test_replace_epoch(self) -> None:
788+
v = Version("1.2.3")
789+
assert str(v.replace(epoch=2)) == "2!1.2.3"
790+
assert v.replace(epoch=0).epoch == 0
791+
792+
v_with_epoch = Version("1!1.2.3")
793+
assert str(v_with_epoch.replace(epoch=2)) == "2!1.2.3"
794+
assert str(v_with_epoch.replace(epoch=None)) == "1.2.3"
795+
796+
def test_replace_release_tuple(self) -> None:
797+
v = Version("1.2.3")
798+
assert str(v.replace(release=(2, 0, 0))) == "2.0.0"
799+
assert str(v.replace(release=(1,))) == "1"
800+
assert str(v.replace(release=(1, 2, 3, 4, 5))) == "1.2.3.4.5"
801+
802+
def test_replace_release_none(self) -> None:
803+
v = Version("1.2.3")
804+
assert str(v.replace(release=None)) == "0"
805+
806+
def test_replace_pre_alpha(self) -> None:
807+
v = Version("1.2.3")
808+
assert str(v.replace(pre=("a", 1))) == "1.2.3a1"
809+
assert str(v.replace(pre=("a", 0))) == "1.2.3a0"
810+
811+
def test_replace_pre_alpha_none(self) -> None:
812+
v = Version("1.2.3a1")
813+
assert str(v.replace(pre=None)) == "1.2.3"
814+
815+
def test_replace_pre_beta(self) -> None:
816+
v = Version("1.2.3")
817+
assert str(v.replace(pre=("b", 1))) == "1.2.3b1"
818+
assert str(v.replace(pre=("b", 0))) == "1.2.3b0"
819+
820+
def test_replace_pre_beta_none(self) -> None:
821+
v = Version("1.2.3b1")
822+
assert str(v.replace(pre=None)) == "1.2.3"
823+
824+
def test_replace_pre_rc(self) -> None:
825+
v = Version("1.2.3")
826+
assert str(v.replace(pre=("rc", 1))) == "1.2.3rc1"
827+
assert str(v.replace(pre=("rc", 0))) == "1.2.3rc0"
828+
829+
def test_replace_pre_rc_none(self) -> None:
830+
v = Version("1.2.3rc1")
831+
assert str(v.replace(pre=None)) == "1.2.3"
832+
833+
def test_replace_post(self) -> None:
834+
v = Version("1.2.3")
835+
assert str(v.replace(post=1)) == "1.2.3.post1"
836+
assert str(v.replace(post=0)) == "1.2.3.post0"
837+
838+
def test_replace_post_none(self) -> None:
839+
v = Version("1.2.3.post1")
840+
assert str(v.replace(post=None)) == "1.2.3"
841+
842+
def test_replace_dev(self) -> None:
843+
v = Version("1.2.3")
844+
assert str(v.replace(dev=1)) == "1.2.3.dev1"
845+
assert str(v.replace(dev=0)) == "1.2.3.dev0"
846+
847+
def test_replace_dev_none(self) -> None:
848+
v = Version("1.2.3.dev1")
849+
assert str(v.replace(dev=None)) == "1.2.3"
850+
851+
def test_replace_local_string(self) -> None:
852+
v = Version("1.2.3")
853+
assert str(v.replace(local="abc")) == "1.2.3+abc"
854+
assert str(v.replace(local="abc.123")) == "1.2.3+abc.123"
855+
assert str(v.replace(local="abc-123")) == "1.2.3+abc.123"
856+
857+
def test_replace_local_none(self) -> None:
858+
v = Version("1.2.3+local")
859+
assert str(v.replace(local=None)) == "1.2.3"
860+
861+
def test_replace_multiple_components(self) -> None:
862+
v = Version("1.2.3")
863+
assert str(v.replace(pre=("a", 1), post=1)) == "1.2.3a1.post1"
864+
assert str(v.replace(release=(2, 0, 0), pre=("b", 2), dev=1)) == "2.0.0b2.dev1"
865+
assert str(v.replace(epoch=1, release=(3, 0), local="abc")) == "1!3.0+abc"
866+
867+
def test_replace_clear_all_optional(self) -> None:
868+
v = Version("1!1.2.3a1.post2.dev3+local")
869+
cleared = v.replace(epoch=None, pre=None, post=None, dev=None, local=None)
870+
assert str(cleared) == "1.2.3"
871+
872+
def test_replace_preserves_comparison(self) -> None:
873+
v1 = Version("1.2.3")
874+
v2 = Version("1.2.4")
875+
876+
v1_new = v1.replace(release=(1, 2, 4))
877+
assert v1_new == v2
878+
assert v1 < v2
879+
assert v1_new >= v2
880+
881+
def test_replace_preserves_hash(self) -> None:
882+
v1 = Version("1.2.3")
883+
v2 = v1.replace(release=(1, 2, 3))
884+
assert hash(v1) == hash(v2)
885+
886+
v3 = v1.replace(release=(2, 0, 0))
887+
assert hash(v1) != hash(v3)
888+
889+
def test_replace_change_pre_type(self) -> None:
890+
"""Can change from one pre-release type to another"""
891+
v = Version("1.2.3a1")
892+
assert str(v.replace(pre=("b", 2))) == "1.2.3b2"
893+
assert str(v.replace(pre=("rc", 1))) == "1.2.3rc1"
894+
895+
v2 = Version("1.2.3rc5")
896+
assert str(v2.replace(pre=("a", 0))) == "1.2.3a0"
897+
898+
def test_replace_invalid_epoch_type(self) -> None:
899+
v = Version("1.2.3")
900+
with pytest.raises(InvalidVersion, match="epoch must be non-negative"):
901+
v.replace(epoch="1") # type: ignore[arg-type]
902+
903+
def test_replace_invalid_post_type(self) -> None:
904+
v = Version("1.2.3")
905+
with pytest.raises(InvalidVersion, match="post must be non-negative"):
906+
v.replace(post="1") # type: ignore[arg-type]
907+
908+
def test_replace_invalid_dev_type(self) -> None:
909+
v = Version("1.2.3")
910+
with pytest.raises(InvalidVersion, match="dev must be non-negative"):
911+
v.replace(dev="1") # type: ignore[arg-type]
912+
913+
def test_replace_invalid_epoch_negative(self) -> None:
914+
v = Version("1.2.3")
915+
with pytest.raises(InvalidVersion, match="epoch must be non-negative"):
916+
v.replace(epoch=-1)
917+
918+
def test_replace_invalid_release_empty(self) -> None:
919+
v = Version("1.2.3")
920+
with pytest.raises(InvalidVersion, match="release must be a non-empty tuple"):
921+
v.replace(release=())
922+
923+
def test_replace_invalid_release_tuple_content(self) -> None:
924+
v = Version("1.2.3")
925+
with pytest.raises(
926+
InvalidVersion, match="release must be a non-empty tuple of non-negative"
927+
):
928+
v.replace(release=(1, -2, 3))
929+
930+
def test_replace_invalid_pre_negative(self) -> None:
931+
v = Version("1.2.3")
932+
with pytest.raises(InvalidVersion, match="pre must be a tuple"):
933+
v.replace(pre=("a", -1))
934+
935+
def test_replace_invalid_pre_type(self) -> None:
936+
v = Version("1.2.3")
937+
with pytest.raises(InvalidVersion, match="pre must be a tuple"):
938+
v.replace(pre=("x", 1)) # type: ignore[arg-type]
939+
940+
def test_replace_invalid_pre_format(self) -> None:
941+
v = Version("1.2.3")
942+
with pytest.raises(InvalidVersion, match="pre must be a tuple"):
943+
v.replace(pre="a1") # type: ignore[arg-type]
944+
with pytest.raises(InvalidVersion, match="pre must be a tuple"):
945+
v.replace(pre=("a",)) # type: ignore[arg-type]
946+
with pytest.raises(InvalidVersion, match="pre must be a tuple"):
947+
v.replace(pre=("a", 1, 2)) # type: ignore[arg-type]
948+
949+
def test_replace_invalid_post_negative(self) -> None:
950+
v = Version("1.2.3")
951+
with pytest.raises(InvalidVersion, match="post must be non-negative"):
952+
v.replace(post=-1)
953+
954+
def test_replace_invalid_dev_negative(self) -> None:
955+
v = Version("1.2.3")
956+
with pytest.raises(InvalidVersion, match="dev must be non-negative"):
957+
v.replace(dev=-1)
958+
959+
def test_replace_invalid_local_string(self) -> None:
960+
v = Version("1.2.3")
961+
with pytest.raises(
962+
InvalidVersion, match="local must be a valid version string"
963+
):
964+
v.replace(local="abc+123")
965+
with pytest.raises(
966+
InvalidVersion, match="local must be a valid version string"
967+
):
968+
v.replace(local="+abc")

0 commit comments

Comments
 (0)