Skip to content

Commit 9b0a249

Browse files
authored
Round explanation values to 4 decimal places in ForecastOperator (#1160)
2 parents 1c43a76 + 7a298e7 commit 9b0a249

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

ads/opctl/operator/lowcode/forecast/model/base_model.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -573,34 +573,44 @@ def _save_report(
573573
if self.spec.generate_explanations:
574574
try:
575575
if not self.formatted_global_explanation.empty:
576+
# Round to 4 decimal places before writing
577+
global_expl_rounded = self.formatted_global_explanation.copy()
578+
global_expl_rounded = global_expl_rounded.apply(
579+
lambda col: np.round(col, 4) if np.issubdtype(col.dtype, np.number) else col
580+
)
576581
if self.spec.generate_explanation_files:
577582
write_data(
578-
data=self.formatted_global_explanation,
583+
data=global_expl_rounded,
579584
filename=os.path.join(
580585
unique_output_dir, self.spec.global_explanation_filename
581586
),
582587
format="csv",
583588
storage_options=storage_options,
584589
index=True,
585590
)
586-
results.set_global_explanations(self.formatted_global_explanation)
591+
results.set_global_explanations(global_expl_rounded)
587592
else:
588593
logger.warning(
589594
f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
590595
)
591596

592597
if not self.formatted_local_explanation.empty:
598+
# Round to 4 decimal places before writing
599+
local_expl_rounded = self.formatted_local_explanation.copy()
600+
local_expl_rounded = local_expl_rounded.apply(
601+
lambda col: np.round(col, 4) if np.issubdtype(col.dtype, np.number) else col
602+
)
593603
if self.spec.generate_explanation_files:
594604
write_data(
595-
data=self.formatted_local_explanation,
605+
data=local_expl_rounded,
596606
filename=os.path.join(
597607
unique_output_dir, self.spec.local_explanation_filename
598608
),
599609
format="csv",
600610
storage_options=storage_options,
601611
index=True,
602612
)
603-
results.set_local_explanations(self.formatted_local_explanation)
613+
results.set_local_explanations(local_expl_rounded)
604614
else:
605615
logger.warning(
606616
f"Attempted to generate local explanations for the {self.spec.local_explanation_filename} file, but an issue occured in formatting the explanations."

tests/operators/forecast/test_explainers.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,19 @@ def test_explanations_values(model, num_series, freq):
343343
if model == "automlx":
344344
pytest.xfail("automlx model does not provide fitted values")
345345

346+
# Check decimal precision for local explanations
347+
local_numeric = local_explanations.select_dtypes(include=["int64", "float64"])
348+
assert np.allclose(local_numeric, np.round(local_numeric, 4), atol=1e-8), \
349+
"Local explanations have values with more than 4 decimal places"
350+
351+
# Check decimal precision for global explanations
352+
global_explanations = results.get_global_explanations()
353+
global_numeric = global_explanations.select_dtypes(include=["int64", "float64"])
354+
assert np.allclose(global_numeric, np.round(global_numeric, 4), atol=1e-8), \
355+
"Global explanations have values with more than 4 decimal places"
356+
346357
local_explain_vals = (
347-
local_explanations.select_dtypes(include=["int64", "float64"]).sum(axis=1)
358+
local_numeric.sum(axis=1)
348359
+ forecast.fitted_value.mean()
349360
)
350361
assert np.allclose(

0 commit comments

Comments
 (0)