Skip to content

Commit 8834cd0

Browse files
author
Alessandro Lucantonio
committed
Adapting fit, predict and score fn args. Still working on tests.
1 parent 2f8cc63 commit 8834cd0

File tree

6 files changed

+79
-72
lines changed

6 files changed

+79
-72
lines changed

bench/bench.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# import matplotlib.pyplot as plt
22
from deap import gp
33

4-
from alpine.gp import gpsymbreg as gps
4+
from alpine.gp import regressor as gps
55
from alpine.data import Dataset
66
from alpine.gp import util
77
import numpy as np

examples/simple_sr.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from deap import gp
3-
from alpine.gp.gpsymbreg import GPSymbolicRegressor
3+
from alpine.gp.regressor import GPSymbolicRegressor
44
from alpine.data import Dataset
55
import numpy as np
66
import ray
@@ -40,44 +40,44 @@ def get_features_batch(
4040
return individ_length, nested_trigs, num_trigs
4141

4242

43-
def eval_MSE_sol(individual, true_data):
43+
def eval_MSE_sol(individual, X, y):
4444
warnings.filterwarnings("ignore")
4545

46-
y_pred = individual(true_data.X)
47-
MSE = np.mean(np.square(y_pred - true_data.y))
46+
y_pred = individual(X)
47+
MSE = np.mean(np.square(y_pred - y))
4848
if np.isnan(MSE):
4949
MSE = 1e5
5050
return MSE, y_pred
5151

5252

5353
@ray.remote
54-
def predict(individuals_str, toolbox, true_data, penalty):
54+
def predict(individuals_str, toolbox, X_test, penalty):
5555

5656
callables = compile_individuals(toolbox, individuals_str)
5757

5858
u = [None] * len(individuals_str)
5959

6060
for i, ind in enumerate(callables):
61-
_, u[i] = eval_MSE_sol(ind, true_data)
61+
_, u[i] = eval_MSE_sol(ind, X_test, None)
6262

6363
return u
6464

6565

6666
@ray.remote
67-
def score(individuals_str, toolbox, true_data, penalty):
67+
def score(individuals_str, toolbox, X_test, y_test, penalty):
6868

6969
callables = compile_individuals(toolbox, individuals_str)
7070

7171
MSE = [None] * len(individuals_str)
7272

7373
for i, ind in enumerate(callables):
74-
MSE[i], _ = eval_MSE_sol(ind, true_data)
74+
MSE[i], _ = eval_MSE_sol(ind, X_test, y_test)
7575

7676
return MSE
7777

7878

7979
@ray.remote
80-
def fitness(individuals_str, toolbox, true_data, penalty):
80+
def fitness(individuals_str, toolbox, X_train, y_train, penalty):
8181
callables = compile_individuals(toolbox, individuals_str)
8282

8383
individ_length, nested_trigs, num_trigs = get_features_batch(individuals_str)
@@ -87,7 +87,7 @@ def fitness(individuals_str, toolbox, true_data, penalty):
8787
if individ_length[i] >= 50:
8888
fitnesses[i] = (1e8,)
8989
else:
90-
MSE, _ = eval_MSE_sol(ind, true_data)
90+
MSE, _ = eval_MSE_sol(ind, X_train, y_train)
9191

9292
fitnesses[i] = (
9393
MSE
@@ -131,8 +131,7 @@ def main():
131131
**regressor_params
132132
)
133133

134-
train_data = Dataset("true_data", x, y)
135-
gpsr.fit(train_data)
134+
gpsr.fit(x, y)
136135

137136
ray.shutdown()
138137

examples/simple_sr_noyaml.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from deap import gp
2-
from alpine.gp.gpsymbreg import GPSymbolicRegressor
2+
from alpine.gp.regressor import GPSymbolicRegressor
33
from alpine.data import Dataset
44
import numpy as np
55
import ray
@@ -39,44 +39,44 @@ def get_features_batch(
3939
return individ_length, nested_trigs, num_trigs
4040

4141

42-
def eval_MSE_sol(individual, true_data):
42+
def eval_MSE_sol(individual, X, y):
4343
warnings.filterwarnings("ignore")
4444

45-
y_pred = individual(true_data.X)
46-
MSE = np.mean(np.square(y_pred - true_data.y))
45+
y_pred = individual(X)
46+
MSE = np.mean(np.square(y_pred - y))
4747
if np.isnan(MSE):
4848
MSE = 1e5
4949
return MSE, y_pred
5050

5151

5252
@ray.remote
53-
def predict(individuals_str, toolbox, true_data, penalty):
53+
def predict(individuals_str, toolbox, X_test, penalty):
5454

5555
callables = compile_individuals(toolbox, individuals_str)
5656

5757
u = [None] * len(individuals_str)
5858

5959
for i, ind in enumerate(callables):
60-
_, u[i] = eval_MSE_sol(ind, true_data)
60+
_, u[i] = eval_MSE_sol(ind, X_test, None)
6161

6262
return u
6363

6464

6565
@ray.remote
66-
def score(individuals_str, toolbox, true_data, penalty):
66+
def score(individuals_str, toolbox, X_test, y_test, penalty):
6767

6868
callables = compile_individuals(toolbox, individuals_str)
6969

7070
MSE = [None] * len(individuals_str)
7171

7272
for i, ind in enumerate(callables):
73-
MSE[i], _ = eval_MSE_sol(ind, true_data)
73+
MSE[i], _ = eval_MSE_sol(ind, X_test, y_test)
7474

7575
return MSE
7676

7777

7878
@ray.remote
79-
def fitness(individuals_str, toolbox, true_data, penalty):
79+
def fitness(individuals_str, toolbox, X_train, y_train, penalty):
8080
callables = compile_individuals(toolbox, individuals_str)
8181

8282
individ_length, nested_trigs, num_trigs = get_features_batch(individuals_str)
@@ -86,7 +86,7 @@ def fitness(individuals_str, toolbox, true_data, penalty):
8686
if individ_length[i] >= 50:
8787
fitnesses[i] = (1e8,)
8888
else:
89-
MSE, _ = eval_MSE_sol(ind, true_data)
89+
MSE, _ = eval_MSE_sol(ind, X_train, y_train)
9090

9191
fitnesses[i] = (
9292
MSE
@@ -145,8 +145,7 @@ def main():
145145
batch_size=100,
146146
)
147147

148-
train_data = Dataset("true_data", x, y)
149-
gpsr.fit(train_data)
148+
gpsr.fit(x, y)
150149

151150
ray.shutdown()
152151

src/alpine/gp/gpsymbreg.py src/alpine/gp/regressor.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import ray
1212
import random
1313
from itertools import chain
14+
from sklearn.base import BaseEstimator, RegressorMixin
1415

1516
# reducing the number of threads launched by fitness evaluations
1617
os.environ["MKL_NUM_THREADS"] = "1"
@@ -24,7 +25,7 @@
2425
)
2526

2627

27-
class GPSymbolicRegressor:
28+
class GPSymbolicRegressor(RegressorMixin, BaseEstimator):
2829
"""Symbolic regression problem via Genetic Programming.
2930
3031
Args:
@@ -130,7 +131,7 @@ def __init__(
130131

131132
if common_data is not None:
132133
# FIXME: does everything work when the functions do not have common args?
133-
self.store_fit_error_common_args(common_data)
134+
self.__store_fit_error_common_args(common_data)
134135

135136
self.NINDIVIDUALS = NINDIVIDUALS
136137
self.NGEN = NGEN
@@ -161,9 +162,6 @@ def __init__(
161162

162163
# config individual creator and toolbox
163164
self.__creator_toolbox_config()
164-
# self.createIndividual = individualCreator
165-
166-
# self.toolbox = toolbox
167165

168166
self.seed = seed
169167

@@ -253,7 +251,7 @@ def __creator_toolbox_config(self):
253251

254252
self.createIndividual = createIndividual
255253

256-
def store_fit_error_common_args(self, data: Dict):
254+
def __store_fit_error_common_args(self, data: Dict):
257255
"""Store names and values of the arguments that are in common between
258256
the fitness and the error metric functions in the common object space.
259257
@@ -262,7 +260,7 @@ def store_fit_error_common_args(self, data: Dict):
262260
"""
263261
self.__store_shared_objects("common", data)
264262

265-
def store_datasets(self, datasets: Dict[str, Dataset]):
263+
def __store_datasets(self, datasets: Dict[str, Dataset]):
266264
"""Store datasets with the corresponding label ("train", "val" or "test")
267265
in the common object space. The datasets are passed as parameters to
268266
the fitness, and possibly to the error metric and the prediction functions.
@@ -272,12 +270,12 @@ def store_datasets(self, datasets: Dict[str, Dataset]):
272270
the validation and the test datasets, respectively. The associated
273271
values are `Dataset` objects.
274272
"""
275-
for dataset_label in datasets.keys():
276-
dataset_name_data = {datasets[dataset_label].name: datasets[dataset_label]}
277-
self.__store_shared_objects(dataset_label, dataset_name_data)
273+
for dataset_label, dataset_data in datasets.items():
274+
self.__store_shared_objects(dataset_label, dataset_data)
278275

279276
def __store_shared_objects(self, label: str, data: Dict):
280277
for key, value in data.items():
278+
# replace each item of the dataset with its obj ref
281279
data[key] = ray.put(value)
282280
self.data_store[label] = data
283281

@@ -414,31 +412,35 @@ def mapper(f, individuals, toolbox_ref):
414412
toolbox_ref = ray.put(self.toolbox)
415413
self.toolbox.register("map", mapper, toolbox_ref=toolbox_ref)
416414

417-
def fit(self, train_data: Dataset, val_data: Dataset | None = None):
415+
def fit(self, X_train, y_train=None, X_val=None, y_val=None):
418416
"""Fits the training data using GP-based symbolic regression."""
419-
if self.validate and val_data is not None:
417+
train_data = {"X_train": X_train, "y_train": y_train}
418+
if self.validate and X_val is not None:
419+
val_data = {"X_val": X_val, "y_val": y_val}
420420
datasets = {"train": train_data, "val": val_data}
421421
else:
422422
datasets = {"train": train_data}
423-
self.store_datasets(datasets)
423+
self.__store_datasets(datasets)
424424
self.__register_fitness_func()
425425
if self.validate and self.error_metric is not None:
426426
self.__register_val_funcs()
427427
self.__run()
428428

429-
def predict(self, test_data: Dataset):
429+
def predict(self, X_test):
430+
test_data = {"X_test": X_test}
430431
datasets = {"test": test_data}
431-
self.store_datasets(datasets)
432+
self.__store_datasets(datasets)
432433
self.__register_predict_func()
433434
u_best = self.toolbox.map(self.toolbox.evaluate_test_sols, (self.best,))[0]
434435
return u_best
435436

436-
def score(self, test_data: Dataset):
437+
def score(self, X_test, y_test):
437438
"""Computes the error metric (passed to the `GPSymbolicRegressor` constructor)
438439
on a given dataset.
439440
"""
441+
test_data = {"X_test": X_test, "y_test": y_test}
440442
datasets = {"test": test_data}
441-
self.store_datasets(datasets)
443+
self.__store_datasets(datasets)
442444
self.__register_score_func()
443445
score = self.toolbox.map(self.toolbox.evaluate_test_score, (self.best,))[0]
444446
return score

tests/test_basic_sr.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from dctkit import config
33
from deap import gp
4-
from alpine.gp.gpsymbreg import GPSymbolicRegressor
4+
from alpine.gp.regressor import GPSymbolicRegressor
55
from alpine.data import Dataset
66
from alpine.gp import util
77
import jax.numpy as jnp
@@ -21,53 +21,56 @@ def compile_individuals(toolbox, individuals_str_batch):
2121
y = x**4 + x**3 + x**2 + x
2222

2323

24-
def eval_MSE_sol(individual, true_data):
24+
def eval_MSE_sol(individual, X, y):
2525
import os
2626

2727
os.environ["JAX_PLATFORMS"] = "cpu"
2828
config()
2929
# Evaluate the mean squared error between the expression
3030
# and the real function : x**4 + x**3 + x**2 + x
31-
y_pred = individual(true_data.X)
32-
MSE = jnp.sum(jnp.square(y_pred - true_data.y)) / len(true_data.X)
33-
if jnp.isnan(MSE):
34-
MSE = 1e5
31+
y_pred = individual(X)
32+
MSE = None
33+
34+
if y is not None:
35+
MSE = jnp.mean(jnp.sum(jnp.square(y_pred - y)))
36+
MSE = jnp.nan_to_num(MSE, nan=1e5)
37+
3538
return MSE, y_pred
3639

3740

3841
@ray.remote
39-
def predict(individuals_str, toolbox, true_data):
42+
def predict(individuals_str, toolbox, X_test):
4043

4144
callables = compile_individuals(toolbox, individuals_str)
4245

4346
u = [None] * len(individuals_str)
4447

4548
for i, ind in enumerate(callables):
46-
_, u[i] = eval_MSE_sol(ind, true_data)
49+
_, u[i] = eval_MSE_sol(ind, X_test, None)
4750

4851
return u
4952

5053

5154
@ray.remote
52-
def score(individuals_str, toolbox, true_data):
55+
def score(individuals_str, toolbox, X_test, y_test):
5356

5457
callables = compile_individuals(toolbox, individuals_str)
5558

5659
MSE = [None] * len(individuals_str)
5760

5861
for i, ind in enumerate(callables):
59-
MSE[i], _ = eval_MSE_sol(ind, true_data)
62+
MSE[i], _ = eval_MSE_sol(ind, X_test, y_test)
6063

6164
return MSE
6265

6366

6467
@ray.remote
65-
def fitness(individuals_str, toolbox, true_data):
68+
def fitness(individuals_str, toolbox, X_train, y_train):
6669
callables = compile_individuals(toolbox, individuals_str)
6770

6871
fitnesses = [None] * len(individuals_str)
6972
for i, ind in enumerate(callables):
70-
MSE, _ = eval_MSE_sol(ind, true_data)
73+
MSE, _ = eval_MSE_sol(ind, X_train, y_train)
7174

7275
fitnesses[i] = (MSE,)
7376

@@ -110,10 +113,12 @@ def test_basic_sr(set_test_dir):
110113
**regressor_params
111114
)
112115

113-
train_data = Dataset("true_data", x, y)
114-
gpsr.fit(train_data)
116+
# train_data = Dataset("true_data", x, y)
117+
gpsr.fit(x, y)
118+
119+
fit_score = gpsr.score(x, y)
115120

116-
fit_score = gpsr.score(train_data)
121+
y_pred = gpsr.predict(x)
117122

118123
ray.shutdown()
119124

0 commit comments

Comments
 (0)