Skip to content

Commit 6042e0c

Browse files
authored
feat: enhance custom encoders to accept _parent and _sort_keys parameters (#436)
Resolve #429 Signed-off-by: Frost Ming <[email protected]>
1 parent 424dd0d commit 6042e0c

File tree

4 files changed

+157
-9
lines changed

4 files changed

+157
-9
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## [Unreleased]
44

5+
### Added
6+
7+
- Custom encoders can now receive `_parent` and `_sort_keys` parameters to enable proper encoding of nested structures. ([#429](https://github.com/python-poetry/tomlkit/issues/429))
8+
59
## [0.13.3] - 2025-06-05
610

711
### Added

tests/test_items.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,120 @@ def encode_decimal(obj):
986986
api.unregister_encoder(encode_decimal)
987987

988988

989+
def test_custom_encoders_with_parent_and_sort_keys():
990+
"""Test that custom encoders can receive _parent and _sort_keys parameters."""
991+
import decimal
992+
993+
parent_captured = None
994+
sort_keys_captured = None
995+
996+
@api.register_encoder
997+
def encode_decimal_with_context(obj, _parent=None, _sort_keys=False):
998+
nonlocal parent_captured, sort_keys_captured
999+
if isinstance(obj, decimal.Decimal):
1000+
parent_captured = _parent
1001+
sort_keys_captured = _sort_keys
1002+
return api.float_(str(obj))
1003+
raise TypeError
1004+
1005+
# Test with default parameters
1006+
result = api.item(decimal.Decimal("1.23"))
1007+
assert result.as_string() == "1.23"
1008+
assert parent_captured is None
1009+
assert sort_keys_captured is False
1010+
1011+
# Test with custom parent and sort_keys
1012+
parent_captured = None
1013+
sort_keys_captured = None
1014+
table = api.table()
1015+
result = item(decimal.Decimal("4.56"), _parent=table, _sort_keys=True)
1016+
assert result.as_string() == "4.56"
1017+
assert parent_captured is table
1018+
assert sort_keys_captured is True
1019+
1020+
api.unregister_encoder(encode_decimal_with_context)
1021+
1022+
1023+
def test_custom_encoders_backward_compatibility():
1024+
"""Test that old-style custom encoders still work without modification."""
1025+
import decimal
1026+
1027+
@api.register_encoder
1028+
def encode_decimal_old_style(obj):
1029+
# Old style encoder - only accepts obj parameter
1030+
if isinstance(obj, decimal.Decimal):
1031+
return api.float_(str(obj))
1032+
raise TypeError
1033+
1034+
# Should work exactly as before
1035+
result = api.item(decimal.Decimal("2.34"))
1036+
assert result.as_string() == "2.34"
1037+
1038+
# Should work when called from item() with extra parameters
1039+
table = api.table()
1040+
result = item(decimal.Decimal("5.67"), _parent=table, _sort_keys=True)
1041+
assert result.as_string() == "5.67"
1042+
1043+
api.unregister_encoder(encode_decimal_old_style)
1044+
1045+
1046+
def test_custom_encoders_with_kwargs():
1047+
"""Test that custom encoders can use **kwargs to accept additional parameters."""
1048+
import decimal
1049+
1050+
kwargs_captured = None
1051+
1052+
@api.register_encoder
1053+
def encode_decimal_with_kwargs(obj, **kwargs):
1054+
nonlocal kwargs_captured
1055+
if isinstance(obj, decimal.Decimal):
1056+
kwargs_captured = kwargs
1057+
return api.float_(str(obj))
1058+
raise TypeError
1059+
1060+
# Test with parent and sort_keys passed as kwargs
1061+
table = api.table()
1062+
result = item(decimal.Decimal("7.89"), _parent=table, _sort_keys=True)
1063+
assert result.as_string() == "7.89"
1064+
assert kwargs_captured == {"_parent": table, "_sort_keys": True}
1065+
1066+
api.unregister_encoder(encode_decimal_with_kwargs)
1067+
1068+
1069+
def test_custom_encoders_for_complex_objects():
1070+
"""Test custom encoders that need to encode nested structures."""
1071+
1072+
class CustomDict:
1073+
def __init__(self, data):
1074+
self.data = data
1075+
1076+
@api.register_encoder
1077+
def encode_custom_dict(obj, _parent=None, _sort_keys=False):
1078+
if isinstance(obj, CustomDict):
1079+
# Create a table and use item() to convert nested values
1080+
table = api.table()
1081+
for key, value in obj.data.items():
1082+
# Pass along _parent and _sort_keys when converting nested values
1083+
table[key] = item(value, _parent=table, _sort_keys=_sort_keys)
1084+
return table
1085+
raise TypeError
1086+
1087+
# Test with nested structure
1088+
custom_obj = CustomDict({"a": 1, "b": {"c": 2, "d": 3}})
1089+
result = item(custom_obj, _sort_keys=True)
1090+
1091+
# Should properly format as a table with sorted keys
1092+
expected = """a = 1
1093+
1094+
[b]
1095+
c = 2
1096+
d = 3
1097+
"""
1098+
assert result.as_string() == expected
1099+
1100+
api.unregister_encoder(encode_custom_dict)
1101+
1102+
9891103
def test_no_extra_minus_sign():
9901104
doc = parse("a = -1")
9911105
assert doc.as_string() == "a = -1"

tomlkit/api.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from collections.abc import Mapping
77
from typing import IO
8+
from typing import TYPE_CHECKING
89
from typing import Iterable
910
from typing import TypeVar
1011

@@ -19,7 +20,6 @@
1920
from tomlkit.items import Date
2021
from tomlkit.items import DateTime
2122
from tomlkit.items import DottedKey
22-
from tomlkit.items import Encoder
2323
from tomlkit.items import Float
2424
from tomlkit.items import InlineTable
2525
from tomlkit.items import Integer
@@ -37,6 +37,12 @@
3737
from tomlkit.toml_document import TOMLDocument
3838

3939

40+
if TYPE_CHECKING:
41+
from tomlkit.items import Encoder
42+
43+
E = TypeVar("E", bound=Encoder)
44+
45+
4046
def loads(string: str | bytes) -> TOMLDocument:
4147
"""
4248
Parses a string into a TOMLDocument.
@@ -294,13 +300,22 @@ def comment(string: str) -> Comment:
294300
return Comment(Trivia(comment_ws=" ", comment="# " + string))
295301

296302

297-
E = TypeVar("E", bound=Encoder)
298-
299-
300303
def register_encoder(encoder: E) -> E:
301304
"""Add a custom encoder, which should be a function that will be called
302-
if the value can't otherwise be converted. It should takes a single value
303-
and return a TOMLKit item or raise a ``ConvertError``.
305+
if the value can't otherwise be converted.
306+
307+
The encoder should return a TOMLKit item or raise a ``ConvertError``.
308+
309+
Example:
310+
@register_encoder
311+
def encode_custom_dict(obj, _parent=None, _sort_keys=False):
312+
if isinstance(obj, CustomDict):
313+
tbl = table()
314+
for key, value in obj.items():
315+
# Pass along parameters when encoding nested values
316+
tbl[key] = item(value, _parent=tbl, _sort_keys=_sort_keys)
317+
return tbl
318+
raise ConvertError("Not a CustomDict")
304319
"""
305320
CUSTOM_ENCODERS.append(encoder)
306321
return encoder

tomlkit/items.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import abc
44
import copy
55
import dataclasses
6+
import inspect
67
import math
78
import re
89
import string
@@ -15,7 +16,6 @@
1516
from enum import Enum
1617
from typing import TYPE_CHECKING
1718
from typing import Any
18-
from typing import Callable
1919
from typing import Collection
2020
from typing import Iterable
2121
from typing import Iterator
@@ -38,11 +38,17 @@
3838

3939

4040
if TYPE_CHECKING:
41+
from typing import Protocol
42+
4143
from tomlkit import container
4244

45+
class Encoder(Protocol):
46+
def __call__(
47+
self, __value: Any, _parent: Item | None = None, _sort_keys: bool = False
48+
) -> Item: ...
49+
4350

4451
ItemT = TypeVar("ItemT", bound="Item")
45-
Encoder = Callable[[Any], "Item"]
4652
CUSTOM_ENCODERS: list[Encoder] = []
4753
AT = TypeVar("AT", bound="AbstractTable")
4854

@@ -199,7 +205,16 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I
199205
else:
200206
for encoder in CUSTOM_ENCODERS:
201207
try:
202-
rv = encoder(value)
208+
# Check if encoder accepts keyword arguments for backward compatibility
209+
sig = inspect.signature(encoder)
210+
if "_parent" in sig.parameters or any(
211+
p.kind == p.VAR_KEYWORD for p in sig.parameters.values()
212+
):
213+
# New style encoder that can accept additional parameters
214+
rv = encoder(value, _parent=_parent, _sort_keys=_sort_keys)
215+
else:
216+
# Old style encoder that only accepts value
217+
rv = encoder(value)
203218
except ConvertError:
204219
pass
205220
else:

0 commit comments

Comments
 (0)