@@ -111,8 +111,6 @@ def __init__(
111
111
self .error_metric = error_metric
112
112
self .predict_func = predict_func
113
113
114
- self .data_store = dict ()
115
-
116
114
self .plot_best = plot_best
117
115
118
116
self .plot_best_genealogy = plot_best_genealogy
@@ -123,9 +121,9 @@ def __init__(
123
121
self .plot_freq = plot_freq
124
122
self .preprocess_func = preprocess_func
125
123
self .callback_func = callback_func
126
- self .is_plot_best_individual_tree = plot_best_individual_tree
127
- self .is_save_best_individual = save_best_individual
128
- self .is_save_train_fit_history = save_train_fit_history
124
+ self .plot_best_individual_tree = plot_best_individual_tree
125
+ self .save_best_individual = save_best_individual
126
+ self .save_train_fit_history = save_train_fit_history
129
127
self .output_path = output_path
130
128
self .batch_size = batch_size
131
129
@@ -157,48 +155,14 @@ def __init__(
157
155
158
156
self .frac_elitist = frac_elitist
159
157
160
- # Elitism settings
161
- self .n_elitist = int (self .frac_elitist * self .NINDIVIDUALS )
162
-
163
- if self .common_data is not None :
164
- # FIXME: does everything work when the functions do not have common args?
165
- self .__store_fit_error_common_args (self .common_data )
166
-
167
- # config individual creator and toolbox
168
- self .__creator_toolbox_config ()
169
-
170
158
self .seed = seed
171
159
172
160
if self .seed is not None :
173
161
self .seed = [self .createIndividual .from_string (i , pset ) for i in seed ]
174
162
175
- # Initialize variables for statistics
176
- self .stats_fit = tools .Statistics (lambda ind : ind .fitness .values )
177
- self .stats_size = tools .Statistics (len )
178
- self .mstats = tools .MultiStatistics (
179
- fitness = self .stats_fit , size = self .stats_size
180
- )
181
- self .mstats .register ("avg" , lambda ind : np .around (np .mean (ind ), 4 ))
182
- self .mstats .register ("std" , lambda ind : np .around (np .std (ind ), 4 ))
183
- self .mstats .register ("min" , lambda ind : np .around (np .min (ind ), 4 ))
184
- self .mstats .register ("max" , lambda ind : np .around (np .max (ind ), 4 ))
185
-
186
- self .__init_logbook ()
187
-
188
- self .train_fit_history = []
189
-
190
- # Create history object to build the genealogy tree
191
- self .history = tools .History ()
192
-
193
- if self .plot_best_genealogy :
194
- # Decorators for history
195
- self .toolbox .decorate ("mate" , self .history .decorator )
196
- self .toolbox .decorate ("mutate" , self .history .decorator )
197
-
198
- self .__register_map ()
199
-
200
- self .plot_initialized = False
201
- self .fig_id = 0
163
+ @property
164
+ def n_elitist (self ):
165
+ return int (self .frac_elitist * self .NINDIVIDUALS )
202
166
203
167
def get_params (self , deep = True ):
204
168
return self .__dict__
@@ -420,9 +384,48 @@ def mapper(f, individuals, toolbox_ref):
420
384
toolbox_ref = ray .put (self .toolbox )
421
385
self .toolbox .register ("map" , mapper , toolbox_ref = toolbox_ref )
422
386
423
- def fit (self , X_train , y_train = None , X_val = None , y_val = None ):
387
+ def fit (self , X , y = None , X_val = None , y_val = None ):
424
388
"""Fits the training data using GP-based symbolic regression."""
425
- train_data = {"X" : X_train , "y" : y_train }
389
+
390
+ if not hasattr (self , "_is_fitted" ):
391
+ self .data_store = dict ()
392
+
393
+ if self .common_data is not None :
394
+ # FIXME: does everything work when the functions do not have common args?
395
+ self .__store_fit_error_common_args (self .common_data )
396
+
397
+ # config individual creator and toolbox
398
+ self .__creator_toolbox_config ()
399
+
400
+ # Initialize variables for statistics
401
+ self .stats_fit = tools .Statistics (lambda ind : ind .fitness .values )
402
+ self .stats_size = tools .Statistics (len )
403
+ self .mstats = tools .MultiStatistics (
404
+ fitness = self .stats_fit , size = self .stats_size
405
+ )
406
+ self .mstats .register ("avg" , lambda ind : np .around (np .mean (ind ), 4 ))
407
+ self .mstats .register ("std" , lambda ind : np .around (np .std (ind ), 4 ))
408
+ self .mstats .register ("min" , lambda ind : np .around (np .min (ind ), 4 ))
409
+ self .mstats .register ("max" , lambda ind : np .around (np .max (ind ), 4 ))
410
+
411
+ self .__init_logbook ()
412
+
413
+ self .train_fit_history = []
414
+
415
+ # Create history object to build the genealogy tree
416
+ self .history = tools .History ()
417
+
418
+ if self .plot_best_genealogy :
419
+ # Decorators for history
420
+ self .toolbox .decorate ("mate" , self .history .decorator )
421
+ self .toolbox .decorate ("mutate" , self .history .decorator )
422
+
423
+ self .__register_map ()
424
+
425
+ self .plot_initialized = False
426
+ self .fig_id = 0
427
+
428
+ train_data = {"X" : X , "y" : y }
426
429
if self .validate and X_val is not None :
427
430
val_data = {"X" : X_val , "y" : y_val }
428
431
datasets = {"train" : train_data , "val" : val_data }
@@ -433,6 +436,7 @@ def fit(self, X_train, y_train=None, X_val=None, y_val=None):
433
436
if self .validate and self .error_metric is not None :
434
437
self .__register_val_funcs ()
435
438
self .__run ()
439
+ self ._is_fitted = True
436
440
return self
437
441
438
442
def predict (self , X_test ):
@@ -443,18 +447,18 @@ def predict(self, X_test):
443
447
u_best = self .toolbox .map (self .toolbox .evaluate_test_sols , (self .best ,))[0 ]
444
448
return u_best
445
449
446
- def score (self , X_test , y_test ):
450
+ def score (self , X , y ):
447
451
"""Computes the error metric (passed to the `GPSymbolicRegressor` constructor)
448
452
on a given dataset.
449
453
"""
450
- test_data = {"X" : X_test , "y" : y_test }
454
+ test_data = {"X" : X , "y" : y }
451
455
datasets = {"test" : test_data }
452
456
self .__store_datasets (datasets )
453
457
self .__register_score_func ()
454
458
score = self .toolbox .map (self .toolbox .evaluate_test_score , (self .best ,))[0 ]
455
459
return score
456
460
457
- def immigration (self , pop , num_immigrants : int ):
461
+ def __immigration (self , pop , num_immigrants : int ):
458
462
immigrants = self .toolbox .population (n = num_immigrants )
459
463
for i in range (num_immigrants ):
460
464
idx_individual_to_replace = random .randint (0 , self .NINDIVIDUALS - 1 )
@@ -543,7 +547,7 @@ def __evolve_islands(self, cgen: int):
543
547
for i in range (self .num_islands ):
544
548
if self .immigration_enabled :
545
549
if cgen % self .immigration_freq == 0 :
546
- self .immigration (
550
+ self .__immigration (
547
551
self .pop [i ], int (self .immigration_frac * self .NINDIVIDUALS )
548
552
)
549
553
@@ -709,20 +713,20 @@ def __run(self):
709
713
if self .plot_best_genealogy :
710
714
self .__plot_genealogy (self .best )
711
715
712
- if self .is_plot_best_individual_tree :
713
- self .plot_best_individual_tree ()
716
+ if self .plot_best_individual_tree :
717
+ self .__plot_best_individual_tree ()
714
718
715
- if self .is_save_best_individual and self .output_path is not None :
716
- self .save_best_individual (self .output_path )
719
+ if self .save_best_individual and self .output_path is not None :
720
+ self .__save_best_individual (self .output_path )
717
721
print ("String of the best individual saved to disk." )
718
722
719
- if self .is_save_train_fit_history and self .output_path is not None :
720
- self .save_train_fit_history (self .output_path )
723
+ if self .save_train_fit_history and self .output_path is not None :
724
+ self .__save_train_fit_history (self .output_path )
721
725
print ("Training fitness history saved to disk." )
722
726
723
727
# NOTE: ray.shutdown should be manually called by the user
724
728
725
- def plot_best_individual_tree (self ):
729
+ def __plot_best_individual_tree (self ):
726
730
"""Plots the tree of the best individual at the end of the evolution."""
727
731
nodes , edges , labels = gp .graph (self .best )
728
732
graph = nx .Graph ()
@@ -736,13 +740,13 @@ def plot_best_individual_tree(self):
736
740
plt .axis ("off" )
737
741
plt .show ()
738
742
739
- def save_best_individual (self , output_path : str ):
743
+ def __save_best_individual (self , output_path : str ):
740
744
"""Saves the string of the best individual of the population in a .txt file."""
741
745
file = open (join (output_path , "best_ind.txt" ), "w" )
742
746
file .write (str (self .best ))
743
747
file .close ()
744
748
745
- def save_train_fit_history (self , output_path : str ):
749
+ def __save_train_fit_history (self , output_path : str ):
746
750
np .save (join (output_path , "train_fit_history.npy" ), self .train_fit_history )
747
751
if self .validate :
748
752
np .save (join (output_path , "val_fit_history.npy" ), self .val_fit_history )
0 commit comments