Skip to content

Commit ab35ba2

Browse files
authored
LightGBM Support for Forecast Operator (#1306)
2 parents 5ef6e9c + e00bb12 commit ab35ba2

File tree

5 files changed

+59
-67
lines changed

5 files changed

+59
-67
lines changed

ads/opctl/operator/lowcode/forecast/model/ml_forecast.py

Lines changed: 50 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def set_kwargs(self):
4141
model_kwargs["uppper_quantile"] = uppper_quantile
4242
return model_kwargs
4343

44+
4445
def preprocess(self, df, series_id):
4546
pass
4647

@@ -53,54 +54,70 @@ def preprocess(self, df, series_id):
5354
err_msg="lightgbm is not installed, please install it with 'pip install lightgbm'",
5455
)
5556
def _train_model(self, data_train, data_test, model_kwargs):
57+
import lightgbm as lgb
58+
from mlforecast import MLForecast
59+
from mlforecast.lag_transforms import ExpandingMean, RollingMean
60+
from mlforecast.target_transforms import Differences
61+
62+
def set_model_config(freq):
63+
seasonal_map = {
64+
"H": 24,
65+
"D": 7,
66+
"W": 52,
67+
"M": 12,
68+
"Q": 4,
69+
}
70+
sp = seasonal_map.get(freq.upper(), 7)
71+
series_lengths = data_train.groupby(ForecastOutputColumns.SERIES).size()
72+
min_len = series_lengths.min()
73+
max_allowed = min_len - sp
74+
75+
default_lags = [lag for lag in [1, sp, 2 * sp] if lag <= max_allowed]
76+
lags = model_kwargs.get("lags", default_lags)
77+
78+
default_roll = 2 * sp
79+
roll = model_kwargs.get("RollingMean", default_roll)
80+
81+
default_diff = sp if sp <= max_allowed else None
82+
diff = model_kwargs.get("Differences", default_diff)
83+
84+
return {
85+
"target_transforms": [Differences([diff])],
86+
"lags": lags,
87+
"lag_transforms": {
88+
1: [ExpandingMean()],
89+
sp: [RollingMean(window_size=roll, min_samples=1)]
90+
}
91+
}
92+
5693
try:
57-
import lightgbm as lgb
58-
from mlforecast import MLForecast
59-
from mlforecast.lag_transforms import ExpandingMean, RollingMean
60-
from mlforecast.target_transforms import Differences
6194

6295
lgb_params = {
6396
"verbosity": model_kwargs.get("verbosity", -1),
6497
"num_leaves": model_kwargs.get("num_leaves", 512),
6598
}
66-
additional_data_params = {}
67-
if len(self.datasets.get_additional_data_column_names()) > 0:
68-
additional_data_params = {
69-
"target_transforms": [
70-
Differences([model_kwargs.get("Differences", 12)])
71-
],
72-
"lags": model_kwargs.get("lags", [1, 6, 12]),
73-
"lag_transforms": (
74-
{
75-
1: [ExpandingMean()],
76-
12: [
77-
RollingMean(
78-
window_size=model_kwargs.get("RollingMean", 24),
79-
min_samples=1,
80-
)
81-
],
82-
}
83-
),
84-
}
99+
100+
data_freq = pd.infer_freq(data_train[self.date_col].drop_duplicates()) \
101+
or pd.infer_freq(data_train[self.date_col].drop_duplicates()[-5:])
102+
103+
additional_data_params = set_model_config(data_freq)
85104

86105
fcst = MLForecast(
87106
models={
88107
"forecast": lgb.LGBMRegressor(**lgb_params),
89-
# "p" + str(int(model_kwargs["uppper_quantile"] * 100))
90108
"upper": lgb.LGBMRegressor(
91109
**lgb_params,
92110
objective="quantile",
93111
alpha=model_kwargs["uppper_quantile"],
94112
),
95-
# "p" + str(int(model_kwargs["lower_quantile"] * 100))
96113
"lower": lgb.LGBMRegressor(
97114
**lgb_params,
98115
objective="quantile",
99116
alpha=model_kwargs["lower_quantile"],
100117
),
101118
},
102-
freq=pd.infer_freq(data_train[self.date_col].drop_duplicates())
103-
or pd.infer_freq(data_train[self.date_col].drop_duplicates()[-5:]),
119+
freq=data_freq,
120+
date_features=['year', 'month', 'day', 'dayofweek', 'dayofyear'],
104121
**additional_data_params,
105122
)
106123

@@ -158,6 +175,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
158175
self.model_parameters[s_id] = {
159176
"framework": SupportedModels.LGBForecast,
160177
**lgb_params,
178+
**fcst.models_['forecast'].get_params(),
161179
}
162180

163181
logger.debug("===========Done===========")
@@ -191,48 +209,21 @@ def _generate_report(self):
191209
Generates the report for the model
192210
"""
193211
import report_creator as rc
194-
from utilsforecast.plotting import plot_series
195212

196213
logging.getLogger("report_creator").setLevel(logging.WARNING)
197214

198-
# Section 1: Forecast Overview
199-
sec1_text = rc.Block(
200-
rc.Heading("Forecast Overview", level=2),
201-
rc.Text(
202-
"These plots show your forecast in the context of historical data."
203-
),
204-
)
205-
sec_1 = _select_plot_list(
206-
lambda s_id: plot_series(
207-
self.datasets.get_all_data_long(include_horizon=False),
208-
pd.concat(
209-
[self.fitted_values, self.outputs], axis=0, ignore_index=True
210-
),
211-
id_col=ForecastOutputColumns.SERIES,
212-
time_col=self.spec.datetime_column.name,
213-
target_col=self.original_target_column,
214-
seed=42,
215-
ids=[s_id],
216-
),
217-
self.datasets.list_series_ids(),
218-
)
219-
220215
# Section 2: LGBForecast Model Parameters
221216
sec2_text = rc.Block(
222217
rc.Heading("LGBForecast Model Parameters", level=2),
223218
rc.Text("These are the parameters used for the LGBForecast model."),
224219
)
225220

226-
blocks = [
227-
rc.Html(
228-
str(s_id[1]),
229-
label=s_id[0],
230-
)
231-
for _, s_id in enumerate(self.model_parameters.items())
232-
]
233-
sec_2 = rc.Select(blocks=blocks)
221+
k, v = next(iter(self.model_parameters.items()))
222+
sec_2 = rc.Html(
223+
pd.DataFrame(list(v.items())).to_html(index=False, header=False),
224+
)
234225

235-
all_sections = [sec1_text, sec_1, sec2_text, sec_2]
226+
all_sections = [sec2_text, sec_2]
236227
model_description = rc.Text(
237228
"LGBForecast uses mlforecast framework to perform time series forecasting using machine learning models"
238229
"with the option to scale to massive amounts of data using remote clusters."

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ spec:
455455
- prophet
456456
- arima
457457
- neuralprophet
458-
# - lgbforecast
458+
- lgbforecast
459459
- automlx
460460
- autots
461461
- auto-select

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ forecast = [
177177
"autots",
178178
"mlforecast",
179179
"neuralprophet>=0.7.0",
180+
"pytorch-lightning==2.5.5",
180181
"numpy<2.0.0",
181182
"oci-cli",
182183
"optuna",

tests/operators/forecast/test_datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"prophet",
3333
"neuralprophet",
3434
"autots",
35-
# "lgbforecast",
35+
"lgbforecast",
3636
"auto-select",
3737
"auto-select-series",
3838
]
@@ -177,7 +177,7 @@ def test_load_datasets(model, data_details):
177177
subprocess.run(f"ls -a {output_data_path}", shell=True)
178178
if yaml_i["spec"]["generate_explanations"] and model not in [
179179
"automlx",
180-
# "lgbforecast",
180+
"lgbforecast",
181181
"auto-select",
182182
]:
183183
verify_explanations(

tests/operators/forecast/test_errors.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@
143143
"prophet",
144144
"neuralprophet",
145145
"autots",
146-
# "lgbforecast",
146+
"lgbforecast",
147147
]
148148

149149
TEMPLATE_YAML = {
@@ -415,8 +415,8 @@ def test_0_series(operator_setup, model):
415415
"local_explanation.csv",
416416
"global_explanation.csv",
417417
]
418-
if model == "autots":
419-
# explanations are not supported for autots
418+
if model in ["autots", "lgbforecast"]:
419+
# explanations are not supported for autots or lgbforecast
420420
output_files.remove("local_explanation.csv")
421421
output_files.remove("global_explanation.csv")
422422
for file in output_files:
@@ -709,7 +709,7 @@ def test_arima_automlx_errors(operator_setup, model):
709709
in error_content["13"]["model_fitting"]["error"]
710710
), f"Error message mismatch: {error_content}"
711711

712-
if model not in ["autots", "automlx"]: # , "lgbforecast"
712+
if model not in ["autots", "automlx", "lgbforecast"]:
713713
if yaml_i["spec"].get("explanations_accuracy_mode") != "AUTOMLX":
714714
global_fn = f"{tmpdirname}/results/global_explanation.csv"
715715
assert os.path.exists(
@@ -816,7 +816,7 @@ def test_date_format(operator_setup, model):
816816
@pytest.mark.parametrize("model", MODELS)
817817
def test_what_if_analysis(operator_setup, model):
818818
os.environ["TEST_MODE"] = "True"
819-
if model == "auto-select":
819+
if model in ["auto-select", "lgbforecast"]:
820820
pytest.skip("Skipping what-if scenario for auto-select")
821821
tmpdirname = operator_setup
822822
historical_data_path, additional_data_path = setup_small_rossman()

0 commit comments

Comments
 (0)