Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 180 additions & 23 deletions pyk/src/pyk/kast/outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .att import EMPTY_ATT, Atts, Format, KAst, KAtt, WithKAtt
from .inner import (
KApply,
KAs,
KInner,
KLabel,
KRewrite,
Expand Down Expand Up @@ -1000,6 +1001,45 @@ def let(self, *, require: str | None = None) -> KRequire:
return KRequire(require=require)


def _sort_contains(sort: KSort, param: KSort) -> bool:
"""Return whether ``param`` appears anywhere in the sort tree of ``sort``."""
return sort == param or any(_sort_contains(p, param) for p in sort.params)


def _match_sort_params(
parametric: KSort,
actual: KSort,
params: frozenset[KSort],
subsorts_fn: Callable[[KSort], frozenset[KSort]] | None = None,
) -> dict[KSort, list[KSort]]:
"""Match ``parametric`` sort against ``actual``, collecting candidate bindings per sort param.

Three matching strategies, mirroring Java ``AddSortInjections.match()``:

1. Direct: ``parametric`` is itself a sort param — bind it to ``actual``.
2. Structural: same constructor head — recurse on sub-params.
3. Subsort-aware: iterate subsorts ``s ≤ actual`` with same head as ``parametric``,
collecting additional candidates for LUB resolution.
"""
if parametric in params:
return {parametric: [actual]}
if parametric.name == actual.name and len(parametric.params) == len(actual.params):
result: dict[KSort, list[KSort]] = {}
for p_sub, a_sub in zip(parametric.params, actual.params, strict=True):
for k, vs in _match_sort_params(p_sub, a_sub, params, subsorts_fn).items():
result.setdefault(k, []).extend(vs)
return result
if parametric.params and subsorts_fn is not None:
result = {}
for s in subsorts_fn(actual):
if s.name == parametric.name and len(s.params) == len(parametric.params):
for p_sub, a_sub in zip(parametric.params, s.params, strict=True):
for k, vs in _match_sort_params(p_sub, a_sub, params).items():
result.setdefault(k, []).extend(vs)
return result
return {}


@final
@dataclass(frozen=True)
class KDefinition(KOuter, WithKAtt, Iterable[KFlatModule]):
Expand Down Expand Up @@ -1346,6 +1386,8 @@ def sort(self, kast: KInner) -> KSort | None:
match kast:
case KToken(_, sort) | KVariable(_, sort):
return sort
case KAs(alias=KVariable(sort=sort)):
return sort
case KRewrite(lhs, rhs):
lhs_sort = self.sort(lhs)
rhs_sort = self.sort(rhs)
Expand All @@ -1355,8 +1397,11 @@ def sort(self, kast: KInner) -> KSort | None:
case KSequence(_):
return KSort('K')
case KApply(label, _):
sort, _ = self.resolve_sorts(label)
return sort
try:
sort, _ = self.resolve_sorts(label)
return sort
except (KeyError, ValueError):
return None
case _:
return None

Expand All @@ -1373,7 +1418,13 @@ def resolve_sorts(self, label: KLabel) -> tuple[KSort, tuple[KSort, ...]]:
sorts = dict(zip(prod.params, label.params, strict=True))

def resolve(sort: KSort) -> KSort:
return sorts.get(sort, sort)
# Direct match: sort IS one of the sort parameters.
if sort in sorts:
return sorts[sort]
# Recursive substitution: sort params may appear nested (e.g. MInt{Width} → MInt{8}).
if sort.params:
return KSort(sort.name, tuple(resolve(p) for p in sort.params))
return sort

return resolve(prod.sort), tuple(resolve(sort) for sort in prod.argument_sorts)

Expand Down Expand Up @@ -1499,31 +1550,137 @@ def transform(

return Subst(subst)(new_term)

def infer_sort_params(
self,
prod: KProduction,
actual_sorts: tuple[KSort | None, ...],
expected_sort: KSort | None = None,
) -> dict[KSort, KSort]:
"""Infer sort parameter bindings for a parametric production application.

Returns a (possibly partial) mapping from sort params to concrete sorts;
unbound parameters are absent from the result.
Mirrors ``AddSortInjections.substituteProd()`` in the Java frontend.

``actual_sorts`` must have the same length as ``prod.argument_sorts``.
``None`` entries are skipped (unsortable arguments).
If ``expected_sort`` is given, parameters that appear only in the result sort
(not in any argument sort) are also inferred from it — this is the
``matchExpected`` path in the Java algorithm.
"""
params = frozenset(prod.params)
candidates: dict[KSort, list[KSort]] = {}

for psort, asort in zip(prod.argument_sorts, actual_sorts, strict=True):
if asort is None:
continue
for k, vs in _match_sort_params(psort, asort, params, self.subsorts).items():
candidates.setdefault(k, []).extend(vs)

if expected_sort is not None:
unbound_result_params = frozenset(
p
for p in params
if _sort_contains(prod.sort, p)
and not any(_sort_contains(asort, p) for asort in actual_sorts if asort is not None)
)
if unbound_result_params:
for k, vs in _match_sort_params(prod.sort, expected_sort, unbound_result_params).items():
candidates.setdefault(k, []).extend(vs)

result: dict[KSort, KSort] = {}
for p in prod.params:
if p not in candidates:
continue
lub: KSort = candidates[p][0]
for s in candidates[p][1:]:
if lub == s:
continue
new_lub = self.least_common_supersort(lub, s)
if new_lub is None:
break
lub = new_lub
else:
result[p] = lub

return result

# Best-effort addition of sort parameters to klabels, context insensitive
def add_sort_params(self, kast: KInner) -> KInner:
"""Return a given term with the sort parameters on the `KLabel` filled in (which may be missing because of how the frontend works), best effort."""
# ML predicate labels whose result sort (Sort2) is context-dependent and not inferable
# from the arguments alone. When Sort1 can be determined but Sort2 cannot, we fill Sort2
# with the sentinel KSort('#SortParam') so that downstream Kore emission can introduce a
# universally-quantified sort variable (Q0) in the axiom.
_ML_PRED_RESULT_SORT_PARAM = KSort('#SortParam') # noqa: N806
_ML_PRED_LABELS = frozenset({'#Equals', '#Ceil', '#Floor', '#In'}) # noqa: N806

Comment thread
ehildenb marked this conversation as resolved.
def _add_sort_params(_k: KInner) -> KInner:
if type(_k) is KApply:
prod = self.symbols[_k.label.name]
if len(_k.label.params) == 0 and len(prod.params) > 0:
sort_dict: dict[KSort, KSort] = {}
for psort, asort in zip(prod.argument_sorts, map(self.sort, _k.args), strict=True):
if asort is None:
_LOGGER.warning(
f'Failed to add sort parameter, unable to determine sort for argument in production: {(prod, psort, asort)}'
)
return _k
if psort in prod.params:
if psort in sort_dict and sort_dict[psort] != asort:
_LOGGER.warning(
f'Failed to add sort parameter, sort mismatch between different occurances of sort parameter: {(prod, psort, sort_dict[psort], asort)}'
)
return _k
elif psort not in sort_dict:
sort_dict[psort] = asort
if all(p in sort_dict for p in prod.params):
return _k.let(label=KLabel(_k.label.name, [sort_dict[p] for p in prod.params]))
if type(_k) is not KApply:
return _k
prod = self.symbols[_k.label.name]
if len(_k.label.params) != 0 or len(prod.params) == 0:
return _k

actual_sorts = tuple(map(self.sort, _k.args))
param_set = frozenset(prod.params)

# Separate sentinel args from real args; bail out on genuinely unsortable ones.
# Sentinels (#SortParam) propagate from nested ML preds and are handled below.
inference_sorts: list[KSort | None] = []
for psort, asort in zip(prod.argument_sorts, actual_sorts, strict=True):
if asort == _ML_PRED_RESULT_SORT_PARAM:
inference_sorts.append(None) # skip in inference, propagate as sentinel below
elif asort is None:
_LOGGER.warning(
f'Failed to add sort parameter, unable to determine sort for argument in production: {(prod, psort, asort)}'
)
return _k
else:
inference_sorts.append(asort)

bindings = self.infer_sort_params(prod, tuple(inference_sorts))

# Sentinel propagation: if an arg carried the #SortParam sentinel (from a nested ML
# pred) and inference left that arg's param slot empty, fill it with the sentinel.
# Only direct-param positions (psort IS a param) propagate the sentinel; nested cases
# (psort = MInt{S}) do not, matching the current Java behaviour.
for psort, asort in zip(prod.argument_sorts, actual_sorts, strict=True):
if asort == _ML_PRED_RESULT_SORT_PARAM and psort in param_set and psort not in bindings:
bindings[psort] = _ML_PRED_RESULT_SORT_PARAM

if all(p in bindings for p in prod.params):
return _k.let(label=KLabel(_k.label.name, [bindings[p] for p in prod.params]))

# ML predicates have a context-dependent result sort (Sort2) that cannot be
# inferred from arguments. Fill it with the sentinel so that krule_to_kore can
# introduce a universally-quantified sort variable for the axiom.
if _k.label.name in _ML_PRED_LABELS:
unbound = [p for p in prod.params if p not in bindings]
# The single sentinel KSort('#SortParam') is only unambiguous when at most
# one parameter is unresolvable bottom-up. All current ML predicates
# (#Equals, #Ceil, #Floor, #In) have exactly two sort params {Sort1,
# Sort2}: Sort1 is always determined by the arguments, Sort2 (the result
# sort) is the one remaining unbound param. If more than one param is
# unbound, the sentinel scheme must be replaced with unique fresh params
# (e.g. KSort('#SortParam', (KSort('Q0'),)), KSort('#SortParam', (KSort('Q1'),)), ...)
# analogously to how Java's AddSortInjections generates #SortParam{Q0},
# #SortParam{Q1}, etc. _ksort_to_kore would also need updating to emit
# these as sort variables rather than sort applications.
if len(unbound) > 1:
raise NotImplementedError(
f'ML predicate {_k.label.name!r} has {len(unbound)} unbound sort parameters '
f'({unbound}); the single-sentinel scheme only handles at most one. '
f'Implement unique fresh sentinels analogous to Java #SortParam{{Q0}}, '
f'#SortParam{{Q1}}, ... and update _ksort_to_kore to emit them as sort variables.'
)
filled = {p: bindings.get(p, _ML_PRED_RESULT_SORT_PARAM) for p in prod.params}
return _k.let(label=KLabel(_k.label.name, [filled[p] for p in prod.params]))

unbound = [p for p in prod.params if p not in bindings]
_LOGGER.warning(
f'Failed to add sort parameter, could not infer sort params from arguments: {(prod, unbound)}'
)
return _k

return bottom_up(_add_sort_params, kast)
Expand Down
Loading
Loading