Skip to content

Commit 8144301

Browse files
committed
Only keep models starting with specified model prefix (SNB_)
Thus, the baseline model (starting with (B_)) is taken account of, but is not used (so that regressions are not entailed).
1 parent ad6da19 commit 8144301

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

khiops/sklearn/estimators.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,19 @@ def _get_main_dictionary(self):
284284
self._assert_is_fitted()
285285
return self.model_.get_dictionary(self.model_main_dictionary_name_)
286286

287+
def _read_model_from_dictionary_file(self, model_dictionary_file_path):
288+
"""Removes dictionaries that do not have the model prefix in their name
289+
290+
This function is necessary for the regression case because Khiops generates a
291+
baseline model which has to be removed for sklearn predictor.
292+
"""
293+
model = kh.read_dictionary_file(model_dictionary_file_path)
294+
assert self._khiops_model_prefix is not None
295+
for dictionary_name in [kdic.name for kdic in model.dictionaries]:
296+
if not dictionary_name.startswith(self._khiops_model_prefix):
297+
model.remove_dictionary(dictionary_name)
298+
return model
299+
287300
def export_report_file(self, report_file_path):
288301
"""Exports the model report to a JSON file
289302
@@ -309,7 +322,7 @@ def export_dictionary_file(self, dictionary_file_path):
309322

310323
def _import_model(self, kdic_path):
311324
"""Sets model instance attribute by importing model from ``.kdic``"""
312-
self.model_ = kh.read_dictionary_file(kdic_path)
325+
self.model_ = self._read_model_from_dictionary_file(kdic_path)
313326

314327
def _get_output_dir(self, fallback_dir):
315328
if self.output_dir:
@@ -808,7 +821,7 @@ def _fit_train_model(self, ds, computation_dir, **kwargs):
808821

809822
# Update the `model_` attribute of the coclustering estimator to the
810823
# new coclustering model
811-
self.model_ = kh.read_dictionary_file(
824+
self.model_ = self._read_model_from_dictionary_file(
812825
fs.get_child_path(
813826
output_dir, f"{self.model_main_dictionary_name_}_deployed.kdic"
814827
)
@@ -1021,7 +1034,7 @@ def _simplify(
10211034

10221035
# Set the `model_` attribute of the new coclustering estimator to
10231036
# the new coclustering model
1024-
simplified_cc.model_ = kh.read_dictionary_file(
1037+
simplified_cc.model_ = self._read_model_from_dictionary_file(
10251038
fs.get_child_path(
10261039
output_dir, f"{self.model_main_dictionary_name_}_deployed.kdic"
10271040
)
@@ -1205,6 +1218,7 @@ def __init__(
12051218
self.construction_rules = construction_rules
12061219
self._original_target_dtype = None
12071220
self._predicted_target_meta_data_tag = None
1221+
self._khiops_baseline_model_prefix = None
12081222

12091223
def __sklearn_tags__(self):
12101224
# If we don't implement this trivial method it's not found by the sklearn. This
@@ -1295,7 +1309,7 @@ def _fit_train_model(self, ds, computation_dir, **kwargs):
12951309
return
12961310

12971311
# Save the model domain object and report
1298-
self.model_ = kh.read_dictionary_file(model_kdic_file_path)
1312+
self.model_ = self._read_model_from_dictionary_file(model_kdic_file_path)
12991313
self.model_report_ = kh.read_analysis_results_file(report_file_path)
13001314

13011315
@abstractmethod
@@ -1384,15 +1398,24 @@ def _fit_training_post_process(self, ds):
13841398
self.model_main_dictionary_name_ = self.model_.dictionaries[0].name
13851399
else:
13861400
for dictionary in self.model_.dictionaries:
1387-
assert dictionary.name.startswith(self._khiops_model_prefix), (
1401+
1402+
# The baseline model is mandatory for regression;
1403+
# absent for classification and encoding
1404+
assert dictionary.name.startswith(
1405+
self._khiops_model_prefix
1406+
) or dictionary.name.startswith(self._khiops_baseline_model_prefix), (
13881407
f"Dictionary '{dictionary.name}' "
1389-
f"does not have prefix '{self._khiops_model_prefix}'"
1390-
)
1391-
initial_dictionary_name = dictionary.name.replace(
1392-
self._khiops_model_prefix, "", 1
1408+
f"does not have prefix '{self._khiops_model_prefix}' "
1409+
f"or '{self._khiops_baseline_model_prefix}'."
13931410
)
1394-
if initial_dictionary_name == ds.main_table.name:
1395-
self.model_main_dictionary_name_ = dictionary.name
1411+
1412+
# Skip baseline model
1413+
if dictionary.name.startswith(self._khiops_model_prefix):
1414+
initial_dictionary_name = dictionary.name.replace(
1415+
self._khiops_model_prefix, "", 1
1416+
)
1417+
if initial_dictionary_name == ds.main_table.name:
1418+
self.model_main_dictionary_name_ = dictionary.name
13961419
if self.model_main_dictionary_name_ is None:
13971420
raise ValueError("No model dictionary after Khiops call")
13981421

@@ -2185,6 +2208,7 @@ def __init__(
21852208
auto_sort=auto_sort,
21862209
)
21872210
self._khiops_model_prefix = "SNB_"
2211+
self._khiops_baseline_model_prefix = "B_"
21882212
self._predicted_target_meta_data_tag = "Mean"
21892213
self._predicted_target_name_prefix = "M"
21902214
self._original_target_dtype = np.float64
@@ -2286,6 +2310,9 @@ def predict(self, X):
22862310
- str (a path for the file containing the array) if X is a dataset spec
22872311
containing file-path tables.
22882312
"""
2313+
assert (
2314+
self._khiops_baseline_model_prefix is not None
2315+
), "Baseline model prefix is not set (mandatory for regression)"
22892316
# Call the parent's method
22902317
y_pred = super().predict(X)
22912318

0 commit comments

Comments
 (0)