Skip to content

Commit 0d9b1c3

Browse files
author
Alessandro Lucantonio
committed
Working on check regressor. Cannot pickle fitness and score functions when passed as parameters.
1 parent 5e332cf commit 0d9b1c3

File tree

4 files changed

+82
-62
lines changed

4 files changed

+82
-62
lines changed

src/alpine/gp/regressor.py

+61-57
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,6 @@ def __init__(
111111
self.error_metric = error_metric
112112
self.predict_func = predict_func
113113

114-
self.data_store = dict()
115-
116114
self.plot_best = plot_best
117115

118116
self.plot_best_genealogy = plot_best_genealogy
@@ -123,9 +121,9 @@ def __init__(
123121
self.plot_freq = plot_freq
124122
self.preprocess_func = preprocess_func
125123
self.callback_func = callback_func
126-
self.is_plot_best_individual_tree = plot_best_individual_tree
127-
self.is_save_best_individual = save_best_individual
128-
self.is_save_train_fit_history = save_train_fit_history
124+
self.plot_best_individual_tree = plot_best_individual_tree
125+
self.save_best_individual = save_best_individual
126+
self.save_train_fit_history = save_train_fit_history
129127
self.output_path = output_path
130128
self.batch_size = batch_size
131129

@@ -157,48 +155,14 @@ def __init__(
157155

158156
self.frac_elitist = frac_elitist
159157

160-
# Elitism settings
161-
self.n_elitist = int(self.frac_elitist * self.NINDIVIDUALS)
162-
163-
if self.common_data is not None:
164-
# FIXME: does everything work when the functions do not have common args?
165-
self.__store_fit_error_common_args(self.common_data)
166-
167-
# config individual creator and toolbox
168-
self.__creator_toolbox_config()
169-
170158
self.seed = seed
171159

172160
if self.seed is not None:
173161
self.seed = [self.createIndividual.from_string(i, pset) for i in seed]
174162

175-
# Initialize variables for statistics
176-
self.stats_fit = tools.Statistics(lambda ind: ind.fitness.values)
177-
self.stats_size = tools.Statistics(len)
178-
self.mstats = tools.MultiStatistics(
179-
fitness=self.stats_fit, size=self.stats_size
180-
)
181-
self.mstats.register("avg", lambda ind: np.around(np.mean(ind), 4))
182-
self.mstats.register("std", lambda ind: np.around(np.std(ind), 4))
183-
self.mstats.register("min", lambda ind: np.around(np.min(ind), 4))
184-
self.mstats.register("max", lambda ind: np.around(np.max(ind), 4))
185-
186-
self.__init_logbook()
187-
188-
self.train_fit_history = []
189-
190-
# Create history object to build the genealogy tree
191-
self.history = tools.History()
192-
193-
if self.plot_best_genealogy:
194-
# Decorators for history
195-
self.toolbox.decorate("mate", self.history.decorator)
196-
self.toolbox.decorate("mutate", self.history.decorator)
197-
198-
self.__register_map()
199-
200-
self.plot_initialized = False
201-
self.fig_id = 0
163+
@property
164+
def n_elitist(self):
165+
return int(self.frac_elitist * self.NINDIVIDUALS)
202166

203167
def get_params(self, deep=True):
204168
return self.__dict__
@@ -420,9 +384,48 @@ def mapper(f, individuals, toolbox_ref):
420384
toolbox_ref = ray.put(self.toolbox)
421385
self.toolbox.register("map", mapper, toolbox_ref=toolbox_ref)
422386

423-
def fit(self, X_train, y_train=None, X_val=None, y_val=None):
387+
def fit(self, X, y=None, X_val=None, y_val=None):
424388
"""Fits the training data using GP-based symbolic regression."""
425-
train_data = {"X": X_train, "y": y_train}
389+
390+
if not hasattr(self, "_is_fitted"):
391+
self.data_store = dict()
392+
393+
if self.common_data is not None:
394+
# FIXME: does everything work when the functions do not have common args?
395+
self.__store_fit_error_common_args(self.common_data)
396+
397+
# config individual creator and toolbox
398+
self.__creator_toolbox_config()
399+
400+
# Initialize variables for statistics
401+
self.stats_fit = tools.Statistics(lambda ind: ind.fitness.values)
402+
self.stats_size = tools.Statistics(len)
403+
self.mstats = tools.MultiStatistics(
404+
fitness=self.stats_fit, size=self.stats_size
405+
)
406+
self.mstats.register("avg", lambda ind: np.around(np.mean(ind), 4))
407+
self.mstats.register("std", lambda ind: np.around(np.std(ind), 4))
408+
self.mstats.register("min", lambda ind: np.around(np.min(ind), 4))
409+
self.mstats.register("max", lambda ind: np.around(np.max(ind), 4))
410+
411+
self.__init_logbook()
412+
413+
self.train_fit_history = []
414+
415+
# Create history object to build the genealogy tree
416+
self.history = tools.History()
417+
418+
if self.plot_best_genealogy:
419+
# Decorators for history
420+
self.toolbox.decorate("mate", self.history.decorator)
421+
self.toolbox.decorate("mutate", self.history.decorator)
422+
423+
self.__register_map()
424+
425+
self.plot_initialized = False
426+
self.fig_id = 0
427+
428+
train_data = {"X": X, "y": y}
426429
if self.validate and X_val is not None:
427430
val_data = {"X": X_val, "y": y_val}
428431
datasets = {"train": train_data, "val": val_data}
@@ -433,6 +436,7 @@ def fit(self, X_train, y_train=None, X_val=None, y_val=None):
433436
if self.validate and self.error_metric is not None:
434437
self.__register_val_funcs()
435438
self.__run()
439+
self._is_fitted = True
436440
return self
437441

438442
def predict(self, X_test):
@@ -443,18 +447,18 @@ def predict(self, X_test):
443447
u_best = self.toolbox.map(self.toolbox.evaluate_test_sols, (self.best,))[0]
444448
return u_best
445449

446-
def score(self, X_test, y_test):
450+
def score(self, X, y):
447451
"""Computes the error metric (passed to the `GPSymbolicRegressor` constructor)
448452
on a given dataset.
449453
"""
450-
test_data = {"X": X_test, "y": y_test}
454+
test_data = {"X": X, "y": y}
451455
datasets = {"test": test_data}
452456
self.__store_datasets(datasets)
453457
self.__register_score_func()
454458
score = self.toolbox.map(self.toolbox.evaluate_test_score, (self.best,))[0]
455459
return score
456460

457-
def immigration(self, pop, num_immigrants: int):
461+
def __immigration(self, pop, num_immigrants: int):
458462
immigrants = self.toolbox.population(n=num_immigrants)
459463
for i in range(num_immigrants):
460464
idx_individual_to_replace = random.randint(0, self.NINDIVIDUALS - 1)
@@ -543,7 +547,7 @@ def __evolve_islands(self, cgen: int):
543547
for i in range(self.num_islands):
544548
if self.immigration_enabled:
545549
if cgen % self.immigration_freq == 0:
546-
self.immigration(
550+
self.__immigration(
547551
self.pop[i], int(self.immigration_frac * self.NINDIVIDUALS)
548552
)
549553

@@ -709,20 +713,20 @@ def __run(self):
709713
if self.plot_best_genealogy:
710714
self.__plot_genealogy(self.best)
711715

712-
if self.is_plot_best_individual_tree:
713-
self.plot_best_individual_tree()
716+
if self.plot_best_individual_tree:
717+
self.__plot_best_individual_tree()
714718

715-
if self.is_save_best_individual and self.output_path is not None:
716-
self.save_best_individual(self.output_path)
719+
if self.save_best_individual and self.output_path is not None:
720+
self.__save_best_individual(self.output_path)
717721
print("String of the best individual saved to disk.")
718722

719-
if self.is_save_train_fit_history and self.output_path is not None:
720-
self.save_train_fit_history(self.output_path)
723+
if self.save_train_fit_history and self.output_path is not None:
724+
self.__save_train_fit_history(self.output_path)
721725
print("Training fitness history saved to disk.")
722726

723727
# NOTE: ray.shutdown should be manually called by the user
724728

725-
def plot_best_individual_tree(self):
729+
def __plot_best_individual_tree(self):
726730
"""Plots the tree of the best individual at the end of the evolution."""
727731
nodes, edges, labels = gp.graph(self.best)
728732
graph = nx.Graph()
@@ -736,13 +740,13 @@ def plot_best_individual_tree(self):
736740
plt.axis("off")
737741
plt.show()
738742

739-
def save_best_individual(self, output_path: str):
743+
def __save_best_individual(self, output_path: str):
740744
"""Saves the string of the best individual of the population in a .txt file."""
741745
file = open(join(output_path, "best_ind.txt"), "w")
742746
file.write(str(self.best))
743747
file.close()
744748

745-
def save_train_fit_history(self, output_path: str):
749+
def __save_train_fit_history(self, output_path: str):
746750
np.save(join(output_path, "train_fit_history.npy"), self.train_fit_history)
747751
if self.validate:
748752
np.save(join(output_path, "val_fit_history.npy"), self.val_fit_history)

src/alpine/gp/util.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import yaml
22
from .primitives import add_primitives_to_pset
33
from importlib import import_module
4+
import ray
45

56

67
def add_primitives_to_pset_from_dict(pset, primitives_dict):
@@ -91,3 +92,18 @@ def detect_nested_trigonometric_functions(equation):
9192
i += 1
9293

9394
return nested
95+
96+
97+
@ray.remote
98+
def dummy_fitness(individuals_str, toolbox, X, y):
99+
fitnesses = [(0.0,)] * len(individuals_str)
100+
101+
return fitnesses
102+
103+
104+
@ray.remote
105+
def dummy_score(individuals_str, toolbox, X, y):
106+
107+
MSE = [0.0] * len(individuals_str)
108+
109+
return MSE

tests/test_poisson1d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def test_poisson1d(set_test_dir, yamlfile):
228228

229229
fit_score = gpsr.score(X_train, y_train)
230230

231-
gpsr.save_best_test_sols(X_train, "./")
231+
gpsr.__save_best_test_sols(X_train, "./")
232232

233233
ray.shutdown()
234234
assert np.allclose(u.coeffs.flatten(), np.ravel(u_best))

tests/test_regressor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from deap import gp
55
from sklearn.datasets import make_regression
66
from sklearn.model_selection import train_test_split, GridSearchCV
7+
from alpine.gp.util import dummy_fitness, dummy_score
78

89

910
def test_regressor():
@@ -32,13 +33,12 @@ def test_regressor():
3233

3334
pset = util.add_primitives_to_pset_from_dict(pset, primitives)
3435

35-
penalty = {"reg_param": 0.0}
36-
common_data = {"penalty": penalty}
36+
common_data = {}
3737

3838
gpsr = GPSymbolicRegressor(
3939
pset=pset,
40-
fitness=None,
41-
error_metric=None,
40+
fitness=dummy_fitness.remote,
41+
error_metric=dummy_score.remote,
4242
predict_func=None,
4343
common_data=common_data,
4444
NINDIVIDUALS=100,

0 commit comments

Comments
 (0)