Skip to content

Commit 7469250

Browse files
authored
Merge pull request #845 from MilesCranmer/fix-greater
fix: comparison operator parsing
2 parents d41d4be + 665d4a1 commit 7469250

File tree

6 files changed

+34
-6
lines changed

6 files changed

+34
-6
lines changed

docs/operators.md

+7-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@ it can exponentially increase the search space.
1616
|--------------|------------|----------|
1717
| `+` | `max` | `logical_or`[^2] |
1818
| `-` | `min` | `logical_and`[^3]|
19-
| `*` | `greater`[^4] | |
20-
| `/` | `cond`[^5] | |
21-
| `^` | `mod` | |
19+
| `*` | `>`[^4] | |
20+
| `/` | `>=` | |
21+
| `^` | `<` | |
22+
| | `<=` | |
23+
| | `cond`[^5] | |
24+
| | `mod` | |
2225

2326
**Unary Operators**
2427

@@ -74,5 +77,5 @@ any invalid values over the training dataset.
7477
[^1]: However, you will need to define a sympy equivalent in `extra_sympy_mapping` if you want to use a function not in the above list.
7578
[^2]: `logical_or` is equivalent to `(x, y) -> (x > 0 || y > 0) ? 1 : 0`
7679
[^3]: `logical_and` is equivalent to `(x, y) -> (x > 0 && y > 0) ? 1 : 0`
77-
[^4]: `greater` is equivalent to `(x, y) -> x > y ? 1 : 0`
80+
[^4]: `>` is equivalent to `(x, y) -> x > y ? 1 : 0`
7881
[^5]: `cond` is equivalent to `(x, y) -> x > 0 ? y : 0`

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "pysr"
7-
version = "1.5.0"
7+
version = "1.5.1"
88
authors = [
99
{name = "Miles Cranmer", email = "[email protected]"},
1010
]

pysr/export_sympy.py

+3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
"max": lambda x, y: sympy.Piecewise((y, x < y), (x, True)),
5656
"min": lambda x, y: sympy.Piecewise((x, x < y), (y, True)),
5757
"greater": lambda x, y: sympy.Piecewise((1.0, x > y), (0.0, True)),
58+
"less": lambda x, y: sympy.Piecewise((1.0, x < y), (0.0, True)),
59+
"greater_equal": lambda x, y: sympy.Piecewise((1.0, x >= y), (0.0, True)),
60+
"less_equal": lambda x, y: sympy.Piecewise((1.0, x <= y), (0.0, True)),
5861
"cond": lambda x, y: sympy.Piecewise((y, x > 0), (0.0, True)),
5962
"logical_or": lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)),
6063
"logical_and": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)),

pysr/julia_import.py

+3
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,8 @@ def _import_juliacall():
6767
# Expose `D` operator:
6868
jl.seval("using SymbolicRegression: D")
6969

70+
# Expose other operators:
71+
jl.seval("using SymbolicRegression: less, greater_equal, less_equal")
72+
7073
jl.seval("using Pkg: Pkg")
7174
Pkg = jl.Pkg

pysr/juliapkg.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"packages": {
44
"SymbolicRegression": {
55
"uuid": "8254be44-1295-4e6a-a16d-46603ac705cb",
6-
"version": "~1.8.0"
6+
"version": "~1.9.0"
77
},
88
"Serialization": {
99
"uuid": "9e88b42a-f829-5b0c-bbe9-9e923198166b",

pysr/test/test_main.py

+19
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,25 @@ def test_tensorboard_logger(self):
758758
# Verify model still works as expected
759759
self.assertLessEqual(model.get_best()["loss"], 1e-4)
760760

761+
def test_comparison_operator(self):
762+
X = self.rstate.randn(100, 2)
763+
y = ((X[:, 0] + X[:, 1]) < (X[:, 0] * X[:, 1])).astype(float)
764+
765+
model = PySRRegressor(
766+
binary_operators=["<", "+", "*"],
767+
**self.default_test_kwargs,
768+
early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity <= 7",
769+
)
770+
771+
model.fit(X, y)
772+
773+
best_equation = model.get_best()["equation"]
774+
self.assertIn("less", best_equation)
775+
self.assertLessEqual(model.get_best()["loss"], 1e-4)
776+
777+
y_pred = model.predict(X)
778+
np.testing.assert_array_almost_equal(y, y_pred, decimal=3)
779+
761780

762781
def manually_create_model(equations, feature_names=None):
763782
if feature_names is None:

0 commit comments

Comments
 (0)