Skip to content

Commit 5e332cf

Browse files
author
Alessandro Lucantonio
committed
Working on making the regressor compliant with sklearn specs.
1 parent 3c60996 commit 5e332cf

File tree

2 files changed

+103
-11
lines changed

2 files changed

+103
-11
lines changed

src/alpine/gp/regressor.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
output_path: str | None = None,
105105
batch_size=1,
106106
):
107-
107+
super().__init__()
108108
self.pset = pset
109109

110110
self.fitness = fitness
@@ -122,16 +122,14 @@ def __init__(
122122
self.num_best_inds_str = num_best_inds_str
123123
self.plot_freq = plot_freq
124124
self.preprocess_func = preprocess_func
125-
self.callback_fun = callback_func
125+
self.callback_func = callback_func
126126
self.is_plot_best_individual_tree = plot_best_individual_tree
127127
self.is_save_best_individual = save_best_individual
128128
self.is_save_train_fit_history = save_train_fit_history
129129
self.output_path = output_path
130130
self.batch_size = batch_size
131131

132-
if common_data is not None:
133-
# FIXME: does everything work when the functions do not have common args?
134-
self.__store_fit_error_common_args(common_data)
132+
self.common_data = common_data
135133

136134
self.NINDIVIDUALS = NINDIVIDUALS
137135
self.NGEN = NGEN
@@ -157,8 +155,14 @@ def __init__(
157155
self.overlapping_generation = overlapping_generation
158156
self.validate = validate
159157

158+
self.frac_elitist = frac_elitist
159+
160160
# Elitism settings
161-
self.n_elitist = int(frac_elitist * self.NINDIVIDUALS)
161+
self.n_elitist = int(self.frac_elitist * self.NINDIVIDUALS)
162+
163+
if self.common_data is not None:
164+
# FIXME: does everything work when the functions do not have common args?
165+
self.__store_fit_error_common_args(self.common_data)
162166

163167
# config individual creator and toolbox
164168
self.__creator_toolbox_config()
@@ -196,6 +200,9 @@ def __init__(
196200
self.plot_initialized = False
197201
self.fig_id = 0
198202

203+
def get_params(self, deep=True):
204+
return self.__dict__
205+
199206
def __creator_toolbox_config(self):
200207
"""Initialize toolbox and individual creator based on config file."""
201208
self.toolbox = base.Toolbox()
@@ -276,7 +283,8 @@ def __store_datasets(self, datasets: Dict[str, Dataset]):
276283
def __store_shared_objects(self, label: str, data: Dict):
277284
for key, value in data.items():
278285
# replace each item of the dataset with its obj ref
279-
data[key] = ray.put(value)
286+
if not isinstance(value, ray.ObjectRef):
287+
data[key] = ray.put(value)
280288
self.data_store[label] = data
281289

282290
def __init_logbook(self):
@@ -425,6 +433,7 @@ def fit(self, X_train, y_train=None, X_val=None, y_val=None):
425433
if self.validate and self.error_metric is not None:
426434
self.__register_val_funcs()
427435
self.__run()
436+
return self
428437

429438
def predict(self, X_test):
430439
test_data = {"X": X_test}
@@ -567,8 +576,8 @@ def __evolve_islands(self, cgen: int):
567576
fitnesses = self.__unflatten_list(fitnesses, [len(i) for i in invalid_inds])
568577

569578
for i in range(self.num_islands):
570-
if self.callback_fun is not None:
571-
self.callback_fun(invalid_inds[i], fitnesses[i])
579+
if self.callback_func is not None:
580+
self.callback_func(invalid_inds[i], fitnesses[i])
572581
else:
573582
for ind, fit in zip(invalid_inds[i], fitnesses[i]):
574583
ind.fitness.values = fit
@@ -626,8 +635,8 @@ def __run(self):
626635
for i in range(self.num_islands):
627636
fitnesses = self.toolbox.map(self.toolbox.evaluate_train, self.pop[i])
628637

629-
if self.callback_fun is not None:
630-
self.callback_fun(self.pop[i], fitnesses)
638+
if self.callback_func is not None:
639+
self.callback_func(self.pop[i], fitnesses)
631640
else:
632641
for ind, fit in zip(self.pop[i], fitnesses):
633642
ind.fitness.values = fit

tests/test_regressor.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from sklearn.utils.estimator_checks import check_estimator
2+
from alpine.gp.regressor import GPSymbolicRegressor
3+
from alpine.gp import util
4+
from deap import gp
5+
from sklearn.datasets import make_regression
6+
from sklearn.model_selection import train_test_split, GridSearchCV
7+
8+
9+
def test_regressor():
10+
pset = gp.PrimitiveSetTyped(
11+
"MAIN",
12+
[
13+
float,
14+
],
15+
float,
16+
)
17+
pset.renameArguments(ARG0="x")
18+
19+
primitives = {
20+
"imports": {"alpine.gp.numpy_primitives": ["numpy_primitives"]},
21+
"used": [
22+
{"name": "add", "dimension": None, "rank": None},
23+
{"name": "sub", "dimension": None, "rank": None},
24+
{"name": "mul", "dimension": None, "rank": None},
25+
{"name": "div", "dimension": None, "rank": None},
26+
{"name": "sin", "dimension": None, "rank": None},
27+
{"name": "cos", "dimension": None, "rank": None},
28+
{"name": "exp", "dimension": None, "rank": None},
29+
{"name": "log", "dimension": None, "rank": None},
30+
],
31+
}
32+
33+
pset = util.add_primitives_to_pset_from_dict(pset, primitives)
34+
35+
penalty = {"reg_param": 0.0}
36+
common_data = {"penalty": penalty}
37+
38+
gpsr = GPSymbolicRegressor(
39+
pset=pset,
40+
fitness=None,
41+
error_metric=None,
42+
predict_func=None,
43+
common_data=common_data,
44+
NINDIVIDUALS=100,
45+
num_islands=10,
46+
NGEN=200,
47+
MUTPB=0.1,
48+
min_height=2,
49+
max_height=6,
50+
crossover_prob=0.9,
51+
overlapping_generation=True,
52+
print_log=True,
53+
batch_size=100,
54+
)
55+
56+
print(gpsr.get_params())
57+
check_estimator(gpsr)
58+
59+
# # Generate synthetic data
60+
# X, y = make_regression(n_samples=100, n_features=10, random_state=42)
61+
# X_train, X_test, y_train, y_test = train_test_split(
62+
# X, y, test_size=0.2, random_state=42
63+
# )
64+
65+
# # Parameter grid
66+
# param_grid = {"NGEN": [10, 20]}
67+
68+
# # Grid search
69+
# grid_search = GridSearchCV(
70+
# estimator=gpsr,
71+
# param_grid=param_grid,
72+
# cv=3,
73+
# scoring="r2",
74+
# verbose=1,
75+
# n_jobs=1,
76+
# )
77+
78+
# # Fit the grid search
79+
# grid_search.fit(X_train, y_train)
80+
81+
82+
if __name__ == "__main__":
83+
test_regressor()

0 commit comments

Comments
 (0)