Skip to content

Commit d359b52

Browse files
committed
Only keep models starting with specified model prefix (SNB_)
Thus, the baseline model (starting with (B_)) is taken account of, but is not used (so that regressions are not entailed).
1 parent 70e5d14 commit d359b52

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

khiops/sklearn/estimators.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,19 @@ def _get_main_dictionary(self):
284284
self._assert_is_fitted()
285285
return self.model_.get_dictionary(self.model_main_dictionary_name_)
286286

287+
def _read_model_from_dictionary_file(self, model_dictionary_file_path):
288+
"""Removes dictionaries that do not have the model prefix in their name
289+
290+
This function is necessary for the regression case because Khiops generates a
291+
baseline model which has to be removed for sklearn predictor.
292+
"""
293+
model = kh.read_dictionary_file(model_dictionary_file_path)
294+
assert self._khiops_model_prefix is not None
295+
for dictionary_name in [kdic.name for kdic in model.dictionaries]:
296+
if not dictionary_name.startswith(self._khiops_model_prefix):
297+
model.remove_dictionary(dictionary_name)
298+
return model
299+
287300
def export_report_file(self, report_file_path):
288301
"""Exports the model report to a JSON file
289302
@@ -309,7 +322,7 @@ def export_dictionary_file(self, dictionary_file_path):
309322

310323
def _import_model(self, kdic_path):
311324
"""Sets model instance attribute by importing model from ``.kdic``"""
312-
self.model_ = kh.read_dictionary_file(kdic_path)
325+
self.model_ = self._read_model_from_dictionary_file(kdic_path)
313326

314327
def _get_output_dir(self, fallback_dir):
315328
if self.output_dir:
@@ -806,7 +819,7 @@ def _fit_train_model(self, ds, computation_dir, **kwargs):
806819

807820
# Update the `model_` attribute of the coclustering estimator to the
808821
# new coclustering model
809-
self.model_ = kh.read_dictionary_file(
822+
self.model_ = self._read_model_from_dictionary_file(
810823
fs.get_child_path(
811824
output_dir, f"{self.model_main_dictionary_name_}_deployed.kdic"
812825
)
@@ -1019,7 +1032,7 @@ def _simplify(
10191032

10201033
# Set the `model_` attribute of the new coclustering estimator to
10211034
# the new coclustering model
1022-
simplified_cc.model_ = kh.read_dictionary_file(
1035+
simplified_cc.model_ = self._read_model_from_dictionary_file(
10231036
fs.get_child_path(
10241037
output_dir, f"{self.model_main_dictionary_name_}_deployed.kdic"
10251038
)
@@ -1204,6 +1217,7 @@ def __init__(
12041217
self.construction_rules = construction_rules
12051218
self._original_target_dtype = None
12061219
self._predicted_target_meta_data_tag = None
1220+
self._khiops_baseline_model_prefix = None
12071221

12081222
def __sklearn_tags__(self):
12091223
# If we don't implement this trivial method it's not found by the sklearn. This
@@ -1294,7 +1308,7 @@ def _fit_train_model(self, ds, computation_dir, **kwargs):
12941308
return
12951309

12961310
# Save the model domain object and report
1297-
self.model_ = kh.read_dictionary_file(model_kdic_file_path)
1311+
self.model_ = self._read_model_from_dictionary_file(model_kdic_file_path)
12981312
self.model_report_ = kh.read_analysis_results_file(report_file_path)
12991313

13001314
@abstractmethod
@@ -1383,15 +1397,24 @@ def _fit_training_post_process(self, ds):
13831397
self.model_main_dictionary_name_ = self.model_.dictionaries[0].name
13841398
else:
13851399
for dictionary in self.model_.dictionaries:
1386-
assert dictionary.name.startswith(self._khiops_model_prefix), (
1400+
1401+
# The baseline model is mandatory for regression;
1402+
# absent for classification and encoding
1403+
assert dictionary.name.startswith(
1404+
self._khiops_model_prefix
1405+
) or dictionary.name.startswith(self._khiops_baseline_model_prefix), (
13871406
f"Dictionary '{dictionary.name}' "
1388-
f"does not have prefix '{self._khiops_model_prefix}'"
1389-
)
1390-
initial_dictionary_name = dictionary.name.replace(
1391-
self._khiops_model_prefix, "", 1
1407+
f"does not have prefix '{self._khiops_model_prefix}' "
1408+
f"or '{self._khiops_baseline_model_prefix}'."
13921409
)
1393-
if initial_dictionary_name == ds.main_table.name:
1394-
self.model_main_dictionary_name_ = dictionary.name
1410+
1411+
# Skip baseline model
1412+
if dictionary.name.startswith(self._khiops_model_prefix):
1413+
initial_dictionary_name = dictionary.name.replace(
1414+
self._khiops_model_prefix, "", 1
1415+
)
1416+
if initial_dictionary_name == ds.main_table.name:
1417+
self.model_main_dictionary_name_ = dictionary.name
13951418
if self.model_main_dictionary_name_ is None:
13961419
raise ValueError("No model dictionary after Khiops call")
13971420

@@ -2183,6 +2206,7 @@ def __init__(
21832206
auto_sort=auto_sort,
21842207
)
21852208
self._khiops_model_prefix = "SNB_"
2209+
self._khiops_baseline_model_prefix = "B_"
21862210
self._predicted_target_meta_data_tag = "Mean"
21872211
self._predicted_target_name_prefix = "M"
21882212
self._original_target_dtype = np.float64
@@ -2284,6 +2308,9 @@ def predict(self, X):
22842308
- str (a path for the file containing the array) if X is a dataset spec
22852309
containing file-path tables.
22862310
"""
2311+
assert (
2312+
self._khiops_baseline_model_prefix is not None
2313+
), "Baseline model prefix is not set (mandatory for regression)"
22872314
# Call the parent's method
22882315
y_pred = super().predict(X)
22892316

0 commit comments

Comments
 (0)