@@ -104,7 +104,7 @@ def __init__(
104
104
output_path : str | None = None ,
105
105
batch_size = 1 ,
106
106
):
107
-
107
+ super (). __init__ ()
108
108
self .pset = pset
109
109
110
110
self .fitness = fitness
@@ -122,16 +122,14 @@ def __init__(
122
122
self .num_best_inds_str = num_best_inds_str
123
123
self .plot_freq = plot_freq
124
124
self .preprocess_func = preprocess_func
125
- self .callback_fun = callback_func
125
+ self .callback_func = callback_func
126
126
self .is_plot_best_individual_tree = plot_best_individual_tree
127
127
self .is_save_best_individual = save_best_individual
128
128
self .is_save_train_fit_history = save_train_fit_history
129
129
self .output_path = output_path
130
130
self .batch_size = batch_size
131
131
132
- if common_data is not None :
133
- # FIXME: does everything work when the functions do not have common args?
134
- self .__store_fit_error_common_args (common_data )
132
+ self .common_data = common_data
135
133
136
134
self .NINDIVIDUALS = NINDIVIDUALS
137
135
self .NGEN = NGEN
@@ -157,8 +155,14 @@ def __init__(
157
155
self .overlapping_generation = overlapping_generation
158
156
self .validate = validate
159
157
158
+ self .frac_elitist = frac_elitist
159
+
160
160
# Elitism settings
161
- self .n_elitist = int (frac_elitist * self .NINDIVIDUALS )
161
+ self .n_elitist = int (self .frac_elitist * self .NINDIVIDUALS )
162
+
163
+ if self .common_data is not None :
164
+ # FIXME: does everything work when the functions do not have common args?
165
+ self .__store_fit_error_common_args (self .common_data )
162
166
163
167
# config individual creator and toolbox
164
168
self .__creator_toolbox_config ()
@@ -196,6 +200,9 @@ def __init__(
196
200
self .plot_initialized = False
197
201
self .fig_id = 0
198
202
203
+ def get_params (self , deep = True ):
204
+ return self .__dict__
205
+
199
206
def __creator_toolbox_config (self ):
200
207
"""Initialize toolbox and individual creator based on config file."""
201
208
self .toolbox = base .Toolbox ()
@@ -276,7 +283,8 @@ def __store_datasets(self, datasets: Dict[str, Dataset]):
276
283
def __store_shared_objects (self , label : str , data : Dict ):
277
284
for key , value in data .items ():
278
285
# replace each item of the dataset with its obj ref
279
- data [key ] = ray .put (value )
286
+ if not isinstance (value , ray .ObjectRef ):
287
+ data [key ] = ray .put (value )
280
288
self .data_store [label ] = data
281
289
282
290
def __init_logbook (self ):
@@ -425,6 +433,7 @@ def fit(self, X_train, y_train=None, X_val=None, y_val=None):
425
433
if self .validate and self .error_metric is not None :
426
434
self .__register_val_funcs ()
427
435
self .__run ()
436
+ return self
428
437
429
438
def predict (self , X_test ):
430
439
test_data = {"X" : X_test }
@@ -567,8 +576,8 @@ def __evolve_islands(self, cgen: int):
567
576
fitnesses = self .__unflatten_list (fitnesses , [len (i ) for i in invalid_inds ])
568
577
569
578
for i in range (self .num_islands ):
570
- if self .callback_fun is not None :
571
- self .callback_fun (invalid_inds [i ], fitnesses [i ])
579
+ if self .callback_func is not None :
580
+ self .callback_func (invalid_inds [i ], fitnesses [i ])
572
581
else :
573
582
for ind , fit in zip (invalid_inds [i ], fitnesses [i ]):
574
583
ind .fitness .values = fit
@@ -626,8 +635,8 @@ def __run(self):
626
635
for i in range (self .num_islands ):
627
636
fitnesses = self .toolbox .map (self .toolbox .evaluate_train , self .pop [i ])
628
637
629
- if self .callback_fun is not None :
630
- self .callback_fun (self .pop [i ], fitnesses )
638
+ if self .callback_func is not None :
639
+ self .callback_func (self .pop [i ], fitnesses )
631
640
else :
632
641
for ind , fit in zip (self .pop [i ], fitnesses ):
633
642
ind .fitness .values = fit
0 commit comments