diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6b4f2b831..c60f0ad94 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -76,6 +76,9 @@ jobs: - name: "Run Torch tests" run: coverage run --append -m pysr test torch if: ${{ matrix.test-id == 'main' }} + - name: "Run Paddle tests" + run: coverage run --append -m pysr test paddle + if: ${{ matrix.test-id == 'main' }} - name: "Coveralls" env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/CI_Windows.yml b/.github/workflows/CI_Windows.yml index a62b7af6e..bfd8df5b4 100644 --- a/.github/workflows/CI_Windows.yml +++ b/.github/workflows/CI_Windows.yml @@ -55,3 +55,7 @@ jobs: run: pip install torch # (optional import) - name: "Run Torch tests" run: python -m pysr test torch + - name: "Install Paddle" + run: pip install paddlepaddle # (optional import) + - name: "Run Paddle tests" + run: python -m pysr test paddle diff --git a/.github/workflows/CI_mac.yml b/.github/workflows/CI_mac.yml index 68a940ee3..603653e09 100644 --- a/.github/workflows/CI_mac.yml +++ b/.github/workflows/CI_mac.yml @@ -55,3 +55,5 @@ jobs: run: python -m pysr test jax - name: "Run Torch tests" run: python -m pysr test torch + - name: "Run Paddle tests" + run: python -m pysr test paddle diff --git a/examples/pysr_demo_paddle.py b/examples/pysr_demo_paddle.py new file mode 100644 index 000000000..523b03c95 --- /dev/null +++ b/examples/pysr_demo_paddle.py @@ -0,0 +1,115 @@ +import os + +import numpy as np +import paddle +from paddle import nn +from paddle.io import DataLoader, TensorDataset +from sklearn.model_selection import train_test_split + +os.environ["PYTHON_JULIACALL_THREADS"] = "1" + +rstate = np.random.RandomState(0) + +N = 100000 +Nt = 10 +X = 6 * rstate.rand(N, Nt, 5) - 3 +y_i = X[..., 0] ** 2 + 6 * np.cos(2 * X[..., 2]) +y = np.sum(y_i, axis=1) / y_i.shape[1] +z = y**2 + + +hidden = 128 +total_steps = 50_000 + + +def mlp(size_in, size_out, act=nn.ReLU): + return nn.Sequential( + nn.Linear(size_in, hidden), + act(), + nn.Linear(hidden, hidden), + act(), + nn.Linear(hidden, hidden), + act(), + nn.Linear(hidden, size_out), + ) + + +class SumNet(nn.Layer): + def __init__(self): + super().__init__() + + ######################################################## + # The same inductive bias as above! + self.g = mlp(5, 1) + self.f = mlp(1, 1) + + def forward(self, x): + y_i = self.g(x)[:, :, 0] + y = paddle.sum(y_i, axis=1, keepdim=True) / y_i.shape[1] + z = self.f(y) + return z[:, 0] + + +Xt = paddle.to_tensor(X).astype("float32") +zt = paddle.to_tensor(z).astype("float32") +X_train, X_test, z_train, z_test = train_test_split(Xt, zt, random_state=0) +train_set = TensorDataset([X_train, z_train]) +train = DataLoader(train_set, batch_size=128, shuffle=True) +test_set = TensorDataset([X_test, z_test]) +test = DataLoader(test_set, batch_size=256) + +paddle.seed(0) + +model = SumNet() +max_lr = 1e-2 +model = paddle.Model(model) +scheduler = paddle.optimizer.lr.OneCycleLR( + max_learning_rate=max_lr, total_steps=total_steps, divide_factor=1e4 +) +optim = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters()) +model.prepare(optim, paddle.nn.MSELoss()) +model.fit(train, test, num_iters=total_steps, eval_freq=1000) + +np.random.seed(0) +idx = np.random.randint(0, 10000, size=1000) + +X_for_pysr = Xt[idx] +y_i_for_pysr = model.network.g(X_for_pysr)[:, :, 0] +y_for_pysr = paddle.sum(y_i_for_pysr, axis=1) / y_i_for_pysr.shape[1] +z_for_pysr = zt[idx] # Use true values. + + +nnet_recordings = { + "g_input": X_for_pysr.detach().cpu().numpy().reshape(-1, 5), + "g_output": y_i_for_pysr.detach().cpu().numpy().reshape(-1), + "f_input": y_for_pysr.detach().cpu().numpy().reshape(-1, 1), + "f_output": z_for_pysr.detach().cpu().numpy().reshape(-1), +} + +# Save the data for later use: +import pickle as pkl + +with open("nnet_recordings.pkl", "wb") as f: + pkl.dump(nnet_recordings, f) + +import pickle as pkl + +nnet_recordings = pkl.load(open("nnet_recordings.pkl", "rb")) +f_input = nnet_recordings["f_input"] +f_output = nnet_recordings["f_output"] +g_input = nnet_recordings["g_input"] +g_output = nnet_recordings["g_output"] + + +rstate = np.random.RandomState(0) +f_sample_idx = rstate.choice(f_input.shape[0], size=500, replace=False) +from pysr import PySRRegressor + +model = PySRRegressor( + niterations=50, + binary_operators=["+", "-", "*"], + unary_operators=["cos", "square"], +) +model.fit(g_input[f_sample_idx], g_output[f_sample_idx]) + +model.equations_[["complexity", "loss", "equation"]] diff --git a/examples/pysr_demo_pytorch.py b/examples/pysr_demo_pytorch.py new file mode 100644 index 000000000..918b82ebb --- /dev/null +++ b/examples/pysr_demo_pytorch.py @@ -0,0 +1,144 @@ +import os +from multiprocessing import cpu_count + +import numpy as np +import pytorch_lightning as pl +import torch +from sklearn.model_selection import train_test_split +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, TensorDataset + +from pysr import PySRRegressor + +os.environ["PYTHON_JULIACALL_THREADS"] = "1" +rstate = np.random.RandomState(0) + +N = 100000 +Nt = 10 +X = 6 * rstate.rand(N, Nt, 5) - 3 +y_i = X[..., 0] ** 2 + 6 * np.cos(2 * X[..., 2]) +y = np.sum(y_i, axis=1) / y_i.shape[1] +z = y**2 + + +hidden = 128 +total_steps = 50_000 + + +def mlp(size_in, size_out, act=nn.ReLU): + return nn.Sequential( + nn.Linear(size_in, hidden), + act(), + nn.Linear(hidden, hidden), + act(), + nn.Linear(hidden, hidden), + act(), + nn.Linear(hidden, size_out), + ) + + +class SumNet(pl.LightningModule): + def __init__(self): + super().__init__() + + ######################################################## + # The same inductive bias as above! + self.g = mlp(5, 1) + self.f = mlp(1, 1) + + def forward(self, x): + y_i = self.g(x)[:, :, 0] + y = torch.sum(y_i, dim=1, keepdim=True) / y_i.shape[1] + z = self.f(y) + return z[:, 0] + + ######################################################## + + # PyTorch Lightning bookkeeping: + def training_step(self, batch, batch_idx): + x, z = batch + predicted_z = self(x) + loss = F.mse_loss(predicted_z, z) + return loss + + def validation_step(self, batch, batch_idx): + return self.training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.max_lr) + scheduler = { + "scheduler": torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=self.max_lr, + total_steps=self.trainer.estimated_stepping_batches, + final_div_factor=1e4, + ), + "interval": "step", + } + return [optimizer], [scheduler] + + +Xt = torch.tensor(X).float() +zt = torch.tensor(z).float() +X_train, X_test, z_train, z_test = train_test_split(Xt, zt, random_state=0) +train_set = TensorDataset(X_train, z_train) +train = DataLoader( + train_set, batch_size=128, num_workers=cpu_count(), shuffle=True, pin_memory=True +) +test_set = TensorDataset(X_test, z_test) +test = DataLoader(test_set, batch_size=256, num_workers=cpu_count(), pin_memory=True) + +pl.seed_everything(0) +model = SumNet() +model.total_steps = total_steps +model.max_lr = 1e-2 + +trainer = pl.Trainer(max_steps=total_steps, accelerator="gpu", devices=1) +trainer.fit(model, train_dataloaders=train, val_dataloaders=test) + + +np.random.seed(0) +idx = np.random.randint(0, 10000, size=1000) + +X_for_pysr = Xt[idx] +y_i_for_pysr = model.g(X_for_pysr)[:, :, 0] +y_for_pysr = torch.sum(y_i_for_pysr, dim=1) / y_i_for_pysr.shape[1] +z_for_pysr = zt[idx] # Use true values. + +X_for_pysr.shape, y_i_for_pysr.shape + + +nnet_recordings = { + "g_input": X_for_pysr.detach().cpu().numpy().reshape(-1, 5), + "g_output": y_i_for_pysr.detach().cpu().numpy().reshape(-1), + "f_input": y_for_pysr.detach().cpu().numpy().reshape(-1, 1), + "f_output": z_for_pysr.detach().cpu().numpy().reshape(-1), +} + +# Save the data for later use: +import pickle as pkl + +with open("nnet_recordings.pkl", "wb") as f: + pkl.dump(nnet_recordings, f) + +import pickle as pkl + +nnet_recordings = pkl.load(open("nnet_recordings.pkl", "rb")) +f_input = nnet_recordings["f_input"] +f_output = nnet_recordings["f_output"] +g_input = nnet_recordings["g_input"] +g_output = nnet_recordings["g_output"] + + +rstate = np.random.RandomState(0) +f_sample_idx = rstate.choice(f_input.shape[0], size=500, replace=False) + +model = PySRRegressor( + niterations=50, + binary_operators=["+", "-", "*"], + unary_operators=["cos", "square"], +) +model.fit(g_input[f_sample_idx], g_output[f_sample_idx]) + +model.equations_[["complexity", "loss", "equation"]] diff --git a/pyproject.toml b/pyproject.toml index eb3ce6157..a5f07eb36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ requires-python = ">=3.10" classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License" + "License :: OSI Approved :: Apache Software License", ] dependencies = [ "sympy>=1.0.0,<2.0.0", @@ -35,6 +35,7 @@ dev = [ "ipykernel>=6,<7", "ipython>=8,<9", "jax[cpu]>=0.4,<0.5", + "paddlepaddle>=2.6.1", "jupyter>=1,<2", "mypy>=1,<2", "nbval>=0.11,<0.12", @@ -51,10 +52,10 @@ dev = [ [tool.setuptools] packages = ["pysr", "pysr._cli", "pysr.test"] include-package-data = false -package-data = {pysr = ["juliapkg.json"]} +package-data = { pysr = ["juliapkg.json"] } [tool.setuptools.dynamic] -dependencies = {file = "requirements.txt"} +dependencies = { file = "requirements.txt" } [tool.isort] profile = "black" diff --git a/pysr/__init__.py b/pysr/__init__.py index e26174aba..9626c5ee8 100644 --- a/pysr/__init__.py +++ b/pysr/__init__.py @@ -13,6 +13,7 @@ from . import sklearn_monkeypatch from .deprecated import best, best_callable, best_row, best_tex, install, pysr from .export_jax import sympy2jax +from .export_paddle import sympy2paddle from .export_torch import sympy2torch from .expression_specs import ( AbstractExpressionSpec, @@ -33,6 +34,7 @@ "sklearn_monkeypatch", "sympy2jax", "sympy2torch", + "sympy2paddle", "install", "load_all_packages", "PySRRegressor", diff --git a/pysr/_cli/main.py b/pysr/_cli/main.py index b27b7cedc..30c443929 100644 --- a/pysr/_cli/main.py +++ b/pysr/_cli/main.py @@ -10,6 +10,7 @@ runtests, runtests_dev, runtests_jax, + runtests_paddle, runtests_startup, runtests_torch, ) @@ -48,7 +49,7 @@ def _install(julia_project, quiet, precompile): ) -TEST_OPTIONS = {"main", "jax", "torch", "cli", "dev", "startup"} +TEST_OPTIONS = {"main", "jax", "torch", "paddle", "cli", "dev", "startup"} @pysr.command("test") @@ -63,7 +64,7 @@ def _install(julia_project, quiet, precompile): def _tests(tests, expressions): """Run parts of the PySR test suite. - Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas. + Choose from main, jax, torch, paddle, cli, dev, and startup. You can give multiple tests, separated by commas. """ test_cases = [] for test in tests.split(","): @@ -73,6 +74,8 @@ def _tests(tests, expressions): test_cases.extend(runtests_jax(just_tests=True)) elif test == "torch": test_cases.extend(runtests_torch(just_tests=True)) + elif test == "paddle": + test_cases.extend(runtests_paddle(just_tests=True)) elif test == "cli": runtests_cli = get_runtests_cli() test_cases.extend(runtests_cli(just_tests=True)) diff --git a/pysr/export.py b/pysr/export.py index c1b589b84..3fbf1f7b6 100644 --- a/pysr/export.py +++ b/pysr/export.py @@ -7,6 +7,7 @@ from .export_jax import sympy2jax from .export_numpy import sympy2numpy +from .export_paddle import sympy2paddle from .export_sympy import create_sympy_symbols, pysr2sympy from .export_torch import sympy2torch from .utils import ArrayLike @@ -22,6 +23,8 @@ def add_export_formats( output_torch_format: bool = False, extra_jax_mappings: dict[Callable, str] | None = None, output_jax_format: bool = False, + extra_paddle_mappings: dict[Callable, str] | None = None, + output_paddle_format: bool = False, ) -> pd.DataFrame: """Create export formats for an equations dataframe. @@ -33,6 +36,7 @@ def add_export_formats( lambda_format = [] jax_format = [] torch_format = [] + paddle_format = [] for _, eqn_row in output.iterrows(): eqn = pysr2sympy( @@ -72,6 +76,16 @@ def add_export_formats( ) torch_format.append(module) + # Paddle: + if output_paddle_format: + module = sympy2paddle( + eqn, + sympy_symbols, + selection=selection_mask, + extra_paddle_mappings=extra_paddle_mappings, + ) + paddle_format.append(module) + exports = pd.DataFrame( { "sympy_format": sympy_format, @@ -84,5 +98,7 @@ def add_export_formats( exports["jax_format"] = jax_format if output_torch_format: exports["torch_format"] = torch_format + if output_paddle_format: + exports["paddle_format"] = paddle_format return exports diff --git a/pysr/export_paddle.py b/pysr/export_paddle.py new file mode 100644 index 000000000..c61a171d6 --- /dev/null +++ b/pysr/export_paddle.py @@ -0,0 +1,224 @@ +import collections as co +import functools as ft + +import numpy as np # noqa: F401 +import sympy # type: ignore + + +def _reduce_add(*args): + return ft.reduce(lambda a, b: a + b, args) + + +def _reduce_mul(*args): + return ft.reduce(lambda a, b: a * b, args) + + +def _mod(a, b): + return a % b + + +def _div(a, b): + return a / b + + +paddle_initialized = False +paddle = None +SingleSymPyModule = None + + +def _initialize_paddle(): + global paddle_initialized + global paddle + global SingleSymPyModule + + # Way to lazy load paddle, only if this is called, + # but still allow this module to be loaded in __init__ + if not paddle_initialized: + import paddle as _paddle + + paddle = _paddle + + _global_func_lookup = { + sympy.Mul: _reduce_mul, + sympy.Add: _reduce_add, + sympy.div: _div, + sympy.Abs: paddle.abs, + sympy.sign: paddle.sign, + # Note: May raise error for ints. + sympy.ceiling: paddle.ceil, + sympy.floor: paddle.floor, + sympy.log: paddle.log, + sympy.exp: paddle.exp, + sympy.sqrt: paddle.sqrt, + sympy.cos: paddle.cos, + sympy.acos: paddle.acos, + sympy.sin: paddle.sin, + sympy.asin: paddle.asin, + sympy.tan: paddle.tan, + sympy.atan: paddle.atan, + sympy.atan2: paddle.atan2, + # Note: May give NaN for complex results. + sympy.cosh: paddle.cosh, + sympy.acosh: paddle.acosh, + sympy.sinh: paddle.sinh, + sympy.asinh: paddle.asinh, + sympy.tanh: paddle.tanh, + sympy.atanh: paddle.atanh, + sympy.Pow: paddle.pow, + sympy.re: paddle.real, + sympy.im: paddle.imag, + sympy.arg: paddle.angle, + # Note: May raise error for ints and complexes + sympy.erf: paddle.erf, + sympy.loggamma: paddle.lgamma, + sympy.Eq: paddle.equal, + sympy.Ne: paddle.not_equal, + sympy.StrictGreaterThan: paddle.greater_than, + sympy.StrictLessThan: paddle.less_than, + sympy.LessThan: paddle.less_equal, + sympy.GreaterThan: paddle.greater_equal, + sympy.And: paddle.logical_and, + sympy.Or: paddle.logical_or, + sympy.Not: paddle.logical_not, + sympy.Max: paddle.max, + sympy.Min: paddle.min, + sympy.Mod: _mod, + sympy.Heaviside: paddle.heaviside, + sympy.core.numbers.Half: (lambda: 0.5), + sympy.core.numbers.One: (lambda: 1.0), + } + + class _Node(paddle.nn.Layer): + """Forked from https://github.com/patrick-kidger/sympypaddle""" + + def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): + super().__init__(**kwargs) + + self._sympy_func = expr.func + if issubclass(expr.func, sympy.Float): + self._value = paddle.create_parameter( + shape=[1], + dtype="float32", + default_initializer=paddle.nn.initializer.Assign( + paddle.to_tensor(float(expr)) + ), + ) + self._paddle_func = lambda: self._value + self._args = () + elif issubclass(expr.func, sympy.Rational): + # This is some fraction fixed in the operator. + self._value = float(expr) + self._paddle_func = lambda: self._value + self._args = () + elif issubclass(expr.func, sympy.UnevaluatedExpr): + if len(expr.args) != 1 or not issubclass( + expr.args[0].func, sympy.Float + ): + raise ValueError( + "UnevaluatedExpr should only be used to wrap floats." + ) + self.register_buffer( + "_value", paddle.to_tensor(float(expr.args[0])) + ) + self._paddle_func = lambda: self._value + self._args = () + elif issubclass(expr.func, sympy.Integer): + # Can get here if expr is one of the Integer special cases, + # e.g. NegativeOne + self._value = int(expr) + self._paddle_func = lambda: self._value + self._args = () + elif issubclass(expr.func, sympy.NumberSymbol): + # Can get here from exp(1) or exact pi + self._value = float(expr) + self._paddle_func = lambda: self._value + self._args = () + elif issubclass(expr.func, sympy.Symbol): + self._name = expr.name + self._paddle_func = lambda value: value + self._args = ((lambda memodict: memodict[expr.name]),) + else: + try: + self._paddle_func = _func_lookup[expr.func] + except KeyError: + raise KeyError( + f"Function {expr.func} was not found in paddle function mappings." + "Please add it to extra_paddle_mappings in the format, e.g., " + "{sympy.sqrt: paddle.sqrt}." + ) + args = [] + for arg in expr.args: + try: + arg_ = _memodict[arg] + except KeyError: + arg_ = type(self)( + expr=arg, + _memodict=_memodict, + _func_lookup=_func_lookup, + **kwargs, + ) + _memodict[arg] = arg_ + args.append(arg_) + self._args = paddle.nn.LayerList(args) + + def extra_repr(self): + return ( + f"sympy_func={self._paddle_func.__name__}" + f"{self._sympy_func}" + ) + + def forward(self, memodict): + args = [] + for arg in self._args: + try: + arg_ = memodict[arg] + except KeyError: + arg_ = arg(memodict) + memodict[arg] = arg_ + args.append(arg_) + return self._paddle_func(*args) + + class _SingleSymPyModule(paddle.nn.Layer): + + def __init__( + self, expression, symbols_in, selection=None, extra_funcs=None, **kwargs + ): + super().__init__(**kwargs) + + if extra_funcs is None: + extra_funcs = {} + _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs) + + _memodict = {} + self._node = _Node( + expr=expression, _memodict=_memodict, _func_lookup=_func_lookup + ) + self._expression_string = str(expression) + self._selection = selection + self.symbols_in = [str(symbol) for symbol in symbols_in] + + def __repr__(self): + return f"{type(self).__name__}(expression={self._expression_string})" + + def forward(self, X): + + if self._selection is not None: + X = X[:, self._selection] + symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)} + return self._node(symbols) + + SingleSymPyModule = _SingleSymPyModule + + +def sympy2paddle(expression, symbols_in, selection=None, extra_paddle_mappings=None): + """Returns a module for a given sympy expression with trainable parameters; + + This function will assume the input to the module is a matrix X, where + each column corresponds to each symbol you pass in `symbols_in`. + """ + global SingleSymPyModule + + _initialize_paddle() + + return SingleSymPyModule( + expression, symbols_in, selection=selection, extra_funcs=extra_paddle_mappings + ) diff --git a/pysr/expression_specs.py b/pysr/expression_specs.py index f9a6eee7c..2e21b08bd 100644 --- a/pysr/expression_specs.py +++ b/pysr/expression_specs.py @@ -78,6 +78,10 @@ def supports_jax(self) -> bool: def supports_latex(self) -> bool: return False + @property + def supports_paddle(self) -> bool: + return False + class ExpressionSpec(AbstractExpressionSpec): """The default expression specification, with no special behavior.""" @@ -99,10 +103,12 @@ def create_exports( feature_names_in=model.feature_names_in_, selection_mask=model.selection_mask_, extra_sympy_mappings=model.extra_sympy_mappings, - extra_torch_mappings=model.extra_torch_mappings, - output_jax_format=model.output_jax_format, extra_jax_mappings=model.extra_jax_mappings, + output_jax_format=model.output_jax_format, + extra_torch_mappings=model.extra_torch_mappings, output_torch_format=model.output_torch_format, + extra_paddle_mappings=model.extra_paddle_mappings, + output_paddle_format=model.output_paddle_format, ) @property @@ -121,6 +127,10 @@ def supports_jax(self): def supports_latex(self): return True + @property + def supports_paddle(self): + return True + class TemplateExpressionSpec(AbstractExpressionSpec): """Spec for templated expressions. diff --git a/pysr/sr.py b/pysr/sr.py index 002093780..3b5929523 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -12,7 +12,7 @@ from io import StringIO from multiprocessing import cpu_count from pathlib import Path -from typing import Any, Literal, cast +from typing import Any, Dict, Literal, Optional, cast import numpy as np import pandas as pd @@ -31,6 +31,7 @@ sympy2multilatextable, with_preamble, ) +from .export_paddle import sympy2paddle from .export_sympy import assert_valid_sympy_symbol from .expression_specs import ( AbstractExpressionSpec, @@ -205,8 +206,10 @@ def _check_assertions( ) -def _validate_export_mappings(extra_jax_mappings, extra_torch_mappings): - # It is expected extra_jax/torch_mappings will be updated after fit. +def _validate_export_mappings( + extra_jax_mappings, extra_torch_mappings, extra_paddle_mappings +): + # It is expected extra_jax/torch/paddle_mappings will be updated after fit. # Thus, validation is performed here instead of in _validate_init_params if extra_jax_mappings is not None: for value in extra_jax_mappings.values(): @@ -223,6 +226,14 @@ def _validate_export_mappings(extra_jax_mappings, extra_torch_mappings): "e.g., {sympy.sqrt: torch.sqrt}." ) + if extra_paddle_mappings is not None: + for value in extra_paddle_mappings.values(): + if not callable(value): + raise ValueError( + "extra_paddle_mappings must be callable functions! " + "e.g., {sympy.sqrt: paddle.sqrt}." + ) + # Class validation constants VALID_OPTIMIZER_ALGORITHMS = ["BFGS", "NelderMead"] @@ -661,6 +672,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): Whether to create a 'torch_format' column in the output, containing a torch module with trainable parameters. Default is `False`. + output_paddle_format : bool + Whether to create a 'paddle_format' column in the output, + containing a paddle module with trainable parameters. + Default is `False`. extra_sympy_mappings : dict[str, Callable] Provides mappings between custom `binary_operators` or `unary_operators` defined in julia strings, to those same @@ -681,6 +696,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): pytorch expressions. For example: `extra_torch_mappings={sympy.sin: torch.sin}`. Default is `None`. + extra_paddle_mappings : dict[Callable, Callable] + The same as `extra_jax_mappings` but for model export + to paddle. Note that the dictionary keys should be callable + paddle expressions. + For example: `extra_paddle_mappings={sympy.sin: paddle.sin}`. + Default is `None`. denoise : bool Whether to use a Gaussian Process to denoise the data before inputting to PySR. Can help PySR fit noisy data. @@ -881,9 +902,11 @@ def __init__( update: bool = False, output_jax_format: bool = False, output_torch_format: bool = False, - extra_sympy_mappings: dict[str, Callable] | None = None, - extra_torch_mappings: dict[Callable, Callable] | None = None, - extra_jax_mappings: dict[Callable, str] | None = None, + output_paddle_format: bool = False, + extra_sympy_mappings: Optional[Dict[str, Callable]] = None, + extra_torch_mappings: Optional[Dict[Callable, Callable]] = None, + extra_jax_mappings: Optional[Dict[Callable, str]] = None, + extra_paddle_mappings: Optional[Dict[Callable, Callable]] = None, denoise: bool = False, select_k_features: int | None = None, **kwargs, @@ -989,9 +1012,11 @@ def __init__( self.update = update self.output_jax_format = output_jax_format self.output_torch_format = output_torch_format + self.output_paddle_format = output_paddle_format self.extra_sympy_mappings = extra_sympy_mappings self.extra_jax_mappings = extra_jax_mappings self.extra_torch_mappings = extra_torch_mappings + self.extra_paddle_mappings = extra_paddle_mappings # Pre-modelling transformation self.denoise = denoise self.select_k_features = select_k_features @@ -1219,7 +1244,11 @@ def __getstate__(self) -> dict[str, Any]: show_pickle_warning = not ( "show_pickle_warnings_" in state and not state["show_pickle_warnings_"] ) - state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"] + state_keys_containing_lambdas = [ + "extra_sympy_mappings", + "extra_torch_mappings", + "extra_paddle_mappings", + ] for state_key in state_keys_containing_lambdas: if state[state_key] is not None and show_pickle_warning: warnings.warn( @@ -1238,16 +1267,19 @@ def __getstate__(self) -> dict[str, Any]: ): pickled_state["output_torch_format"] = False pickled_state["output_jax_format"] = False + pickled_state["output_paddle_format"] = False if self.nout_ == 1: pickled_columns = ~pickled_state["equations_"].columns.isin( - ["jax_format", "torch_format"] + ["jax_format", "torch_format", "paddle_format"] ) pickled_state["equations_"] = ( pickled_state["equations_"].loc[:, pickled_columns].copy() ) else: pickled_columns = [ - ~dataframe.columns.isin(["jax_format", "torch_format"]) + ~dataframe.columns.isin( + ["jax_format", "torch_format", "paddle_format"] + ) for dataframe in pickled_state["equations_"] ] pickled_state["equations_"] = [ @@ -2544,6 +2576,41 @@ def pytorch(self, index=None): else: return best_equation["torch_format"] + def paddle(self, index=None): + """ + Return paddle representation of the equation(s) chosen by `model_selection`. + + Each equation (multiple given if there are multiple outputs) is a PaddlePaddle module + containing the parameters as trainable attributes. You can use the module like + any other PaddlePaddle module: `module(X)`, where `X` is a tensor with the same + column ordering as trained with. + + Parameters + ---------- + index : int | list[int] + If you wish to select a particular equation from + `self.equations_`, give the index number here. This overrides + the `model_selection` parameter. If there are multiple output + features, then pass a list of indices with the order the same + as the output feature. + + Returns + ------- + best_equation : paddle.nn.Layer + PaddlePaddle module representing the expression. + """ + if not self.expression_spec_.supports_paddle: + raise ValueError( + f"`expression_spec={self.expression_spec_}` does not support paddle export." + ) + self.set_params(output_paddle_format=True) + self.refresh() + best_equation = self.get_best(index=index) + if isinstance(best_equation, list): + return [eq["paddle_format"] for eq in best_equation] + else: + return best_equation["paddle_format"] + def get_equation_file(self, i: int | None = None) -> Path: if i is not None: return ( @@ -2619,7 +2686,11 @@ def get_hof(self, search_output=None) -> pd.DataFrame | list[pd.DataFrame]: if should_read_from_file: self.equation_file_contents_ = self._read_equation_file() - _validate_export_mappings(self.extra_jax_mappings, self.extra_torch_mappings) + _validate_export_mappings( + self.extra_jax_mappings, + self.extra_torch_mappings, + self.extra_paddle_mappings, + ) equation_file_contents = cast(list[pd.DataFrame], self.equation_file_contents_) diff --git a/pysr/test/__init__.py b/pysr/test/__init__.py index 4d977cccf..a1016e30e 100644 --- a/pysr/test/__init__.py +++ b/pysr/test/__init__.py @@ -2,6 +2,7 @@ from .test_dev import runtests as runtests_dev from .test_jax import runtests as runtests_jax from .test_main import runtests +from .test_paddle import runtests as runtests_paddle from .test_startup import runtests as runtests_startup from .test_torch import runtests as runtests_torch @@ -9,6 +10,7 @@ "runtests", "runtests_jax", "runtests_torch", + "runtests_paddle", "get_runtests_cli", "runtests_startup", "runtests_dev", diff --git a/pysr/test/test_paddle.py b/pysr/test/test_paddle.py new file mode 100644 index 000000000..af6914e94 --- /dev/null +++ b/pysr/test/test_paddle.py @@ -0,0 +1,243 @@ +import unittest +from pathlib import Path + +import numpy as np +import paddle +import pandas as pd +import sympy # type: ignore + +import pysr +from pysr import PySRRegressor, sympy2paddle + + +class TestPaddle(unittest.TestCase): + def setUp(self): + np.random.seed(0) + + def test_sympy2paddle(self): + + x, y, z = sympy.symbols("x y z") + cosx = 1.0 * sympy.cos(x) + y + + X = paddle.to_tensor(np.random.randn(1000, 3)) + true = 1.0 * paddle.cos(X[:, 0]) + X[:, 1] + paddle_module = sympy2paddle(cosx, [x, y, z]) + + self.assertTrue( + np.all(np.isclose(paddle_module(X).detach().numpy(), true.detach().numpy())) + ) + + def test_pipeline_pandas(self): + X = pd.DataFrame(np.random.randn(100, 10)) + y = np.ones(X.shape[0]) + model = PySRRegressor( + progress=False, + max_evals=10000, + model_selection="accuracy", + extra_sympy_mappings={}, + output_paddle_format=True, + ) + model.fit(X, y) + + equations = pd.DataFrame( + { + "Equation": ["1.0", "cos(x1)", "square(cos(x1))"], + "Loss": [1.0, 0.1, 1e-5], + "Complexity": [1, 2, 3], + } + ) + + for fname in ["hall_of_fame.csv.bak", "hall_of_fame.csv"]: + equations["Complexity Loss Equation".split(" ")].to_csv( + Path(model.output_directory_) / model.run_id_ / fname + ) + + model.refresh(run_directory=str(Path(model.output_directory_) / model.run_id_)) + + pdformat = model.paddle() + self.assertEqual(str(pdformat), "_SingleSymPyModule(expression=cos(x1)**2)") + + np.testing.assert_almost_equal( + pdformat(paddle.to_tensor(X.values)).detach().numpy(), + np.square(np.cos(X.values[:, 1])), # Selection 1st feature + decimal=3, + ) + + def test_pipeline(self): + X = np.random.randn(100, 10) + y = np.ones(X.shape[0]) + model = PySRRegressor( + progress=False, + max_evals=10000, + model_selection="accuracy", + output_paddle_format=True, + ) + model.fit(X, y) + + equations = pd.DataFrame( + { + "Equation": ["1.0", "cos(x1)", "square(cos(x1))"], + "Loss": [1.0, 0.1, 1e-5], + "Complexity": [1, 2, 3], + } + ) + + for fname in ["hall_of_fame.csv.bak", "hall_of_fame.csv"]: + equations["Complexity Loss Equation".split(" ")].to_csv( + Path(model.output_directory_) / model.run_id_ / fname + ) + + model.refresh(run_directory=str(Path(model.output_directory_) / model.run_id_)) + + pdformat = model.paddle() + self.assertEqual(str(pdformat), "_SingleSymPyModule(expression=cos(x1)**2)") + + np.testing.assert_almost_equal( + pdformat(paddle.to_tensor(X)).detach().numpy(), + np.square(np.cos(X[:, 1])), # 2nd feature + decimal=3, + ) + + def test_mod_mapping(self): + + x, y, z = sympy.symbols("x y z") + expression = x**2 + sympy.atanh(sympy.Mod(y + 1, 2) - 1) * 3.2 * z + + module = sympy2paddle(expression, [x, y, z]) + + X = paddle.rand((100, 3)).astype("float32") * 10 + + true_out = ( + X[:, 0] ** 2 + + paddle.atanh( + paddle.mod(X[:, 1] + 1, paddle.to_tensor(2).astype("float32")) - 1 + ) + * 3.2 + * X[:, 2] + ) + paddle_out = module(X) + + np.testing.assert_array_almost_equal( + true_out.detach(), paddle_out.detach(), decimal=3 + ) + + def test_custom_operator(self): + X = np.random.randn(100, 3) + y = np.ones(X.shape[0]) + model = PySRRegressor( + progress=False, + max_evals=10000, + model_selection="accuracy", + output_paddle_format=True, + ) + model.fit(X, y) + + equations = pd.DataFrame( + { + "Equation": ["1.0", "mycustomoperator(x1)"], + "Loss": [1.0, 0.1], + "Complexity": [1, 2], + } + ) + + for fname in ["hall_of_fame.csv.bak", "hall_of_fame.csv"]: + equations["Complexity Loss Equation".split(" ")].to_csv( + Path(model.output_directory_) / model.run_id_ / fname + ) + + MyCustomOperator = sympy.Function("mycustomoperator") + + model.set_params( + extra_sympy_mappings={"mycustomoperator": MyCustomOperator}, + extra_paddle_mappings={MyCustomOperator: paddle.sin}, + ) + # TODO: We shouldn't need to specify the run directory here. + model.refresh(run_directory=str(Path(model.output_directory_) / model.run_id_)) + # self.assertEqual(str(model.sympy()), "sin(x1)") + # Will automatically use the set global state from get_hof. + + pdformat = model.paddle() + self.assertEqual( + str(pdformat), "_SingleSymPyModule(expression=mycustomoperator(x1))" + ) + + np.testing.assert_almost_equal( + pdformat(paddle.to_tensor(X)).detach().numpy(), + np.sin(X[:, 1]), + decimal=3, + ) + + def test_avoid_simplification(self): + # SymPy should not simplify without permission + + ex = pysr.export_sympy.pysr2sympy( + "square(exp(sign(0.44796443))) + 1.5 * x1", + # ^ Normally this would become exp1 and require + # its own mapping + feature_names_in=["x1"], + extra_sympy_mappings={"square": lambda x: x**2}, + ) + m = pysr.export_paddle.sympy2paddle(ex, ["x1"]) + rng = np.random.RandomState(0) + X = rng.randn(10, 1) + np.testing.assert_almost_equal( + m(paddle.to_tensor(X)).detach().numpy(), + np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0], + decimal=3, + ) + + def test_issue_656(self): + + # Should correctly map numeric symbols to floats + E_plus_x1 = sympy.exp(1) + sympy.symbols("x1") + m = pysr.export_paddle.sympy2paddle(E_plus_x1, ["x1"]) + X = np.random.randn(10, 1) + np.testing.assert_almost_equal( + m(paddle.to_tensor(X)).detach().numpy(), + np.exp(1) + X[:, 0], + decimal=3, + ) + + def test_feature_selection_custom_operators(self): + rstate = np.random.RandomState(0) + X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)}) + + def cos_approx(x): + return 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720 + + y = X["k15"] ** 2 + 2 * cos_approx(X["k20"]) + + model = PySRRegressor( + progress=False, + unary_operators=["cos_approx(x) = 1 - x^2 / 2 + x^4 / 24 + x^6 / 720"], + select_k_features=3, + maxsize=10, + early_stop_condition=1e-5, + extra_sympy_mappings={"cos_approx": cos_approx}, + random_state=0, + deterministic=True, + parallelism="serial", + ) + np.random.seed(0) + model.fit(X.values, y.values) + paddle_module = model.paddle() + + np_output = model.predict(X.values) + + paddle_output = paddle_module(paddle.to_tensor(X.values)).detach().numpy() + + np.testing.assert_almost_equal(y.values, np_output, decimal=3) + np.testing.assert_almost_equal(y.values, paddle_output, decimal=3) + + +def runtests(just_tests=False): + """Run all tests in test_paddle.py.""" + tests = [TestPaddle] + if just_tests: + return tests + loader = unittest.TestLoader() + suite = unittest.TestSuite() + for test in tests: + suite.addTests(loader.loadTestsFromTestCase(test)) + runner = unittest.TextTestRunner() + return runner.run(suite)