Skip to content

Commit 62cf433

Browse files
committed
Fix issue that prevented handling rule application for rules of the form pat->Condition[expr_,cond]
1 parent 2ffdf02 commit 62cf433

File tree

7 files changed

+228
-23
lines changed

7 files changed

+228
-23
lines changed

CHANGES.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ Bugs
5050

5151
# ``0`` with a given precision (like in ```0`3```) is now parsed as ``0``, an integer number.
5252
#. ``RandomSample`` with one list argument now returns a random ordering of the list items. Previously it would return just one item.
53-
53+
#. Rules of the form ``pat->Condition[expr, cond]`` are handled as in WL. The same also works for nested `Condition` expressions. In particular, the comparison between two Rules with the same pattern but an iterated ``Condition`` expressionare considered equal if the conditions are the same.
54+
5455

5556
Enhancements
5657
++++++++++++

mathics/builtin/assignments/assignment.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,21 @@ class SetDelayed(Set):
170170
'Condition' ('/;') can be used with 'SetDelayed' to make an
171171
assignment that only holds if a condition is satisfied:
172172
>> f[x_] := p[x] /; x>0
173+
>> f[x_] := p[-x]/; x<-2
173174
>> f[3]
174175
= p[3]
175176
>> f[-3]
176-
= f[-3]
177-
It also works if the condition is set in the LHS:
177+
= p[3]
178+
>> f[-1]
179+
= f[-1]
180+
Notice that the LHS is the same in both definitions, but the second
181+
does not overwrite the first one.
182+
183+
To overwrite one of these definitions, we have to assign using the same condition:
184+
>> f[x_] := Sin[x] /; x>0
185+
>> f[3]
186+
= Sin[3]
187+
In a similar way, the condition can be set in the LHS:
178188
>> F[x_, y_] /; x < y /; x>0 := x / y;
179189
>> F[x_, y_] := y / x;
180190
>> F[2, 3]

mathics/builtin/patterns.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ def create_rules(rules_expr, expr, name, evaluation, extra_args=[]):
172172
else:
173173
result = []
174174
for rule in rules:
175-
if rule.get_head_name() not in ("System`Rule", "System`RuleDelayed"):
175+
head_name = rule.get_head_name()
176+
if head_name not in ("System`Rule", "System`RuleDelayed"):
176177
evaluation.message(name, "reps", rule)
177178
return None, True
178179
elif len(rule.elements) != 2:
@@ -186,7 +187,13 @@ def create_rules(rules_expr, expr, name, evaluation, extra_args=[]):
186187
)
187188
return None, True
188189
else:
189-
result.append(Rule(rule.elements[0], rule.elements[1]))
190+
result.append(
191+
Rule(
192+
rule.elements[0],
193+
rule.elements[1],
194+
delayed=(head_name == "System`RuleDelayed"),
195+
)
196+
)
190197
return result, False
191198

192199

@@ -1690,7 +1697,7 @@ def __init__(self, rulelist, evaluation):
16901697
self._elements = None
16911698
self._head = SymbolDispatch
16921699

1693-
def get_sort_key(self) -> tuple:
1700+
def get_sort_key(self, pattern_sort=False) -> tuple:
16941701
return self.src.get_sort_key()
16951702

16961703
def get_atom_name(self):

mathics/core/definitions.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111

1212
from typing import List, Optional
1313

14-
from mathics.core.atoms import String
14+
from mathics.core.atoms import Integer, String
1515
from mathics.core.attributes import A_NO_ATTRIBUTES
1616
from mathics.core.convert.expression import to_mathics_list
1717
from mathics.core.element import fully_qualified_symbol_name
1818
from mathics.core.expression import Expression
19+
from mathics.core.pattern import Pattern
20+
from mathics.core.rules import Rule
1921
from mathics.core.symbols import (
2022
Atom,
2123
Symbol,
@@ -721,9 +723,6 @@ def get_ownvalue(self, name):
721723
return None
722724

723725
def set_ownvalue(self, name, value) -> None:
724-
from .expression import Symbol
725-
from .rules import Rule
726-
727726
name = self.lookup_name(name)
728727
self.add_rule(name, Rule(Symbol(name), value))
729728
self.clear_cache(name)
@@ -759,8 +758,6 @@ def get_config_value(self, name, default=None):
759758
return default
760759

761760
def set_config_value(self, name, new_value) -> None:
762-
from mathics.core.expression import Integer
763-
764761
self.set_ownvalue(name, Integer(new_value))
765762

766763
def set_line_no(self, line_no) -> None:
@@ -780,6 +777,25 @@ def get_history_length(self):
780777

781778

782779
def get_tag_position(pattern, name) -> Optional[str]:
780+
# Strip first the pattern from HoldPattern, Pattern
781+
# and Condition wrappings
782+
while True:
783+
# TODO: Not Atom/Expression,
784+
# pattern -> pattern.to_expression()
785+
if isinstance(pattern, Pattern):
786+
pattern = pattern.expr
787+
continue
788+
if pattern.has_form("System`HoldPattern", 1):
789+
pattern = pattern.elements[0]
790+
continue
791+
if pattern.has_form("System`Pattern", 2):
792+
pattern = pattern.elements[1]
793+
continue
794+
if pattern.has_form("System`Condition", 2):
795+
pattern = pattern.elements[0]
796+
continue
797+
break
798+
783799
if pattern.get_name() == name:
784800
return "own"
785801
elif isinstance(pattern, Atom):
@@ -788,10 +804,8 @@ def get_tag_position(pattern, name) -> Optional[str]:
788804
head_name = pattern.get_head_name()
789805
if head_name == name:
790806
return "down"
791-
elif head_name == "System`N" and len(pattern.elements) == 2:
807+
elif pattern.has_form("System`N", 2):
792808
return "n"
793-
elif head_name == "System`Condition" and len(pattern.elements) > 0:
794-
return get_tag_position(pattern.elements[0], name)
795809
elif pattern.get_lookup_name() == name:
796810
return "sub"
797811
else:
@@ -801,11 +815,18 @@ def get_tag_position(pattern, name) -> Optional[str]:
801815
return None
802816

803817

804-
def insert_rule(values, rule) -> None:
818+
def insert_rule(values: list, rule: Rule) -> None:
819+
rhs_conds = getattr(rule, "rhs_conditions", [])
805820
for index, existing in enumerate(values):
806821
if existing.pattern.sameQ(rule.pattern):
807-
del values[index]
808-
break
822+
# Check for coincidences in the replace conditions,
823+
# it they are there.
824+
# This ensures that the rules are equivalent even taking
825+
# into accound the RHS conditions.
826+
existing_rhs_conds = getattr(existing, "rhs_conditions", [])
827+
if existing_rhs_conds == rhs_conds:
828+
del values[index]
829+
break
809830
# use insort_left to guarantee that if equal rules exist, newer rules will
810831
# get higher precedence by being inserted before them. see DownValues[].
811832
bisect.insort_left(values, rule)

mathics/core/rules.py

+138-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from mathics.core.element import KeyComparable
77
from mathics.core.expression import Expression
8-
from mathics.core.symbols import strip_context
8+
from mathics.core.symbols import strip_context, SymbolTrue
99
from mathics.core.pattern import Pattern, StopGenerator
1010

1111
from itertools import chain
@@ -19,6 +19,10 @@ def function_arguments(f):
1919
return _python_function_arguments(f)
2020

2121

22+
class StopMatchConditionFailed(StopGenerator):
23+
pass
24+
25+
2226
class StopGenerator_BaseRule(StopGenerator):
2327
pass
2428

@@ -59,7 +63,11 @@ def yield_match(vars, rest):
5963
if name.startswith("_option_"):
6064
options[name[len("_option_") :]] = value
6165
del vars[name]
62-
new_expression = self.do_replace(expression, vars, options, evaluation)
66+
try:
67+
new_expression = self.do_replace(expression, vars, options, evaluation)
68+
except StopMatchConditionFailed:
69+
return
70+
6371
if new_expression is None:
6472
new_expression = expression
6573
if rest[0] or rest[1]:
@@ -107,7 +115,7 @@ def yield_match(vars, rest):
107115
def do_replace(self):
108116
raise NotImplementedError
109117

110-
def get_sort_key(self) -> tuple:
118+
def get_sort_key(self, pattern_sort=False) -> tuple:
111119
# FIXME: check if this makes sense:
112120
return tuple((self.system, self.pattern.get_sort_key(True)))
113121

@@ -131,12 +139,131 @@ class Rule(BaseRule):
131139
``G[1.^2, a^2]``
132140
"""
133141

134-
def __init__(self, pattern, replace, system=False) -> None:
142+
def __ge__(self, other):
143+
if isinstance(other, Rule):
144+
sys, key, rhs_cond = self.get_sort_key()
145+
sys_other, key_other, rhs_cond_other = other.get_sort_key()
146+
if sys != sys_other:
147+
return sys > sys_other
148+
if key != key_other:
149+
return key > key_other
150+
151+
# larger and more complex conditions come first
152+
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
153+
if len_cond != len_cond_other:
154+
return len_cond_other > len_cond
155+
if len_cond == 0:
156+
return False
157+
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
158+
me_sk = me_cond.get_sort_key(True)
159+
o_sk = other_cond.get_sort_key(True)
160+
if me_sk > o_sk:
161+
return False
162+
return True
163+
# Follow the usual rule
164+
return self.get_sort_key(True) >= other.get_sort_key(True)
165+
166+
def __gt__(self, other):
167+
if isinstance(other, Rule):
168+
sys, key, rhs_cond = self.get_sort_key()
169+
sys_other, key_other, rhs_cond_other = other.get_sort_key()
170+
if sys != sys_other:
171+
return sys > sys_other
172+
if key != key_other:
173+
return key > key_other
174+
175+
# larger and more complex conditions come first
176+
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
177+
if len_cond != len_cond_other:
178+
return len_cond_other > len_cond
179+
if len_cond == 0:
180+
return False
181+
182+
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
183+
me_sk = me_cond.get_sort_key(True)
184+
o_sk = other_cond.get_sort_key(True)
185+
if me_sk > o_sk:
186+
return False
187+
return me_sk > o_sk
188+
# Follow the usual rule
189+
return self.get_sort_key(True) > other.get_sort_key(True)
190+
191+
def __le__(self, other):
192+
if isinstance(other, Rule):
193+
sys, key, rhs_cond = self.get_sort_key()
194+
sys_other, key_other, rhs_cond_other = other.get_sort_key()
195+
if sys != sys_other:
196+
return sys < sys_other
197+
if key != key_other:
198+
return key < key_other
199+
200+
# larger and more complex conditions come first
201+
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
202+
if len_cond != len_cond_other:
203+
return len_cond_other < len_cond
204+
if len_cond == 0:
205+
return False
206+
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
207+
me_sk = me_cond.get_sort_key(True)
208+
o_sk = other_cond.get_sort_key(True)
209+
if me_sk < o_sk:
210+
return False
211+
return True
212+
# Follow the usual rule
213+
return self.get_sort_key(True) <= other.get_sort_key(True)
214+
215+
def __lt__(self, other):
216+
if isinstance(other, Rule):
217+
sys, key, rhs_cond = self.get_sort_key()
218+
sys_other, key_other, rhs_cond_other = other.get_sort_key()
219+
if sys != sys_other:
220+
return sys < sys_other
221+
if key != key_other:
222+
return key < key_other
223+
224+
# larger and more complex conditions come first
225+
len_cond, len_cond_other = len(rhs_cond), len(rhs_cond_other)
226+
if len_cond != len_cond_other:
227+
return len_cond_other < len_cond
228+
if len_cond == 0:
229+
return False
230+
231+
for me_cond, other_cond in zip(rhs_cond, rhs_cond_other):
232+
me_sk = me_cond.get_sort_key(True)
233+
o_sk = other_cond.get_sort_key(True)
234+
if me_sk < o_sk:
235+
return False
236+
return me_sk > o_sk
237+
# Follow the usual rule
238+
return self.get_sort_key(True) < other.get_sort_key(True)
239+
240+
def __init__(self, pattern, replace, delayed=True, system=False) -> None:
135241
super(Rule, self).__init__(pattern, system=system)
136242
self.replace = replace
243+
self.delayed = delayed
244+
# If delayed is True, and replace is a nested
245+
# Condition expression, stores the conditions and the
246+
# remaining stripped expression.
247+
# This is going to be used to compare and sort rules,
248+
# and also to decide if the rule matches an expression.
249+
conds = []
250+
if delayed:
251+
while replace.has_form("System`Condition", 2):
252+
replace, cond = replace.elements
253+
conds.append(cond)
254+
255+
self.rhs_conditions = sorted(conds)
256+
self.strip_replace = replace
137257

138258
def do_replace(self, expression, vars, options, evaluation):
139-
new = self.replace.replace_vars(vars)
259+
replace = self.replace if self.rhs_conditions == [] else self.strip_replace
260+
for cond in self.rhs_conditions:
261+
cond = cond.replace_vars(vars)
262+
cond = cond.evaluate(evaluation)
263+
if cond is not SymbolTrue:
264+
raise StopMatchConditionFailed
265+
266+
new = replace.replace_vars(vars)
140267
new.options = options
141268

142269
# if options is a non-empty dict, we need to ensure reevaluation of the whole expression, since 'new' will
@@ -159,6 +286,12 @@ def do_replace(self, expression, vars, options, evaluation):
159286
def __repr__(self) -> str:
160287
return "<Rule: %s -> %s>" % (self.pattern, self.replace)
161288

289+
def get_sort_key(self, pattern_sort=False) -> tuple:
290+
# FIXME: check if this makes sense:
291+
return tuple(
292+
(self.system, self.pattern.get_sort_key(True), self.rhs_conditions)
293+
)
294+
162295

163296
class BuiltinRule(BaseRule):
164297
"""

mathics/core/systemsymbols.py

+5
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
SymbolSeries = Symbol("System`Series")
170170
SymbolSeriesData = Symbol("System`SeriesData")
171171
SymbolSet = Symbol("System`Set")
172+
SymbolSetDelayed = Symbol("System`SetDelayed")
172173
SymbolSign = Symbol("System`Sign")
173174
SymbolSimplify = Symbol("System`Simplify")
174175
SymbolSin = Symbol("System`Sin")
@@ -186,6 +187,8 @@
186187
SymbolSubsuperscriptBox = Symbol("System`SubsuperscriptBox")
187188
SymbolSuperscriptBox = Symbol("System`SuperscriptBox")
188189
SymbolTable = Symbol("System`Table")
190+
SymbolTagSet = Symbol("System`TagSet")
191+
SymbolTagSetDelayed = Symbol("System`TagSetDelayed")
189192
SymbolTeXForm = Symbol("System`TeXForm")
190193
SymbolThrow = Symbol("System`Throw")
191194
SymbolToString = Symbol("System`ToString")
@@ -194,5 +197,7 @@
194197
SymbolUndefined = Symbol("System`Undefined")
195198
SymbolUnequal = Symbol("System`Unequal")
196199
SymbolUnevaluated = Symbol("System`Unevaluated")
200+
SymbolUpSet = Symbol("System`UpSet")
201+
SymbolUpSetDelayed = Symbol("System`UpSetDelayed")
197202
SymbolUpValues = Symbol("System`UpValues")
198203
SymbolXor = Symbol("System`Xor")

0 commit comments

Comments
 (0)