Skip to content

Commit b15138c

Browse files
authored
Fix auto-select on model failure (#998)
2 parents 0ed3daa + 31295b5 commit b15138c

File tree

2 files changed

+28
-24
lines changed

2 files changed

+28
-24
lines changed

ads/opctl/operator/lowcode/forecast/model_evaluator.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,23 +121,26 @@ def run_all_models(self, datasets: ForecastDatasets, operator_config: ForecastOp
121121
from .model.factory import ForecastOperatorModelFactory
122122
metrics[model] = {}
123123
for i in range(len(cut_offs)):
124-
backtest_historical_data = train_sets[i]
125-
backtest_additional_data = additional_data[i]
126-
backtest_test_data = test_sets[i]
127-
backtest_operator_config = self.create_operator_config(operator_config, i, model,
128-
backtest_historical_data,
129-
backtest_additional_data,
130-
backtest_test_data)
131-
datasets = ForecastDatasets(backtest_operator_config)
132-
ForecastOperatorModelFactory.get_model(
133-
backtest_operator_config, datasets
134-
).generate_report()
135-
test_metrics_filename = backtest_operator_config.spec.test_metrics_filename
136-
metrics_df = pd.read_csv(
137-
f"{backtest_operator_config.spec.output_directory.url}/{test_metrics_filename}")
138-
metrics_df["average_across_series"] = metrics_df.drop('metrics', axis=1).mean(axis=1)
139-
metrics_average_dict = dict(zip(metrics_df['metrics'].str.lower(), metrics_df['average_across_series']))
140-
metrics[model][i] = metrics_average_dict[operator_config.spec.metric]
124+
try:
125+
backtest_historical_data = train_sets[i]
126+
backtest_additional_data = additional_data[i]
127+
backtest_test_data = test_sets[i]
128+
backtest_operator_config = self.create_operator_config(operator_config, i, model,
129+
backtest_historical_data,
130+
backtest_additional_data,
131+
backtest_test_data)
132+
datasets = ForecastDatasets(backtest_operator_config)
133+
ForecastOperatorModelFactory.get_model(
134+
backtest_operator_config, datasets
135+
).generate_report()
136+
test_metrics_filename = backtest_operator_config.spec.test_metrics_filename
137+
metrics_df = pd.read_csv(
138+
f"{backtest_operator_config.spec.output_directory.url}/{test_metrics_filename}")
139+
metrics_df["average_across_series"] = metrics_df.drop('metrics', axis=1).mean(axis=1)
140+
metrics_average_dict = dict(zip(metrics_df['metrics'].str.lower(), metrics_df['average_across_series']))
141+
metrics[model][i] = metrics_average_dict[operator_config.spec.metric]
142+
except:
143+
logger.warn(f"Failed to calculate metrics for {model} and {i} backtest")
141144
return metrics
142145

143146
def find_best_model(self, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig):
@@ -147,10 +150,12 @@ def find_best_model(self, datasets: ForecastDatasets, operator_config: ForecastO
147150
model = SupportedModels.Prophet
148151
logger.error(f"Running {model} model as auto-select failed with the following error: {e.message}")
149152
return model
150-
avg_backtests_metrics = {key: sum(value.values()) / len(value.values()) for key, value in metrics.items()}
151-
best_model = min(avg_backtests_metrics, key=avg_backtests_metrics.get)
153+
nonempty_metrics = {model: metric for model, metric in metrics.items() if metric != {}}
154+
avg_backtests_metric = {model: sum(value.values()) / len(value.values())
155+
for model, value in nonempty_metrics.items()}
156+
best_model = min(avg_backtests_metric, key=avg_backtests_metric.get)
152157
logger.info(f"Among models {self.models}, {best_model} model shows better performance during backtesting.")
153-
backtest_stats = pd.DataFrame(metrics).rename_axis('backtest')
158+
backtest_stats = pd.DataFrame(nonempty_metrics).rename_axis('backtest')
154159
backtest_stats.reset_index(inplace=True)
155160
output_dir = operator_config.spec.output_directory.url
156161
backtest_report_name = "backtest_stats.csv"

ads/opctl/operator/lowcode/forecast/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
mean_absolute_percentage_error,
2020
mean_squared_error,
2121
)
22-
try:
23-
from scipy.stats import linregress
24-
except:
25-
from sklearn.metrics import r2_score
22+
23+
from scipy.stats import linregress
24+
from sklearn.metrics import r2_score
2625

2726
from ads.common.object_storage_details import ObjectStorageDetails
2827
from ads.dataset.label_encoder import DataFrameLabelEncoder

0 commit comments

Comments
 (0)