diff --git a/machine_learning/forecasting/run.py b/machine_learning/forecasting/run.py index 9d81b03cd09e..787806abb94f 100644 --- a/machine_learning/forecasting/run.py +++ b/machine_learning/forecasting/run.py @@ -13,8 +13,10 @@ from warnings import simplefilter +import matplotlib.pyplot as plt import numpy as np import pandas as pd +from sklearn.ensemble import RandomForestRegressor from sklearn.preprocessing import Normalizer from sklearn.svm import SVR from statsmodels.tsa.statespace.sarimax import SARIMAX @@ -78,6 +80,29 @@ def support_vector_regressor(x_train: list, x_test: list, train_user: list) -> f return float(y_pred[0]) +def random_forest_regressor(x_train: list, x_test: list, train_user: list) -> float: + """ + Fourth method: Random Forest Regressor + Random Forest is an ensemble learning method for regression that operates + by constructing a multitude of decision trees at training time and outputting + the mean prediction of the individual trees. + + It is more robust than a single decision tree and less prone to overfitting. + Good for capturing nonlinear relationships in data. + + input : training data (date, total_event) in list of float + where x = list of set (date and total event) + output : list of total user prediction in float + + >>> random_forest_regressor([[5,2],[1,5],[6,2]], [[3,2]], [2,1,4]) + 1.95 + """ + model = RandomForestRegressor(n_estimators=100, random_state=42) + model.fit(x_train, train_user) + prediction = model.predict(x_test) + return float(prediction[0]) + + def interquartile_range_checker(train_user: list) -> float: """ Optional method: interquatile range @@ -120,6 +145,22 @@ def data_safety_checker(list_vote: list, actual_result: float) -> bool: return safe > not_safe +def plot_forecast(actual, predictions): + plt.figure(figsize=(10, 5)) + plt.plot(range(len(actual)), actual, label="Actual") + plt.plot(len(actual), predictions[0], "ro", label="Linear Reg") + plt.plot(len(actual), predictions[1], "go", label="SARIMAX") + plt.plot(len(actual), predictions[2], "bo", label="SVR") + plt.plot(len(actual), predictions[3], "yo", label="RF") + plt.legend() + plt.title("Data Safety Forecast") + plt.xlabel("Days") + plt.ylabel("Normalized User Count") + plt.grid(True) + plt.tight_layout() + plt.show() + + if __name__ == "__main__": """ data column = total user in a day, how much online event held in one day, @@ -155,8 +196,11 @@ def data_safety_checker(list_vote: list, actual_result: float) -> bool: ), sarimax_predictor(train_user, train_match, test_match), support_vector_regressor(x_train, x_test, train_user), + random_forest_regressor(x_train, x_test, train_user), ] # check the safety of today's data not_str = "" if data_safety_checker(res_vote, test_user[0]) else "not " print(f"Today's data is {not_str}safe.") + + plot_forecast(train_user, res_vote)