diff --git a/README.md b/README.md index adca9b2..e79ace1 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,8 @@ To run the app locally, you'll need to connect it to the `forecast development d OCF team members can connect to the `forecast development database` using [these Notion instructions](https://www.notion.so/openclimatefix/Connecting-to-AWS-RDS-bf35b3fbd61f40df9c974c240e042354). Add `DB_URL= (db_url from notion documents)` to a `secrets.toml` file. Follow the instructions in the Notion document to connect to the database v. +To connect to the database platform, use `DATA_PLATFORM_HOST` and `DATA_PLATFORM_PORT`. + Run app: ```shell diff --git a/pyproject.toml b/pyproject.toml index 25ef385..267734f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "plotly==5.24.1", "psycopg2-binary==2.9.10", "SQLAlchemy==2.0.36", - "streamlit==1.46.1", + "streamlit==1.51.0", "testcontainers==4.9.0", "uvicorn==0.34.0", "geopandas==1.0.1", @@ -35,6 +35,8 @@ dependencies = [ "torch @ https://download.pytorch.org/whl/cpu/torch-2.3.1%2Bcpu-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64'", "torch @ https://download.pytorch.org/whl/cpu/torch-2.3.1-cp312-none-macosx_11_0_arm64.whl ; platform_system == 'Darwin' and platform_machine == 'arm64'", "matplotlib>=3.8,<4.0", + "dp-sdk", + "aiocache", ] [project.optional-dependencies] @@ -66,6 +68,9 @@ dev-dependencies = [ index-url = "https://download.pytorch.org/whl/cpu" extra-index-url = ["https://pypi.org/simple"] +[tool.uv.sources] +dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.13.2/dp_sdk-0.13.2-py3-none-any.whl" } + [tool.pytest.ini_options] testpaths = ["tests"] python_files = ["test_*.py"] diff --git a/src/dataplatform/__init__.py b/src/dataplatform/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dataplatform/forecast/__init__.py b/src/dataplatform/forecast/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dataplatform/forecast/cache.py b/src/dataplatform/forecast/cache.py new file mode 100644 index 0000000..a14402b --- /dev/null +++ b/src/dataplatform/forecast/cache.py @@ -0,0 +1,27 @@ +"""Cache utilities for the forecast module.""" + +from datetime import UTC, datetime, timedelta + +from dp_sdk.ocf import dp + +from dataplatform.forecast.constant import cache_seconds + + +def key_builder_remove_client(func: callable, *args: list, **kwargs: dict) -> str: + """Custom key builder that ignores the client argument for caching purposes.""" + key = f"{func.__name__}:" + for arg in args: + if not isinstance(arg, dp.DataPlatformDataServiceStub): + key += f"{arg}-" + + for k, v in kwargs.items(): + key += f"{k}={v}-" + + # get the time now to the closest 5 minutes, this forces a new cache every 5 minutes + current_time = datetime.now(UTC).replace(second=0, microsecond=0) + current_time = current_time - timedelta( + minutes=current_time.minute % (int(cache_seconds / 60)), + ) + key += f"time={current_time}-" + + return key diff --git a/src/dataplatform/forecast/constant.py b/src/dataplatform/forecast/constant.py new file mode 100644 index 0000000..5c51245 --- /dev/null +++ b/src/dataplatform/forecast/constant.py @@ -0,0 +1,24 @@ +"""Constants for the forecast module.""" + +colours = [ + "#FFD480", + "#FF8F73", + "#4675C1", + "#65B0C9", + "#58B0A9", + "#FAA056", + "#306BFF", + "#FF4901", + "#B701FF", + "#17E58F", +] + +metrics = { + "MAE": "MAE is absolute mean error, average(abs(forecast-actual))", + "ME": "ME is mean (bias) error, average((forecast-actual))", +} + +cache_seconds = 300 # 5 minutes + +# This is used for a specific case for the UK National and GSP +observer_names = ["pvlive_in_day", "pvlive_day_after"] diff --git a/src/dataplatform/forecast/data.py b/src/dataplatform/forecast/data.py new file mode 100644 index 0000000..d4c3d46 --- /dev/null +++ b/src/dataplatform/forecast/data.py @@ -0,0 +1,243 @@ +"""Functions to get forecast and observation data from Data Platform.""" + +import time +from datetime import datetime, timedelta + +import betterproto +import pandas as pd +from aiocache import Cache, cached +from dp_sdk.ocf import dp + +from dataplatform.forecast.cache import key_builder_remove_client +from dataplatform.forecast.constant import cache_seconds, observer_names + + +async def get_forecast_data( + client: dp.DataPlatformDataServiceStub, + location: dp.ListLocationsResponseLocationSummary, + start_date: datetime, + end_date: datetime, + selected_forecasters: list[dp.Forecaster], +) -> pd.DataFrame: + """Get forecast data for the given location and time window.""" + all_data_df = [] + + for forecaster in selected_forecasters: + forecaster_data_df = await get_forecast_data_one_forecaster( + client, + location, + start_date, + end_date, + forecaster, + ) + if forecaster_data_df is not None: + all_data_df.append(forecaster_data_df) + + all_data_df = pd.concat(all_data_df, ignore_index=True) + + all_data_df["effective_capacity_watts"] = all_data_df["effective_capacity_watts"].astype(float) + + # get watt value + all_data_df["p50_watts"] = all_data_df["p50_fraction"] * all_data_df["effective_capacity_watts"] + + for col in ["p10", "p25", "p75", "p90"]: + col_fraction = f"{col}_fraction" + if col_fraction in all_data_df.columns: + all_data_df[f"{col}_watts"] = ( + all_data_df[col_fraction] * all_data_df["effective_capacity_watts"] + ) + + return all_data_df + + +@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client) +async def get_forecast_data_one_forecaster( + client: dp, + location: dp.ListLocationsResponseLocationSummary, + start_date: datetime, + end_date: datetime, + selected_forecaster: dp.Forecaster, +) -> pd.DataFrame | None: + """Get forecast data for one forecaster for the given location and time window.""" + all_data_list_dict = [] + + # Grab all the data, in chunks of 30 days to avoid too large requests + temp_start_date = start_date + while temp_start_date <= end_date: + temp_end_date = min(temp_start_date + timedelta(days=30), end_date) + + # fetch data + stream_forecast_data_request = dp.StreamForecastDataRequest( + location_uuid=location.location_uuid, + energy_source=dp.EnergySource.SOLAR, + time_window=dp.TimeWindow( + start_timestamp_utc=temp_start_date, + end_timestamp_utc=temp_end_date, + ), + forecasters=[selected_forecaster], + ) + forecasts = [] + async for chunk in client.stream_forecast_data(stream_forecast_data_request): + forecasts.append( + chunk.to_dict(include_default_values=True, casing=betterproto.Casing.SNAKE), + ) + + if len(forecasts) > 0: + all_data_list_dict.extend(forecasts) + + temp_start_date = temp_start_date + timedelta(days=30) + + all_data_df = pd.DataFrame.from_dict(all_data_list_dict) + if len(all_data_df) == 0: + return None + + # get plevels into columns and rename them 'fraction + columns_before_expand = set(all_data_df.columns) + all_data_df = all_data_df.pipe( + lambda df: df.join(pd.json_normalize(df["other_statistics_fractions"])), + ).drop("other_statistics_fractions", axis=1) + new_columns = set(all_data_df.columns) - columns_before_expand + if len(new_columns) > 0: + all_data_df = all_data_df.rename(columns={col: f"{col}_fraction" for col in new_columns}) + + # create column forecaster_name, its forecaster_fullname with version removed + all_data_df["forecaster_name"] = all_data_df["forecaster_fullname"].apply( + lambda x: x.rsplit(":", 1)[0], # split from right, max 1 split + ) + + return all_data_df + + +@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client) +async def get_all_observations( + client: dp.DataPlatformDataServiceStub, + location: dp.ListLocationsResponseLocationSummary, + start_date: datetime, + end_date: datetime, +) -> pd.DataFrame: + """Get all observations for the given location and time window.""" + all_observations_df = [] + + for observer_name in observer_names: + # Get all the observations for this observer_name, in chunks of 7 days + observation_one_df = [] + temp_start_date = start_date + while temp_start_date <= end_date: + temp_end_date = min(temp_start_date + timedelta(days=7), end_date) + + get_observations_request = dp.GetObservationsAsTimeseriesRequest( + observer_name=observer_name, + location_uuid=location.location_uuid, + energy_source=dp.EnergySource.SOLAR, + time_window=dp.TimeWindow(temp_start_date, temp_end_date), + ) + get_observations_response = await client.get_observations_as_timeseries( + get_observations_request, + ) + + observations = [] + for chunk in get_observations_response.values: + observations.append( + chunk.to_dict(include_default_values=True, casing=betterproto.Casing.SNAKE), + ) + + observation_one_df.append(pd.DataFrame.from_dict(observations)) + + temp_start_date = temp_start_date + timedelta(days=7) + + observation_one_df = pd.concat(observation_one_df, ignore_index=True) + observation_one_df = observation_one_df.sort_values(by="timestamp_utc") + observation_one_df["observer_name"] = observer_name + + all_observations_df.append(observation_one_df) + + all_observations_df = pd.concat(all_observations_df, ignore_index=True) + + all_observations_df["effective_capacity_watts"] = all_observations_df[ + "effective_capacity_watts" + ].astype(float) + + all_observations_df["value_watts"] = ( + all_observations_df["value_fraction"] * all_observations_df["effective_capacity_watts"] + ) + all_observations_df["timestamp_utc"] = pd.to_datetime(all_observations_df["timestamp_utc"]) + + return all_observations_df + + +async def get_all_data( + client: dp.DataPlatformDataServiceStub, + selected_location: dp.ListLocationsResponseLocationSummary, + start_date: datetime, + end_date: datetime, + selected_forecasters: list[dp.Forecaster], +) -> dict: + """Get all forecast and observation data, and merge them.""" + # get generation data + time_start = time.time() + all_observations_df = await get_all_observations( + client, + selected_location, + start_date, + end_date, + ) + observation_seconds = time.time() - time_start + + # get forcast all data + time_start = time.time() + all_forecast_data_df = await get_forecast_data( + client, + selected_location, + start_date, + end_date, + selected_forecasters, + ) + forecast_seconds = time.time() - time_start + + # If the observation data includes pvlive_day_after and pvlive_in_day, + # then lets just take pvlive_day_after + one_observations_df = all_observations_df.copy() + if "pvlive_day_after" in all_observations_df["observer_name"].values: + one_observations_df = all_observations_df[ + all_observations_df["observer_name"] == "pvlive_day_after" + ] + + # make target_timestamp_utc + all_forecast_data_df["init_timestamp"] = pd.to_datetime(all_forecast_data_df["init_timestamp"]) + all_forecast_data_df["target_timestamp_utc"] = all_forecast_data_df[ + "init_timestamp" + ] + pd.to_timedelta(all_forecast_data_df["horizon_mins"], unit="m") + + # take the foecast data, and group by horizonMins, forecasterFullName + # calculate mean absolute error between p50Fraction and observations valueFraction + merged_df = pd.merge( + all_forecast_data_df, + one_observations_df, + left_on=["target_timestamp_utc"], + right_on=["timestamp_utc"], + how="inner", + suffixes=("_forecast", "_observation"), + ) + + # error and absolute error + merged_df["error"] = merged_df["p50_watts"] - merged_df["value_watts"] + merged_df["absolute_error"] = merged_df["error"].abs() + + return { + "merged_df": merged_df, + "all_forecast_data_df": all_forecast_data_df, + "all_observations_df": all_observations_df, + "forecast_seconds": forecast_seconds, + "observation_seconds": observation_seconds, + } + + +def align_t0(merged_df: pd.DataFrame) -> pd.DataFrame: + """Align t0 forecasts for different forecasters.""" + # number of unique forecasters + num_forecasters = merged_df["forecaster_name"].nunique() + # Count number of forecasters that have each t0 time + counts = merged_df.groupby("init_timestamp")["forecaster_name"].nunique() + # Filter to just those t0s that all forecasters have + common_t0s = counts[counts == num_forecasters].index + return merged_df[merged_df["init_timestamp"].isin(common_t0s)] diff --git a/src/dataplatform/forecast/main.py b/src/dataplatform/forecast/main.py new file mode 100644 index 0000000..3e9d9cd --- /dev/null +++ b/src/dataplatform/forecast/main.py @@ -0,0 +1,299 @@ +"""Data Platform Forecast Streamlit Page Main Code.""" + +import asyncio +import os + +import pandas as pd +import streamlit as st +from dp_sdk.ocf import dp +from grpclib.client import Channel + +from dataplatform.forecast.constant import metrics, observer_names +from dataplatform.forecast.data import align_t0, get_all_data +from dataplatform.forecast.plot import ( + plot_forecast_metric_per_day, + plot_forecast_metric_vs_horizon_minutes, + plot_forecast_time_series, +) +from dataplatform.forecast.setup import setup_page + +data_platform_host = os.getenv("DATA_PLATFORM_HOST", "localhost") +data_platform_port = int(os.getenv("DATA_PLATFORM_PORT", "50051")) + + +def dp_forecast_page() -> None: + """Wrapper function that is not async to call the main async function.""" + asyncio.run(async_dp_forecast_page()) + + +async def async_dp_forecast_page() -> None: + """Async Main function for the Data Platform Forecast Streamlit page.""" + st.title("Data Platform Forecast Page") + st.write("This is the forecast page from the Data Platform module. This is very much a WIP") + + async with Channel(host=data_platform_host, port=data_platform_port) as channel: + client = dp.DataPlatformDataServiceStub(channel) + + setup_page_dict = await setup_page(client) + selected_location = setup_page_dict["selected_location"] + start_date = setup_page_dict["start_date"] + end_date = setup_page_dict["end_date"] + selected_forecasters = setup_page_dict["selected_forecasters"] + forecaster_names = setup_page_dict["forecaster_names"] + selected_metric = setup_page_dict["selected_metric"] + selected_forecast_type = setup_page_dict["selected_forecast_type"] + scale_factor = setup_page_dict["scale_factor"] + selected_forecast_horizon = setup_page_dict["selected_forecast_horizon"] + selected_t0s = setup_page_dict["selected_t0s"] + units = setup_page_dict["units"] + strict_horizon_filtering = setup_page_dict["strict_horizon_filtering"] + + ### 1. Get all the data ### + all_data_dict = await get_all_data( + client=client, + start_date=start_date, + end_date=end_date, + selected_forecasters=selected_forecasters, + selected_location=selected_location, + ) + + merged_df = all_data_dict["merged_df"] + all_forecast_data_df = all_data_dict["all_forecast_data_df"] + all_observations_df = all_data_dict["all_observations_df"] + forecast_seconds = all_data_dict["forecast_seconds"] + observation_seconds = all_data_dict["observation_seconds"] + + st.write(f"Selected Location uuid: `{selected_location.location_uuid}`.") + st.write( + f"Fetched `{len(all_forecast_data_df)}` rows of forecast data \ + in `{forecast_seconds:.2f}` seconds. \ + Fetched `{len(all_observations_df)}` rows of observation data \ + in `{observation_seconds:.2f}` seconds. \ + We cache data for 5 minutes to speed up repeated requests.", + ) + + # add download button + csv = merged_df.to_csv().encode("utf-8") + st.download_button( + label="⬇️ Download data", + data=csv, + file_name=f"site_forecast_{selected_location.location_uuid}_{start_date}_{end_date}.csv", + mime="text/csv", + help="Download the forecast and generation data as a CSV file.", + ) + + ### 2. Plot of raw forecast data. ### + st.header("Time Series Plot") + + show_probabilistic = st.checkbox("Show Probabilistic Forecasts", value=True) + + fig = plot_forecast_time_series( + all_forecast_data_df=all_forecast_data_df, + all_observations_df=all_observations_df, + forecaster_names=forecaster_names, + observer_names=observer_names, + scale_factor=scale_factor, + units=units, + selected_forecast_type=selected_forecast_type, + selected_forecast_horizon=selected_forecast_horizon, + selected_t0s=selected_t0s, + show_probabilistic=show_probabilistic, + strict_horizon_filtering=strict_horizon_filtering, + ) + st.plotly_chart(fig) + + ### 3. Summary Accuracy Graph. ### + st.header("Accuracy") + + st.write(metrics) + + align_t0s = st.checkbox( + "Align t0s (Only common t0s across all forecaster are used)", + value=True, + ) + if align_t0s: + merged_df = align_t0(merged_df) + + st.subheader("Metric vs Forecast Horizon") + + if selected_metric == "MAE": + show_sem = st.checkbox( + "Show Uncertainty", + value=True, + help="On the plot below show the uncertainty bands associated with the MAE. " + "This is done by looking at the " + "Standard Error of the Mean (SEM) of the absolute errors. " + "We plot the 5 to 95 percentile range around the MAE.", + ) + else: + show_sem = False + + summary_df = make_summary_data_metric_vs_horizon_minutes(merged_df) + + fig2 = plot_forecast_metric_vs_horizon_minutes( + summary_df, + forecaster_names, + selected_metric, + scale_factor, + units, + show_sem, + ) + + st.plotly_chart(fig2) + + csv = summary_df.to_csv().encode("utf-8") + st.download_button( + label="⬇️ Download summary", + data=csv, + file_name=f"summary_accuracy_{selected_location.location_uuid}_{start_date}_{end_date}.csv", + mime="text/csv", + help="Download the summary accuracy data as a CSV file.", + ) + + ### 4. Summary Accuracy Table, with slider to select min and max horizon mins. ### + st.subheader("Summary Accuracy Table") + + # add slider to select min and max horizon mins + default_min_horizon = int(summary_df["horizon_mins"].min()) + default_max_horizon = int(summary_df["horizon_mins"].max()) + min_horizon, max_horizon = st.slider( + "Select Horizon Mins Range", + default_min_horizon, + default_max_horizon, + ( + default_min_horizon, + default_max_horizon, + ), + step=30, + ) + + summary_table_df = make_summary_data( + merged_df=merged_df, + min_horizon=min_horizon, + max_horizon=max_horizon, + scale_factor=scale_factor, + units=units, + ) + + st.dataframe(summary_table_df) + + ### 4. Daily metric plots. ### + st.subheader("Daily Metrics Plots") + st.write( + "Plotted below are the daily MAE for each forecaster. " + "This is for all forecast horizons.", + ) + + fig3 = plot_forecast_metric_per_day( + merged_df=merged_df, + forecaster_names=forecaster_names, + scale_factor=scale_factor, + units=units, + selected_metric=selected_metric, + ) + + st.plotly_chart(fig3) + + st.header("Known Issues and TODOs") + + st.write("Add more metrics") + st.write("Group adjust and non-adjust") + st.write("speed up read, use async and more caching") + st.write("Get page working with no observations data") + + +def make_summary_data( + merged_df: pd.DataFrame, + min_horizon: int, + max_horizon: int, + scale_factor: float, + units: str, +) -> pd.DataFrame: + """Make summary data table for given min and max horizon mins.""" + # Reduce my horizon mins + summary_table_df = merged_df[ + (merged_df["horizon_mins"] >= min_horizon) & (merged_df["horizon_mins"] <= max_horizon) + ] + + capacity_watts_col = "effective_capacity_watts_observation" + + value_columns = [ + "error", + "absolute_error", + "value_watts", + capacity_watts_col, + ] + summary_table_df = summary_table_df[["forecaster_name", *value_columns]] + + # group by forecaster full name a + summary_table_df = summary_table_df.groupby("forecaster_name").mean() + + # scale by units + summary_table_df = summary_table_df / scale_factor + summary_table_df = summary_table_df.rename( + {col: f"{col} [{units}]" for col in summary_table_df.columns}, + axis=1, + ) + + # pivot table, so forecaster_name is columns + summary_table_df = summary_table_df.pivot_table( + columns=summary_table_df.index, + values=summary_table_df.columns.tolist(), + ) + + # rename + summary_table_df = summary_table_df.rename( + columns={ + "error": "ME", + "absolute_error": "MAE", + capacity_watts_col: "Mean Capacity", + "value_watts": "Mean Observed Generation", + }, + ) + + return summary_table_df + + +def make_summary_data_metric_vs_horizon_minutes( + merged_df: pd.DataFrame, +) -> pd.DataFrame: + """Make summary data for forecast metric vs horizon minutes.""" + # Get the mean observed generation + mean_observed_generation = merged_df["value_watts"].mean() + + # mean absolute error by horizonMins and forecasterFullName + summary_df = ( + merged_df.groupby(["horizon_mins", "forecaster_name"]) + .agg( + { + "absolute_error": ["mean", "std", "count"], + "error": "mean", + }, + ) + .reset_index() + ) + + summary_df.columns = ["_".join(col).strip() for col in summary_df.columns.values] + summary_df.columns = [col[:-1] if col.endswith("_") else col for col in summary_df.columns] + + # calculate sem of MAE + summary_df["sem"] = summary_df["absolute_error_std"] / ( + summary_df["absolute_error_count"] ** 0.5 + ) + + # TODO more metrics + + summary_df["effective_capacity_watts_observation"] = ( + merged_df.groupby(["horizon_mins", "forecaster_name"]) + .agg({"effective_capacity_watts_observation": "mean"}) + .reset_index()["effective_capacity_watts_observation"] + ) + + # rename absolute_error to MAE + summary_df = summary_df.rename(columns={"absolute_error_mean": "MAE", "error_mean": "ME"}) + summary_df["NMAE (by capacity)"] = ( + summary_df["MAE"] / summary_df["effective_capacity_watts_observation"] + ) + summary_df["NMAE (by mean observed generation)"] = summary_df["MAE"] / mean_observed_generation + + return summary_df diff --git a/src/dataplatform/forecast/plot.py b/src/dataplatform/forecast/plot.py new file mode 100644 index 0000000..ff53bc4 --- /dev/null +++ b/src/dataplatform/forecast/plot.py @@ -0,0 +1,273 @@ +"""Plotting functions for forecast analysis.""" + +from datetime import datetime + +import pandas as pd +import plotly.graph_objects as go + +from dataplatform.forecast.constant import colours + + +def make_time_series_trace( + fig: go.Figure, + forecaster_df: pd.DataFrame, + forecaster_name: str, + scale_factor: float, + i: int, + show_probabilistic: bool = True, +) -> go.Figure: + """Make time series trace for a forecaster. + + Include p10 and p90 shading if show_probabilistic is True. + """ + fig.add_trace( + go.Scatter( + x=forecaster_df["target_timestamp_utc"], + y=forecaster_df["p50_watts"] / scale_factor, + mode="lines", + name=forecaster_name, + line={"color": colours[i % len(colours)]}, + legendgroup=forecaster_name, + ), + ) + if ( + show_probabilistic + and "p10_watts" in forecaster_df.columns + and "p90_watts" in forecaster_df.columns + ): + fig.add_trace( + go.Scatter( + x=forecaster_df["target_timestamp_utc"], + y=forecaster_df["p10_watts"] / scale_factor, + mode="lines", + line={"color": colours[i % len(colours)], "width": 0}, + legendgroup=forecaster_name, + showlegend=False, + ), + ) + + fig.add_trace( + go.Scatter( + x=forecaster_df["target_timestamp_utc"], + y=forecaster_df["p90_watts"] / scale_factor, + mode="lines", + line={"color": colours[i % len(colours)], "width": 0}, + legendgroup=forecaster_name, + showlegend=False, + fill="tonexty", + ), + ) + + return fig + + +def plot_forecast_time_series( + all_forecast_data_df: pd.DataFrame, + all_observations_df: pd.DataFrame, + forecaster_names: list[str], + observer_names: list[str], + scale_factor: float, + units: str, + selected_forecast_type: str, + selected_forecast_horizon: int, + selected_t0s: list[datetime], + show_probabilistic: bool = True, + strict_horizon_filtering: bool = False, +) -> go.Figure: + """Plot forecast time series. + + This make a plot of the raw forecasts and observations, for mulitple forecast. + """ + if selected_forecast_type == "Current": + # Choose current forecast + # this is done by selecting the unique target_timestamp_utc with the the lowest horizonMins + # it should also be unique for each forecasterFullName + current_forecast_df = all_forecast_data_df.loc[ + all_forecast_data_df.groupby(["target_timestamp_utc", "forecaster_name"])[ + "horizon_mins" + ].idxmin() + ] + elif selected_forecast_type == "Horizon": + # Choose horizon forecast + if strict_horizon_filtering: + current_forecast_df = all_forecast_data_df[ + all_forecast_data_df["horizon_mins"] == selected_forecast_horizon + ] + else: + current_forecast_df = all_forecast_data_df[ + all_forecast_data_df["horizon_mins"] >= selected_forecast_horizon + ] + current_forecast_df = current_forecast_df.loc[ + current_forecast_df.groupby(["target_timestamp_utc", "forecaster_name"])[ + "horizon_mins" + ].idxmin() + ] + elif selected_forecast_type == "t0": + current_forecast_df = all_forecast_data_df[ + all_forecast_data_df["init_timestamp"].isin(selected_t0s) + ] + + # plot the results + fig = go.Figure() + for observer_name in observer_names: + obs_df = all_observations_df[all_observations_df["observer_name"] == observer_name] + + if observer_name == "pvlive_in_day": + # dashed white line + line = {"color": "white", "dash": "dash"} + elif observer_name == "pvlive_day_after": + line = {"color": "white"} + else: + line = {} + + fig.add_trace( + go.Scatter( + x=obs_df["timestamp_utc"], + y=obs_df["value_watts"] / scale_factor, + mode="lines", + name=observer_name, + line=line, + ), + ) + + for i, forecaster_name in enumerate(forecaster_names): + forecaster_df = current_forecast_df[ + current_forecast_df["forecaster_name"] == forecaster_name + ] + if selected_forecast_type in ["Current", "Horizon"]: + fig = make_time_series_trace( + fig, + forecaster_df, + forecaster_name, + scale_factor, + i, + show_probabilistic, + ) + elif selected_forecast_type == "t0": + for _, t0 in enumerate(selected_t0s): + forecaster_with_t0_df = forecaster_df[forecaster_df["init_timestamp"] == t0] + forecaster_name_wth_t0 = f"{forecaster_name} | t0: {t0}" + fig = make_time_series_trace( + fig, + forecaster_with_t0_df, + forecaster_name_wth_t0, + scale_factor, + i, + show_probabilistic, + ) + + fig.update_layout( + title="Current Forecast", + xaxis_title="Time", + yaxis_title=f"Generation [{units}]", + legend_title="Forecaster", + ) + + return fig + + +def plot_forecast_metric_vs_horizon_minutes( + summary_df: pd.DataFrame, + forecaster_names: list[str], + selected_metric: str, + scale_factor: float, + units: str, + show_sem: bool, +) -> go.Figure: + """Plot forecast metric vs horizon minutes.""" + fig2 = go.Figure() + + for i, forecaster_name in enumerate(forecaster_names): + forecaster_df = summary_df[summary_df["forecaster_name"] == forecaster_name] + fig2.add_trace( + go.Scatter( + x=forecaster_df["horizon_mins"], + y=forecaster_df[selected_metric] / scale_factor, + mode="lines+markers", + name=forecaster_name, + line={"color": colours[i % len(colours)]}, + legendgroup=forecaster_name, + ), + ) + + if show_sem: + fig2.add_trace( + go.Scatter( + x=forecaster_df["horizon_mins"], + y=(forecaster_df[selected_metric] - 1.96 * forecaster_df["sem"]) / scale_factor, + mode="lines", + line={"color": colours[i % len(colours)], "width": 0}, + legendgroup=forecaster_name, + showlegend=False, + ), + ) + + fig2.add_trace( + go.Scatter( + x=forecaster_df["horizon_mins"], + y=(forecaster_df[selected_metric] + 1.96 * forecaster_df["sem"]) / scale_factor, + mode="lines", + line={"color": colours[i % len(colours)], "width": 0}, + legendgroup=forecaster_name, + showlegend=False, + fill="tonexty", + ), + ) + + fig2.update_layout( + title=f"{selected_metric} by Horizon", + xaxis_title="Horizon (Minutes)", + yaxis_title=f"{selected_metric} [{units}]", + legend_title="Forecaster", + ) + + if selected_metric == "MAE": + fig2.update_yaxes(range=[0, None]) + + return fig2 + + +def plot_forecast_metric_per_day( + merged_df: pd.DataFrame, + forecaster_names: list, + selected_metric: str, + scale_factor: float, + units: str, +) -> go.Figure: + """Plot forecast metric per day.""" + daily_plots_df = merged_df + daily_plots_df["date_utc"] = daily_plots_df["timestamp_utc"].dt.date + + # group by forecaster name and date + daily_metrics_df = ( + daily_plots_df.groupby(["date_utc", "forecaster_name"]) + .agg({"absolute_error": "mean", "error": "mean"}) + .reset_index() + ) + + daily_metrics_df = daily_metrics_df.rename(columns={"absolute_error": "MAE", "error": "ME"}) + + fig3 = go.Figure() + for i, forecaster_name in enumerate(forecaster_names): + name_and_version = f"{forecaster_name}" + forecaster_df = daily_metrics_df[daily_metrics_df["forecaster_name"] == name_and_version] + fig3.add_trace( + go.Scatter( + x=forecaster_df["date_utc"], + y=forecaster_df[selected_metric] / scale_factor, + name=forecaster_name, + line={"color": colours[i % len(colours)]}, + ), + ) + + fig3.update_layout( + title=f"Daily {selected_metric}", + xaxis_title="Date", + yaxis_title=f"{selected_metric} [{units}]", + legend_title="Forecaster", + ) + + if selected_metric == "MAE": + fig3.update_yaxes(range=[0, None]) + + return fig3 diff --git a/src/dataplatform/forecast/setup.py b/src/dataplatform/forecast/setup.py new file mode 100644 index 0000000..a46db86 --- /dev/null +++ b/src/dataplatform/forecast/setup.py @@ -0,0 +1,151 @@ +"""Setup Forecast Streamlit Page.""" + +from datetime import UTC, datetime, timedelta + +import pandas as pd +import streamlit as st +from aiocache import Cache, cached +from dp_sdk.ocf import dp + +from dataplatform.forecast.cache import key_builder_remove_client +from dataplatform.forecast.constant import cache_seconds, metrics + + +@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client) +async def get_location_names( + client: dp.DataPlatformDataServiceStub, + location_type: dp.LocationType, +) -> dict: + """Get location names for a given location type.""" + # List Location + list_locations_request = dp.ListLocationsRequest(location_type_filter=location_type) + list_locations_response = await client.list_locations(list_locations_request) + all_locations = list_locations_response.locations + + location_names = {loc.location_name: loc for loc in all_locations} + if location_type == dp.LocationType.GSP: + location_names = { + f"{int(loc.metadata.fields['gsp_id'].number_value)}:{loc.location_name}": loc + for loc in all_locations + } + # sort by gsp id + location_names = dict( + sorted(location_names.items(), key=lambda item: int(item[0].split(":")[0])), + ) + + return location_names + + +@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client) +async def get_forecasters(client: dp.DataPlatformDataServiceStub) -> list[dp.Forecaster]: + """Get all forecasters.""" + get_forecasters_request = dp.ListForecastersRequest() + get_forecasters_response = await client.list_forecasters(get_forecasters_request) + forecasters = get_forecasters_response.forecasters + return forecasters + + +async def setup_page(client: dp.DataPlatformDataServiceStub) -> dict: + """Setup the Streamlit page with sidebar options.""" + # Select Country + st.sidebar.selectbox("TODO Select a Country", ["UK", "NL"], index=0) + + # Select Location Type + location_types = [ + dp.LocationType.NATION, + dp.LocationType.GSP, + dp.LocationType.SITE, + ] + location_type = st.sidebar.selectbox("Select a Location Type", location_types, index=0) + + # select locations + location_names = await get_location_names(client, location_type) + selected_location_name = st.sidebar.selectbox( + "Select a Location", + location_names.keys(), + index=0, + ) + selected_location = location_names[selected_location_name] + + # get models + forecasters = await get_forecasters(client) + forecaster_names = sorted({forecaster.forecaster_name for forecaster in forecasters}) + default_index = forecaster_names.index("pvnet_v2") if "pvnet_v2" in forecaster_names else 0 + selected_forecaster_name = st.sidebar.multiselect( + "Select a Forecaster", + forecaster_names, + default=forecaster_names[default_index], + ) + selected_forecasters = [ + forecaster + for forecaster in forecasters + if forecaster.forecaster_name in selected_forecaster_name + ] + + # select start and end date + start_date = st.sidebar.date_input( + "Start date:", + datetime.now(tz=UTC).date() - timedelta(days=7), + ) + end_date = st.sidebar.date_input("End date:", datetime.now(tz=UTC).date() + timedelta(days=3)) + start_date = datetime.combine(start_date, datetime.min.time()).replace(tzinfo=UTC) + end_date = datetime.combine(end_date, datetime.min.time()).replace(tzinfo=UTC) - timedelta( + seconds=1, + ) + + # select forecast type + selected_forecast_type = st.sidebar.selectbox( + "Select a Forecast Type", + ["Current", "Horizon", "t0"], + index=0, + ) + + selected_forecast_horizon = None + strict_horizon_filtering = False + selected_t0s = None + if selected_forecast_type == "Horizon": + selected_forecast_horizon = st.sidebar.selectbox( + "Select a Forecast Horizon", + list(range(0, 36 * 60, 30)), + index=3, + ) + strict_horizon_filtering = st.sidebar.checkbox( + "Strict Horizon Filtering", + value=False, + help="Only show forecasts that exactly match the selected horizon, " + "if not, we use any forecast horizon greater or equal than", + ) + if selected_forecast_type == "t0": + # make datetimes every 30 minutes from start_date to end_date + all_t0s = ( + pd.date_range(start=start_date, end=end_date, freq="30min").to_pydatetime().tolist() + ) + + selected_t0s = st.sidebar.multiselect( + "Select t0s", + all_t0s, + default=all_t0s[: min(5, len(all_t0s))], + ) + + # select units + default_unit_index = 2 # MW + units = st.sidebar.selectbox("Select Units", ["W", "kW", "MW", "GW"], index=default_unit_index) + scale_factors = {"W": 1, "kW": 1e3, "MW": 1e6, "GW": 1e9} + scale_factor = scale_factors[units] + + selected_metric = st.sidebar.selectbox("Select a Metrics", metrics.keys(), index=0) + + return { + "selected_location": selected_location, + "selected_forecasters": selected_forecasters, + "start_date": start_date, + "end_date": end_date, + "selected_forecast_type": selected_forecast_type, + "scale_factor": scale_factor, + "selected_metric": selected_metric, + "forecaster_names": forecaster_names, + "selected_forecast_horizon": selected_forecast_horizon, + "selected_t0s": selected_t0s, + "units": units, + "strict_horizon_filtering": strict_horizon_filtering, + } diff --git a/src/main.py b/src/main.py index c4cec02..663be0c 100644 --- a/src/main.py +++ b/src/main.py @@ -11,6 +11,7 @@ from nowcasting_datamodel.models.metric import MetricValue from auth import check_password +from dataplatform.forecast.main import dp_forecast_page from forecast import forecast_page from get_data import get_metric_value from plots.all_gsps import make_all_gsps_plots @@ -262,6 +263,7 @@ def main_page(): st.Page(status_page, title="🚦 Status"), st.Page(forecast_page, title="📈 Forecast"), st.Page(pvsite_forecast_page, title="📉 Site Forecast"), + st.Page(dp_forecast_page, title="📉 DP Forecast"), st.Page(sites_toolbox_page, title="🛠️ Sites Toolbox"), st.Page(user_page, title="👥 API Users"), st.Page(nwp_page, title="🌤️ NWP"),