@@ -546,8 +546,9 @@ y = np.sin(X[:, 0] + X[:, 1]) + X[:, 2]**2
546546
547547# Define template: we want sin(f(x1, x2)) + g(x3)
548548template = TemplateExpressionSpec(
549- function_symbols = [" f" , " g" ],
550- combine = " ((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)" ,
549+ expressions = [" f" , " g" ],
550+ variable_names = [" x1" , " x2" , " x3" ],
551+ combine = " sin(f(x1, x2)) + g(x3)" ,
551552)
552553
553554model = PySRRegressor(
@@ -559,15 +560,23 @@ model = PySRRegressor(
559560model.fit(X, y)
560561```
561562
562- You can also use no argument-functions for learning constants, like :
563+ You can also use parameters in your template expressions, which will be optimized during the search :
563564
564565``` python
565566template = TemplateExpressionSpec(
566- function_symbols = [" a" , " f" ],
567- combine = " ((; a, f), (x, y)) -> a() * sin(f(x, y))" ,
567+ expressions = [" f" , " g" ],
568+ variable_names = [" x1" , " x2" , " x3" ],
569+ parameters = {" p1" : 2 , " p2" : 1 }, # p1 has length 2, p2 has length 1
570+ combine = " p1[1] * sin(f(x1, x2)) + p1[2] * g(x3) + p2[1]" ,
568571)
569572```
570573
574+ This will learn an equation of the form:
575+
576+ $$ y = \alpha_1 \sin(f(x_1, x_2)) + \alpha_2 g(x_3) + \beta $$
577+
578+ where $\alpha_1, \alpha_2$ are stored in ` p1 ` and $\beta$ is stored in ` p2 ` . The parameters will be optimized during the search.
579+
571580### Parametric Expressions
572581
573582When your data has categories with shared equation structure but different parameters,
@@ -609,6 +618,20 @@ model.fit(X, y, category=category)
609618
610619See [ Expression Specifications] ( /api/#expression-specifications ) for more details.
611620
621+ You can also use ` TemplateExpressionSpec ` in the same way, passing
622+ the category as a column of ` X ` :
623+
624+ ``` python
625+ spec = TemplateExpressionSpec(
626+ expressions = [" f" , " g" ],
627+ variable_names = [" x1" , " x2" , " class" ],
628+ combine = " p1[class] * sin(f(x1, x2)) + p2[class]" ,
629+ )
630+ ```
631+
632+ this column will automatically be converted to integers.
633+
634+
612635## 12. Using TensorBoard for Logging
613636
614637You can use TensorBoard to visualize the search progress, as well as
0 commit comments