diff --git a/pyk/src/pyk/kast/outer.py b/pyk/src/pyk/kast/outer.py index a0c8539a16..f22194d104 100644 --- a/pyk/src/pyk/kast/outer.py +++ b/pyk/src/pyk/kast/outer.py @@ -17,6 +17,7 @@ from .att import EMPTY_ATT, Atts, Format, KAst, KAtt, WithKAtt from .inner import ( KApply, + KAs, KInner, KLabel, KRewrite, @@ -1000,6 +1001,45 @@ def let(self, *, require: str | None = None) -> KRequire: return KRequire(require=require) +def _sort_contains(sort: KSort, param: KSort) -> bool: + """Return whether ``param`` appears anywhere in the sort tree of ``sort``.""" + return sort == param or any(_sort_contains(p, param) for p in sort.params) + + +def _match_sort_params( + parametric: KSort, + actual: KSort, + params: frozenset[KSort], + subsorts_fn: Callable[[KSort], frozenset[KSort]] | None = None, +) -> dict[KSort, list[KSort]]: + """Match ``parametric`` sort against ``actual``, collecting candidate bindings per sort param. + + Three matching strategies, mirroring Java ``AddSortInjections.match()``: + + 1. Direct: ``parametric`` is itself a sort param — bind it to ``actual``. + 2. Structural: same constructor head — recurse on sub-params. + 3. Subsort-aware: iterate subsorts ``s ≤ actual`` with same head as ``parametric``, + collecting additional candidates for LUB resolution. + """ + if parametric in params: + return {parametric: [actual]} + if parametric.name == actual.name and len(parametric.params) == len(actual.params): + result: dict[KSort, list[KSort]] = {} + for p_sub, a_sub in zip(parametric.params, actual.params, strict=True): + for k, vs in _match_sort_params(p_sub, a_sub, params, subsorts_fn).items(): + result.setdefault(k, []).extend(vs) + return result + if parametric.params and subsorts_fn is not None: + result = {} + for s in subsorts_fn(actual): + if s.name == parametric.name and len(s.params) == len(parametric.params): + for p_sub, a_sub in zip(parametric.params, s.params, strict=True): + for k, vs in _match_sort_params(p_sub, a_sub, params).items(): + result.setdefault(k, []).extend(vs) + return result + return {} + + @final @dataclass(frozen=True) class KDefinition(KOuter, WithKAtt, Iterable[KFlatModule]): @@ -1346,6 +1386,8 @@ def sort(self, kast: KInner) -> KSort | None: match kast: case KToken(_, sort) | KVariable(_, sort): return sort + case KAs(alias=KVariable(sort=sort)): + return sort case KRewrite(lhs, rhs): lhs_sort = self.sort(lhs) rhs_sort = self.sort(rhs) @@ -1355,8 +1397,11 @@ def sort(self, kast: KInner) -> KSort | None: case KSequence(_): return KSort('K') case KApply(label, _): - sort, _ = self.resolve_sorts(label) - return sort + try: + sort, _ = self.resolve_sorts(label) + return sort + except (KeyError, ValueError): + return None case _: return None @@ -1373,7 +1418,13 @@ def resolve_sorts(self, label: KLabel) -> tuple[KSort, tuple[KSort, ...]]: sorts = dict(zip(prod.params, label.params, strict=True)) def resolve(sort: KSort) -> KSort: - return sorts.get(sort, sort) + # Direct match: sort IS one of the sort parameters. + if sort in sorts: + return sorts[sort] + # Recursive substitution: sort params may appear nested (e.g. MInt{Width} → MInt{8}). + if sort.params: + return KSort(sort.name, tuple(resolve(p) for p in sort.params)) + return sort return resolve(prod.sort), tuple(resolve(sort) for sort in prod.argument_sorts) @@ -1499,31 +1550,137 @@ def transform( return Subst(subst)(new_term) + def infer_sort_params( + self, + prod: KProduction, + actual_sorts: tuple[KSort | None, ...], + expected_sort: KSort | None = None, + ) -> dict[KSort, KSort]: + """Infer sort parameter bindings for a parametric production application. + + Returns a (possibly partial) mapping from sort params to concrete sorts; + unbound parameters are absent from the result. + Mirrors ``AddSortInjections.substituteProd()`` in the Java frontend. + + ``actual_sorts`` must have the same length as ``prod.argument_sorts``. + ``None`` entries are skipped (unsortable arguments). + If ``expected_sort`` is given, parameters that appear only in the result sort + (not in any argument sort) are also inferred from it — this is the + ``matchExpected`` path in the Java algorithm. + """ + params = frozenset(prod.params) + candidates: dict[KSort, list[KSort]] = {} + + for psort, asort in zip(prod.argument_sorts, actual_sorts, strict=True): + if asort is None: + continue + for k, vs in _match_sort_params(psort, asort, params, self.subsorts).items(): + candidates.setdefault(k, []).extend(vs) + + if expected_sort is not None: + unbound_result_params = frozenset( + p + for p in params + if _sort_contains(prod.sort, p) + and not any(_sort_contains(asort, p) for asort in actual_sorts if asort is not None) + ) + if unbound_result_params: + for k, vs in _match_sort_params(prod.sort, expected_sort, unbound_result_params).items(): + candidates.setdefault(k, []).extend(vs) + + result: dict[KSort, KSort] = {} + for p in prod.params: + if p not in candidates: + continue + lub: KSort = candidates[p][0] + for s in candidates[p][1:]: + if lub == s: + continue + new_lub = self.least_common_supersort(lub, s) + if new_lub is None: + break + lub = new_lub + else: + result[p] = lub + + return result + # Best-effort addition of sort parameters to klabels, context insensitive def add_sort_params(self, kast: KInner) -> KInner: """Return a given term with the sort parameters on the `KLabel` filled in (which may be missing because of how the frontend works), best effort.""" + # ML predicate labels whose result sort (Sort2) is context-dependent and not inferable + # from the arguments alone. When Sort1 can be determined but Sort2 cannot, we fill Sort2 + # with the sentinel KSort('#SortParam') so that downstream Kore emission can introduce a + # universally-quantified sort variable (Q0) in the axiom. + _ML_PRED_RESULT_SORT_PARAM = KSort('#SortParam') # noqa: N806 + _ML_PRED_LABELS = frozenset({'#Equals', '#Ceil', '#Floor', '#In'}) # noqa: N806 def _add_sort_params(_k: KInner) -> KInner: - if type(_k) is KApply: - prod = self.symbols[_k.label.name] - if len(_k.label.params) == 0 and len(prod.params) > 0: - sort_dict: dict[KSort, KSort] = {} - for psort, asort in zip(prod.argument_sorts, map(self.sort, _k.args), strict=True): - if asort is None: - _LOGGER.warning( - f'Failed to add sort parameter, unable to determine sort for argument in production: {(prod, psort, asort)}' - ) - return _k - if psort in prod.params: - if psort in sort_dict and sort_dict[psort] != asort: - _LOGGER.warning( - f'Failed to add sort parameter, sort mismatch between different occurances of sort parameter: {(prod, psort, sort_dict[psort], asort)}' - ) - return _k - elif psort not in sort_dict: - sort_dict[psort] = asort - if all(p in sort_dict for p in prod.params): - return _k.let(label=KLabel(_k.label.name, [sort_dict[p] for p in prod.params])) + if type(_k) is not KApply: + return _k + prod = self.symbols[_k.label.name] + if len(_k.label.params) != 0 or len(prod.params) == 0: + return _k + + actual_sorts = tuple(map(self.sort, _k.args)) + param_set = frozenset(prod.params) + + # Separate sentinel args from real args; bail out on genuinely unsortable ones. + # Sentinels (#SortParam) propagate from nested ML preds and are handled below. + inference_sorts: list[KSort | None] = [] + for psort, asort in zip(prod.argument_sorts, actual_sorts, strict=True): + if asort == _ML_PRED_RESULT_SORT_PARAM: + inference_sorts.append(None) # skip in inference, propagate as sentinel below + elif asort is None: + _LOGGER.warning( + f'Failed to add sort parameter, unable to determine sort for argument in production: {(prod, psort, asort)}' + ) + return _k + else: + inference_sorts.append(asort) + + bindings = self.infer_sort_params(prod, tuple(inference_sorts)) + + # Sentinel propagation: if an arg carried the #SortParam sentinel (from a nested ML + # pred) and inference left that arg's param slot empty, fill it with the sentinel. + # Only direct-param positions (psort IS a param) propagate the sentinel; nested cases + # (psort = MInt{S}) do not, matching the current Java behaviour. + for psort, asort in zip(prod.argument_sorts, actual_sorts, strict=True): + if asort == _ML_PRED_RESULT_SORT_PARAM and psort in param_set and psort not in bindings: + bindings[psort] = _ML_PRED_RESULT_SORT_PARAM + + if all(p in bindings for p in prod.params): + return _k.let(label=KLabel(_k.label.name, [bindings[p] for p in prod.params])) + + # ML predicates have a context-dependent result sort (Sort2) that cannot be + # inferred from arguments. Fill it with the sentinel so that krule_to_kore can + # introduce a universally-quantified sort variable for the axiom. + if _k.label.name in _ML_PRED_LABELS: + unbound = [p for p in prod.params if p not in bindings] + # The single sentinel KSort('#SortParam') is only unambiguous when at most + # one parameter is unresolvable bottom-up. All current ML predicates + # (#Equals, #Ceil, #Floor, #In) have exactly two sort params {Sort1, + # Sort2}: Sort1 is always determined by the arguments, Sort2 (the result + # sort) is the one remaining unbound param. If more than one param is + # unbound, the sentinel scheme must be replaced with unique fresh params + # (e.g. KSort('#SortParam', (KSort('Q0'),)), KSort('#SortParam', (KSort('Q1'),)), ...) + # analogously to how Java's AddSortInjections generates #SortParam{Q0}, + # #SortParam{Q1}, etc. _ksort_to_kore would also need updating to emit + # these as sort variables rather than sort applications. + if len(unbound) > 1: + raise NotImplementedError( + f'ML predicate {_k.label.name!r} has {len(unbound)} unbound sort parameters ' + f'({unbound}); the single-sentinel scheme only handles at most one. ' + f'Implement unique fresh sentinels analogous to Java #SortParam{{Q0}}, ' + f'#SortParam{{Q1}}, ... and update _ksort_to_kore to emit them as sort variables.' + ) + filled = {p: bindings.get(p, _ML_PRED_RESULT_SORT_PARAM) for p in prod.params} + return _k.let(label=KLabel(_k.label.name, [filled[p] for p in prod.params])) + + unbound = [p for p in prod.params if p not in bindings] + _LOGGER.warning( + f'Failed to add sort parameter, could not infer sort params from arguments: {(prod, unbound)}' + ) return _k return bottom_up(_add_sort_params, kast) diff --git a/pyk/src/tests/unit/kast/test_definition.py b/pyk/src/tests/unit/kast/test_definition.py index 1b6f987be0..6868303c76 100644 --- a/pyk/src/tests/unit/kast/test_definition.py +++ b/pyk/src/tests/unit/kast/test_definition.py @@ -1,14 +1,24 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING import pytest from pyk.kast.att import Atts, KAtt -from pyk.kast.inner import KApply, KSort, KVariable -from pyk.kast.outer import KDefinition, KFlatModule, KNonTerminal, KProduction, KTerminal +from pyk.kast.inner import KApply, KAs, KLabel, KSequence, KSort, KToken, KVariable +from pyk.kast.outer import ( + KDefinition, + KFlatModule, + KNonTerminal, + KProduction, + KTerminal, + _match_sort_params, + _sort_contains, +) if TYPE_CHECKING: + from collections.abc import Callable from typing import Final from pyk.kast.inner import KInner @@ -17,6 +27,13 @@ # --------------------------------------------------------------------------- # Minimal test definition # +# bar: syntax N ::= bar(N) -- result sort is the param directly +# foo: syntax MInt{N} ::= foo(MInt{N}) -- result/arg sorts nest the param +# baz: syntax MInt{N} ::= baz() -- no args; param bound only from expected sort +# #Equals: syntax S2 ::= #Equals{S1,S2}(S1, S1) -- ML pred, result sort context-dependent +# +# Subsort: syntax Int ::= MInt{Int} -- MInt{Int} <: Int (enables subsort-aware matching) +# # Cell map fragment: # AccountCellMap ::= AccountCellMap AccountCellMap [cellCollection, element(AccountCellMapItem), wrapElement()] # AccountCellMap ::= AccountCellMapItem(Int, AccountCell) @@ -25,9 +42,69 @@ # --------------------------------------------------------------------------- INT: Final = KSort('Int') +N: Final = KSort('N') +S1: Final = KSort('S1') +S2: Final = KSort('S2') +S3: Final = KSort('S3') +MINT_N: Final = KSort('MInt', (N,)) +MINT_INT: Final = KSort('MInt', (INT,)) +SORT_PARAM: Final = KSort('#SortParam') ACCOUNT_CELL_MAP: Final = KSort('AccountCellMap') ACCOUNT_CELL: Final = KSort('AccountCell') +_BAR_PROD: Final = KProduction( + sort=N, + items=[KTerminal('bar'), KTerminal('('), KNonTerminal(N), KTerminal(')')], + params=[N], + klabel='bar', +) + +_FOO_PROD: Final = KProduction( + sort=MINT_N, + items=[KTerminal('foo'), KTerminal('('), KNonTerminal(MINT_N), KTerminal(')')], + params=[N], + klabel='foo', +) + +_EQUALS_PROD: Final = KProduction( + sort=S2, + items=[KNonTerminal(S1), KNonTerminal(S1)], + params=[S1, S2], + klabel='#Equals', +) + +# Hypothetical 3-param #Equals to test the multi-unbound-param guard. +# S1 is inferred from arguments; S2 and S3 are both unbound, which the single-sentinel +# scheme cannot handle — add_sort_params must raise NotImplementedError. +_EQUALS3_PROD: Final = KProduction( + sort=S2, + items=[KNonTerminal(S1), KNonTerminal(S1)], + params=[S1, S2, S3], + klabel='#Equals', +) + +# User-defined label where S2 does not appear in any argument sort, so it remains +# unbound after argument processing. add_sort_params must emit a warning and +# return the term unchanged (best-effort). +_PAIR_PROD: Final = KProduction( + sort=KSort('Pair', (S1, S2)), + items=[KTerminal('pair'), KTerminal('('), KNonTerminal(S1), KTerminal(')')], + params=[S1, S2], + klabel='pair', +) + +# syntax MInt{N} ::= baz() — no argument sorts; param N only bound via expected_sort +_BAZ_PROD: Final = KProduction( + sort=MINT_N, + items=[KTerminal('baz'), KTerminal('('), KTerminal(')')], + params=[N], + klabel='baz', +) + +# syntax Int ::= MInt{Int} — subsort declaration: MInt{Int} <: Int +# Enables the subsort-aware matching path (Java AddSortInjections.match step 3). +_MINT_INT_SUBSORT: Final = KProduction(sort=INT, items=[KNonTerminal(MINT_INT)]) + _ACCT_MAP_CONCAT: Final = KProduction( sort=ACCOUNT_CELL_MAP, items=[KNonTerminal(ACCOUNT_CELL_MAP), KNonTerminal(ACCOUNT_CELL_MAP)], @@ -69,9 +146,268 @@ DEFN: Final = KDefinition( 'TEST', - [KFlatModule('TEST', [_ACCT_MAP_CONCAT, _ACCT_MAP_ITEM, _ACCOUNT_CELL, _GET_ENTRY])], + [ + KFlatModule( + 'TEST', + [ + _BAR_PROD, + _FOO_PROD, + _BAZ_PROD, + _EQUALS_PROD, + _MINT_INT_SUBSORT, + _ACCT_MAP_CONCAT, + _ACCT_MAP_ITEM, + _ACCOUNT_CELL, + _GET_ENTRY, + ], + ) + ], +) + +# Definition used only to verify the multi-unbound-param guard in add_sort_params. +DEFN3: Final = KDefinition('TEST3', [KFlatModule('TEST3', [_EQUALS3_PROD])]) + +# Definition used only to verify the unresolvable-user-label warning path. +DEFN_PAIR: Final = KDefinition('TEST_PAIR', [KFlatModule('TEST_PAIR', [_PAIR_PROD])]) + + +# --------------------------------------------------------------------------- +# KDefinition.sort +# --------------------------------------------------------------------------- + +SORT_DATA: Final = ( + # Basic leaf terms + ('ktoken', KToken('42', INT), INT), + ('kvariable_with_sort', KVariable('X', sort=INT), INT), + ('ksequence', KSequence([]), KSort('K')), + # KApply: result sort substituted directly from param + ('kapply_direct_result', KApply(KLabel('bar', [INT]), [KVariable('X', sort=INT)]), INT), + # KApply: result sort nests the param (MInt{N} with N→Int → MInt{Int}) + ('kapply_nested_result', KApply(KLabel('foo', [INT]), [KVariable('X', sort=MINT_INT)]), MINT_INT), + # KApply with unfilled sort params: sort() returns None rather than raising + ('kapply_unfilled_params', KApply(KLabel('foo'), [KVariable('X', sort=MINT_INT)]), None), + # KApply with unknown label: KeyError from symbols lookup → None + ('kapply_unknown_label', KApply(KLabel('nonexistent'), []), None), + # KAs: sort of the alias variable + ('kas_sorted_alias', KAs(KVariable('X', sort=MINT_INT), KVariable('Y', sort=MINT_INT)), MINT_INT), + # KAs whose alias has no sort annotation: returns None + ('kas_unsorted_alias', KAs(KVariable('X', sort=MINT_INT), KVariable('Y')), None), ) + +@pytest.mark.parametrize( + 'test_id,term,expected', + SORT_DATA, + ids=[test_id for test_id, *_ in SORT_DATA], +) +def test_sort(test_id: str, term: KInner, expected: KSort | None) -> None: + assert DEFN.sort(term) == expected + + +# --------------------------------------------------------------------------- +# KDefinition.resolve_sorts +# --------------------------------------------------------------------------- + +RESOLVE_SORTS_DATA: Final = ( + # Direct substitution: result sort IS the param (N → Int) + ('direct_bar', KLabel('bar', [INT]), INT, (INT,)), + # Recursive substitution: result/arg sort nests the param (MInt{N} with N → Int → MInt{Int}) + ('nested_foo', KLabel('foo', [INT]), MINT_INT, (MINT_INT,)), +) + + +@pytest.mark.parametrize( + 'test_id,label,expected_result,expected_args', + RESOLVE_SORTS_DATA, + ids=[test_id for test_id, *_ in RESOLVE_SORTS_DATA], +) +def test_resolve_sorts(test_id: str, label: KLabel, expected_result: KSort, expected_args: tuple[KSort, ...]) -> None: + result, args = DEFN.resolve_sorts(label) + assert result == expected_result + assert args == expected_args + + +# --------------------------------------------------------------------------- +# KDefinition.add_sort_params +# --------------------------------------------------------------------------- + +ADD_SORT_PARAMS_DATA: Final = ( + # Label already has params filled: leave unchanged + ( + 'already_filled', + KApply(KLabel('bar', [INT]), [KVariable('X', sort=INT)]), + KApply(KLabel('bar', [INT]), [KVariable('X', sort=INT)]), + ), + # Direct sort param: psort IS the param (N ~ Int → N=Int) + ( + 'direct_param', + KApply(KLabel('bar'), [KVariable('X', sort=INT)]), + KApply(KLabel('bar', [INT]), [KVariable('X', sort=INT)]), + ), + # Nested sort param: psort = MInt{N}, asort = MInt{Int} → N=Int via unification + ( + 'nested_param', + KApply(KLabel('foo'), [KVariable('X', sort=MINT_INT)]), + KApply(KLabel('foo', [INT]), [KVariable('X', sort=MINT_INT)]), + ), + # ML pred: S1 inferred from args, S2 (result sort) filled with #SortParam sentinel + ( + 'ml_pred_sentinel', + KApply('#Equals', [KVariable('X', sort=INT), KVariable('Y', sort=INT)]), + KApply(KLabel('#Equals', [INT, SORT_PARAM]), [KVariable('X', sort=INT), KVariable('Y', sort=INT)]), + ), + # Unsortable argument (no sort annotation): cannot fill params, term returned unchanged + ( + 'unsortable_arg_unchanged', + KApply(KLabel('foo'), [KVariable('X')]), + KApply(KLabel('foo'), [KVariable('X')]), + ), + # Subsort-aware: arg sort is Int, but MInt{Int} <: Int in DEFN, so N=Int via subsort match + # (this case would fail with structural-only unification since Int ≠ MInt{N}) + ( + 'subsort_aware', + KApply(KLabel('foo'), [KVariable('X', sort=INT)]), + KApply(KLabel('foo', [INT]), [KVariable('X', sort=INT)]), + ), +) + + +@pytest.mark.parametrize( + 'test_id,term,expected', + ADD_SORT_PARAMS_DATA, + ids=[test_id for test_id, *_ in ADD_SORT_PARAMS_DATA], +) +def test_add_sort_params(test_id: str, term: KInner, expected: KInner) -> None: + assert DEFN.add_sort_params(term) == expected + + +def test_add_sort_params_multi_unbound_raises() -> None: + # #Equals with 3 sort params: S1 is inferred from arguments, S2 and S3 are both unbound. + # The single-sentinel scheme cannot distinguish them, so NotImplementedError must be raised. + term = KApply('#Equals', [KVariable('X', sort=INT), KVariable('Y', sort=INT)]) + with pytest.raises(NotImplementedError, match='2 unbound sort parameters'): + DEFN3.add_sort_params(term) + + +def test_add_sort_params_user_label_unresolvable_warns(caplog: pytest.LogCaptureFixture) -> None: + # pair(S1, S2) has S2 absent from arguments — S2 is unbound after inference. + # add_sort_params emits a warning and returns the term unchanged (best-effort). + term = KApply(KLabel('pair'), [KVariable('X', sort=INT)]) + with caplog.at_level(logging.WARNING): + result = DEFN_PAIR.add_sort_params(term) + assert result == term + assert any('could not infer sort params' in record.message for record in caplog.records) + + +# --------------------------------------------------------------------------- +# KDefinition.infer_sort_params +# --------------------------------------------------------------------------- +# +# Tests the public method directly (not through add_sort_params), mirroring the +# Java AddSortInjections.substituteProd() test scenarios derived from the algorithm. + +INFER_SORT_PARAMS_DATA: Final[ + tuple[tuple[str, KProduction, tuple[KSort | None, ...], KSort | None, dict[KSort, KSort]], ...] +] = ( + # Direct param: psort IS the param (N → Int) + ('direct_param', _BAR_PROD, (INT,), None, {N: INT}), + # Nested param: psort = MInt{N}, asort = MInt{Int} → N=Int via structural match + ('nested_param', _FOO_PROD, (MINT_INT,), None, {N: INT}), + # Subsort-aware: arg sort is Int, MInt{Int} <: Int in DEFN → N=Int via subsort iteration + # (structural match fails: MInt{N} ≠ Int; subsort check finds MInt{Int} ≤ Int) + ('subsort_aware', _FOO_PROD, (INT,), None, {N: INT}), + # matchExpected: baz() has no arg sorts; N is bound from the expected_sort MInt{Int} + ('expected_sort', _BAZ_PROD, (), MINT_INT, {N: INT}), + # None arg is skipped; no bindings → empty result + ('unbound_absent', _BAR_PROD, (None,), None, {}), +) + + +@pytest.mark.parametrize( + 'test_id,prod,actual_sorts,expected_sort,expected_bindings', + INFER_SORT_PARAMS_DATA, + ids=[test_id for test_id, *_ in INFER_SORT_PARAMS_DATA], +) +def test_infer_sort_params( + test_id: str, + prod: KProduction, + actual_sorts: tuple[KSort | None, ...], + expected_sort: KSort | None, + expected_bindings: dict[KSort, KSort], +) -> None: + assert DEFN.infer_sort_params(prod, actual_sorts, expected_sort) == expected_bindings + + +# --------------------------------------------------------------------------- +# _match_sort_params (module-level helper) +# --------------------------------------------------------------------------- +# +# Directly tests the three matching strategies described in the docstring. + + +def _subsorts_fn(s: KSort) -> frozenset[KSort]: + return frozenset({MINT_INT}) if s == INT else frozenset() + + +MATCH_SORT_PARAMS_DATA: Final[ + tuple[ + tuple[ + str, KSort, KSort, frozenset[KSort], Callable[[KSort], frozenset[KSort]] | None, dict[KSort, list[KSort]] + ], + ..., + ] +] = ( + # Case 1 – direct: parametric IS a sort param + ('direct', N, INT, frozenset({N}), None, {N: [INT]}), + # Case 2 – structural: same head, recurse on sub-params + ('structural', MINT_N, MINT_INT, frozenset({N}), None, {N: [INT]}), + # Case 2 fails (different heads), no subsorts_fn → empty + ('structural_no_match_no_subsorts', MINT_N, INT, frozenset({N}), None, {}), + # Case 3 – subsort-aware: MInt{N} vs Int; subsorts_fn yields MInt{Int} → N=Int + ('subsort_aware', MINT_N, INT, frozenset({N}), _subsorts_fn, {N: [INT]}), + # No match in any case + ('no_match', INT, KSort('Bool'), frozenset({N}), None, {}), +) + + +@pytest.mark.parametrize( + 'test_id,parametric,actual,params,subsorts_fn,expected', + MATCH_SORT_PARAMS_DATA, + ids=[test_id for test_id, *_ in MATCH_SORT_PARAMS_DATA], +) +def test_match_sort_params( + test_id: str, + parametric: KSort, + actual: KSort, + params: frozenset[KSort], + subsorts_fn: Callable[[KSort], frozenset[KSort]] | None, + expected: dict[KSort, list[KSort]], +) -> None: + assert _match_sort_params(parametric, actual, params, subsorts_fn) == expected + + +# --------------------------------------------------------------------------- +# _sort_contains (module-level helper) +# --------------------------------------------------------------------------- + +SORT_CONTAINS_DATA: Final = ( + ('param_itself', N, N, True), + ('nested_one_level', MINT_N, N, True), + ('nested_two_levels', KSort('Foo', (MINT_N,)), N, True), + ('concrete_not_param', MINT_INT, N, False), + ('unrelated', INT, N, False), +) + + +@pytest.mark.parametrize( + 'test_id,sort,param,expected', + SORT_CONTAINS_DATA, + ids=[test_id for test_id, *_ in SORT_CONTAINS_DATA], +) +def test_sort_contains(test_id: str, sort: KSort, param: KSort, expected: bool) -> None: + assert _sort_contains(sort, param) == expected + + # --------------------------------------------------------------------------- # KDefinition.add_cell_map_items # ---------------------------------------------------------------------------