Skip to content

Commit 5bbb2cf

Browse files
authored
Fix model parameter serialization for ETS and Theta (#1383)
2 parents 0f873ee + 52198d3 commit 5bbb2cf

4 files changed

Lines changed: 44 additions & 8 deletions

File tree

ads/opctl/operator/lowcode/forecast/model/ets.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,24 @@ def _auto_detect_ets_params(self, y: pd.Series, seasonal_period: int) -> Dict[st
9898

9999
return params
100100

101+
@staticmethod
102+
def _fit_statistics_as_dict(fit) -> Dict[str, Any]:
103+
"""Return JSON-safe scalar fit statistics for model parameter output."""
104+
return {
105+
name: float(getattr(fit, name))
106+
for name in ["aic", "bic", "aicc", "llf", "sse"]
107+
if hasattr(fit, name)
108+
}
109+
110+
@staticmethod
111+
def _fit_params_as_dict(fit) -> Dict[str, float]:
112+
"""Return named fitted ETS parameters as JSON-safe scalar values."""
113+
values = getattr(fit, "params", [])
114+
values = values.tolist() if isinstance(values, pd.Series) else list(values)
115+
names = getattr(fit, "param_names", None)
116+
names = names if names and len(names) == len(values) else range(len(values))
117+
return {str(name): float(value) for name, value in zip(names, values)}
118+
101119
def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, Any]):
102120
try:
103121
self.forecast_output.init_series_output(series_id=series_id, data_at_series=df)
@@ -199,13 +217,16 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
199217
self.models[series_id]["model"] = fit
200218
self.models[series_id]["le"] = self.le[series_id]
201219

202-
params = vars(model).copy()
203-
for param in ["arima_res_", "endog_index_"]:
204-
if param in params:
205-
params.pop(param)
206220
self.model_parameters[series_id] = {
207221
"framework": SupportedModels.ETSForecaster,
208-
**params,
222+
"error": model.error,
223+
"trend": model.trend,
224+
"damped_trend": model.damped_trend,
225+
"seasonal": model.seasonal,
226+
"seasonal_periods": model.seasonal_periods,
227+
"initialization_method": model.initialization_method,
228+
"fit_params": self._fit_params_as_dict(fit),
229+
"fit_statistics": self._fit_statistics_as_dict(fit),
209230
}
210231

211232
logger.debug("===========Done===========")

ads/opctl/operator/lowcode/forecast/model/theta.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,15 @@ def _get_sp_candidates(self, y, freq):
9999
logger.debug(f"Found {valid_candidates} seasonality candidates")
100100
return sorted(list(valid_candidates))
101101

102+
@staticmethod
103+
def _fit_params_as_dict(model) -> Dict[str, float]:
104+
"""Return scalar fitted Theta parameters as JSON-safe values."""
105+
return {
106+
str(name): float(value)
107+
for name, value in model.get_fitted_params().items()
108+
if isinstance(value, (int, float, np.integer, np.floating))
109+
}
110+
102111
def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, Any]):
103112
try:
104113
self.forecast_output.init_series_output(series_id=series_id, data_at_series=df)
@@ -227,10 +236,15 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
227236
self.models[series_id]["model_params"] = model_kwargs
228237
self.models[series_id]["le"] = self.le[series_id]
229238

230-
params = vars(model).copy()
239+
model_params = model.get_params()
231240
self.model_parameters[series_id] = {
232241
"framework": SupportedModels.Theta,
233-
**params,
242+
"initial_level": model_params.get("initial_level"),
243+
"deseasonalize": model_params.get("deseasonalize"),
244+
"deseasonalize_model": model_kwargs.get("deseasonalize_model"),
245+
"sp": model_params.get("sp"),
246+
"uses_additive_deseasonalization": using_additive_deseasonalization,
247+
"fit_params": self._fit_params_as_dict(model),
234248
}
235249

236250
logger.debug("===========Done===========")

docs/source/user_guide/operators/forecast_operator/yaml_schema.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ Further Description
267267
* **format**: (Optional) Specify the format for output data (e.g., ``csv``, ``json``, ``excel``).
268268
* **options**: (Optional) Include any additional arguments, such as connection parameters for storage.
269269

270-
* **model**: (Optional) The name of the model framework to use. Defaults to ``auto-select``. Available options include ``arima``, ``prophet``, ``theta``, ``ets``, ``neuralprophet``, ``autots``, and ``auto-select``.
270+
* **model**: (Optional) The name of the model framework to use. Defaults to ``prophet``. Available options include ``prophet``, ``arima``, ``neuralprophet``, ``theta``, ``ets``, ``lgbforecast``, ``xgbforecast``, ``automlx``, ``autots``, ``auto-select``, and ``auto-select-series``.
271271

272272
* **model_kwargs**: (Optional) A dictionary of arguments to pass directly to the model framework, allowing for detailed control over modeling.
273273

tests/operators/forecast/test_explainers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"target_category_columns": [],
4747
"horizon": None,
4848
"generate_explanations": True,
49+
"generate_model_parameters" : True,
4950
},
5051
}
5152

0 commit comments

Comments
 (0)