@@ -284,6 +284,19 @@ def _get_main_dictionary(self):
284
284
self ._assert_is_fitted ()
285
285
return self .model_ .get_dictionary (self .model_main_dictionary_name_ )
286
286
287
+ def _read_model_from_dictionary_file (self , model_dictionary_file_path ):
288
+ """Removes dictionaries that do not have the model prefix in their name
289
+
290
+ This function is necessary for the regression case because Khiops generates a
291
+ baseline model which has to be removed for sklearn predictor.
292
+ """
293
+ model = kh .read_dictionary_file (model_dictionary_file_path )
294
+ assert self ._khiops_model_prefix is not None
295
+ for dictionary_name in [kdic .name for kdic in model .dictionaries ]:
296
+ if not dictionary_name .startswith (self ._khiops_model_prefix ):
297
+ model .remove_dictionary (dictionary_name )
298
+ return model
299
+
287
300
def export_report_file (self , report_file_path ):
288
301
"""Exports the model report to a JSON file
289
302
@@ -309,7 +322,7 @@ def export_dictionary_file(self, dictionary_file_path):
309
322
310
323
def _import_model (self , kdic_path ):
311
324
"""Sets model instance attribute by importing model from ``.kdic``"""
312
- self .model_ = kh . read_dictionary_file (kdic_path )
325
+ self .model_ = self . _read_model_from_dictionary_file (kdic_path )
313
326
314
327
def _get_output_dir (self , fallback_dir ):
315
328
if self .output_dir :
@@ -806,7 +819,7 @@ def _fit_train_model(self, ds, computation_dir, **kwargs):
806
819
807
820
# Update the `model_` attribute of the coclustering estimator to the
808
821
# new coclustering model
809
- self .model_ = kh . read_dictionary_file (
822
+ self .model_ = self . _read_model_from_dictionary_file (
810
823
fs .get_child_path (
811
824
output_dir , f"{ self .model_main_dictionary_name_ } _deployed.kdic"
812
825
)
@@ -1019,7 +1032,7 @@ def _simplify(
1019
1032
1020
1033
# Set the `model_` attribute of the new coclustering estimator to
1021
1034
# the new coclustering model
1022
- simplified_cc .model_ = kh . read_dictionary_file (
1035
+ simplified_cc .model_ = self . _read_model_from_dictionary_file (
1023
1036
fs .get_child_path (
1024
1037
output_dir , f"{ self .model_main_dictionary_name_ } _deployed.kdic"
1025
1038
)
@@ -1204,6 +1217,7 @@ def __init__(
1204
1217
self .construction_rules = construction_rules
1205
1218
self ._original_target_dtype = None
1206
1219
self ._predicted_target_meta_data_tag = None
1220
+ self ._khiops_baseline_model_prefix = None
1207
1221
1208
1222
def __sklearn_tags__ (self ):
1209
1223
# If we don't implement this trivial method it's not found by the sklearn. This
@@ -1294,7 +1308,7 @@ def _fit_train_model(self, ds, computation_dir, **kwargs):
1294
1308
return
1295
1309
1296
1310
# Save the model domain object and report
1297
- self .model_ = kh . read_dictionary_file (model_kdic_file_path )
1311
+ self .model_ = self . _read_model_from_dictionary_file (model_kdic_file_path )
1298
1312
self .model_report_ = kh .read_analysis_results_file (report_file_path )
1299
1313
1300
1314
@abstractmethod
@@ -1383,15 +1397,24 @@ def _fit_training_post_process(self, ds):
1383
1397
self .model_main_dictionary_name_ = self .model_ .dictionaries [0 ].name
1384
1398
else :
1385
1399
for dictionary in self .model_ .dictionaries :
1386
- assert dictionary .name .startswith (self ._khiops_model_prefix ), (
1400
+
1401
+ # The baseline model is mandatory for regression;
1402
+ # absent for classification and encoding
1403
+ assert dictionary .name .startswith (
1404
+ self ._khiops_model_prefix
1405
+ ) or dictionary .name .startswith (self ._khiops_baseline_model_prefix ), (
1387
1406
f"Dictionary '{ dictionary .name } ' "
1388
- f"does not have prefix '{ self ._khiops_model_prefix } '"
1389
- )
1390
- initial_dictionary_name = dictionary .name .replace (
1391
- self ._khiops_model_prefix , "" , 1
1407
+ f"does not have prefix '{ self ._khiops_model_prefix } ' "
1408
+ f"or '{ self ._khiops_baseline_model_prefix } '."
1392
1409
)
1393
- if initial_dictionary_name == ds .main_table .name :
1394
- self .model_main_dictionary_name_ = dictionary .name
1410
+
1411
+ # Skip baseline model
1412
+ if dictionary .name .startswith (self ._khiops_model_prefix ):
1413
+ initial_dictionary_name = dictionary .name .replace (
1414
+ self ._khiops_model_prefix , "" , 1
1415
+ )
1416
+ if initial_dictionary_name == ds .main_table .name :
1417
+ self .model_main_dictionary_name_ = dictionary .name
1395
1418
if self .model_main_dictionary_name_ is None :
1396
1419
raise ValueError ("No model dictionary after Khiops call" )
1397
1420
@@ -2183,6 +2206,7 @@ def __init__(
2183
2206
auto_sort = auto_sort ,
2184
2207
)
2185
2208
self ._khiops_model_prefix = "SNB_"
2209
+ self ._khiops_baseline_model_prefix = "B_"
2186
2210
self ._predicted_target_meta_data_tag = "Mean"
2187
2211
self ._predicted_target_name_prefix = "M"
2188
2212
self ._original_target_dtype = np .float64
@@ -2284,6 +2308,9 @@ def predict(self, X):
2284
2308
- str (a path for the file containing the array) if X is a dataset spec
2285
2309
containing file-path tables.
2286
2310
"""
2311
+ assert (
2312
+ self ._khiops_baseline_model_prefix is not None
2313
+ ), "Baseline model prefix is not set (mandatory for regression)"
2287
2314
# Call the parent's method
2288
2315
y_pred = super ().predict (X )
2289
2316
0 commit comments