Skip to content
247 changes: 165 additions & 82 deletions hubverse_annotator/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,43 +93,92 @@ def forecast_annotation_ui(
export_button()


def model_and_target_selection_ui(
observed_data_table: pl.DataFrame,
def model_selection_ui(
forecast_table: pl.DataFrame,
loc_abbr: str,
) -> tuple[list[str], str | None]:
) -> list[str]:
"""
Streamlit widget for model and target selection.
Renders a Streamlit multiselect widget for choosing
models, with "All" and "None" buttons. Defaults to all
models being selected.

Parameters
----------
observed_data_table : pl.DataFrame
A hubverse table of loaded data (possibly empty).
forecast_table : pl.DataFrame
The hubverse formatted table of forecasted ED
visits and or hospital admissions (possibly empty).
Hubverse-formatted forecasts (must include
"loc_abbr" and "model_id" columns; possibly empty).
loc_abbr : str
The selection location, typically a US jurisdiction.
The selection location, typically a US
jurisdiction.

Returns
-------
tuple
Returns a list of selected model names and the
selected target.
list[str]
The list of currently selected model_ids.
"""

models = (
forecast_table.filter(pl.col("loc_abbr") == loc_abbr)
.get_column("model_id")
.unique()
.sort()
.to_list()
)
if not models:
st.info("Upload forecasts for this location to pick models.")
return []
if "model_selection" not in st.session_state:
st.session_state.model_selection = models.copy()

def _select_all():
st.session_state.model_selection = models.copy()

def _select_none():
st.session_state.model_selection = []

all_button, none_button = st.columns(2)
all_button.button("All", on_click=_select_all)
none_button.button("None", on_click=_select_none)
selected_models = st.multiselect(
"Model(s)",
options=models,
default=None,
key="model_selection",
)

return selected_models


def target_selection_ui(
observed_data_table: pl.DataFrame,
forecast_table: pl.DataFrame,
loc_abbr: str,
selected_models: list[str],
) -> str | None:
"""
Renders a Streamlit selectbox for choosing a
target. Always defaults to the first alphabetical
target whenever the set of models changes.

Parameters
----------
observed_data_table : pl.DataFrame
Hubverse formatted observed data table (must
include "loc_abbr", "target"; possibly empty).
forecast_table : pl.DataFrame
Hubverse formatted forecast table (must include
"loc_abbr", "model_id", "target"; possibly empty).
loc_abbr : str
The selection location, typically a US
jurisdiction.
selected_models : list[str]
Models currently selected.

Returns
-------
str | None
The selected target, or None if there are no
targets.
"""
forecast_targets = (
forecast_table.filter(
pl.col("loc_abbr") == loc_abbr,
Expand All @@ -147,13 +196,16 @@ def model_and_target_selection_ui(
.sort()
.to_list()
)
suffix = "_".join(selected_models) if selected_models else "none"
target_selection_key = f"target_selection_{suffix}"
all_targets = sorted(set(forecast_targets + observed_data_targets))
selected_target = st.selectbox(
"Target",
options=all_targets,
key="target_selection",
index=0,
key=target_selection_key,
)
return selected_models, selected_target
return selected_target


@st.cache_data
Expand Down Expand Up @@ -212,12 +264,13 @@ def get_reference_dates(forecast_table: pl.DataFrame) -> list[datetime.date]:
return forecast_table.get_column("reference_date").unique().to_list()


def location_and_reference_data_ui(
observed_data_table: pl.DataFrame, forecast_table: pl.DataFrame
) -> tuple[str, datetime.date]:
def location_selection_ui(
observed_data_table: pl.DataFrame,
forecast_table: pl.DataFrame,
) -> str:
"""
Streamlit widget for the reference date and location
selection.
Streamlit widget for the selection of a location (two
letter abbreviation for US jurisdiction).

Parameters
----------
Expand All @@ -229,20 +282,20 @@ def location_and_reference_data_ui(

Returns
-------
tuple
Returns a tuple of the two letter location
abbreviation and the selected reference date.
str
The selected two-letter location abbreviation.

"""
loc_lookup = get_available_locations(observed_data_table, forecast_table)
if "locations_list" not in st.session_state:
st.session_state.locations_list = (
loc_lookup.get_column("long_name").sort().to_list()
)

def go_to_prev_loc():
def _go_to_prev_loc():
st.session_state.current_loc_id -= 1

def go_to_next_loc():
def _go_to_next_loc():
st.session_state.current_loc_id += 1

location_col, prev_col, next_col = st.columns(
Expand All @@ -259,7 +312,7 @@ def go_to_next_loc():
st.button(
"⏮️",
disabled=(st.session_state.current_loc_id == 0),
on_click=go_to_prev_loc,
on_click=_go_to_prev_loc,
key="prev_button",
use_container_width=True,
)
Expand All @@ -270,7 +323,7 @@ def go_to_next_loc():
st.session_state.current_loc_id
== len(st.session_state.locations_list) - 1
),
on_click=go_to_next_loc,
on_click=_go_to_next_loc,
key="next_button",
use_container_width=True,
)
Expand All @@ -281,14 +334,41 @@ def go_to_next_loc():
.get_column("short_name")
.item()
)
return loc_abbr


def reference_date_selection_ui(
forecast_table: pl.DataFrame,
) -> datetime.date | None:
"""
Streamlit widget for the selection of forecast
reference date.

Parameters
----------
forecast_table : pl.DataFrame
The hubverse formatted table of forecasted ED
visits and or hospital admissions (possibly empty).

Returns
-------
datetime.date | None
The selected reference date, or None if no dates
are available.
"""
ref_dates = sorted(get_reference_dates(forecast_table), reverse=True)
if not ref_dates:
st.info("Upload a forecast file to select a reference date.")
return None
if ref_dates and "ref_date_selection" not in st.session_state:
st.session_state.ref_date_selection = ref_dates[0]
selected_ref_date = st.selectbox(
"Reference Date",
options=ref_dates,
format_func=lambda d: d.strftime("%Y-%m-%d"),
key="ref_date_selection",
)
return loc_abbr, selected_ref_date
return selected_ref_date


def is_empty_chart(chart: alt.LayerChart) -> bool:
Expand Down Expand Up @@ -523,6 +603,55 @@ def plotting_ui(
base_chart.altair_chart(chart, use_container_width=False, key=chart_key)


def filter_for_plotting(
observed_data_table: pl.DataFrame,
forecast_table: pl.DataFrame,
selected_models: list[str],
selected_target: str | None,
selected_ref_date: datetime.date,
loc_abbr: str,
) -> tuple[pl.DataFrame, pl.DataFrame]:
"""
Filter forecast and observed data tables for the
selected models and target.

Parameters
----------
observed_data_table : pl.DataFrame
A hubverse table of loaded data (possibly empty).
forecast_table : pl.DataFrame
The hubverse formatted table of forecasted ED
visits and or hospital admissions (possibly empty).
selected_models : list[str]
Selected models to annotate.
selected_target : str
The target for filtering in the forecast and or
observed hubverse tables.
selected_ref_date : datetime.date
The selected reference date.
loc_abbr
The abbreviated US jurisdiction abbreviation.

Returns
-------
tuple
A tuple of observed_data_table (pl.DataFrame) and
forecast_table (pl.DataFrame) filtered by model,
target, and location, to be used for plotting.
"""
data_to_plot = observed_data_table.filter(
pl.col("loc_abbr") == loc_abbr,
pl.col("target") == selected_target,
)
forecasts_to_plot = forecast_table.filter(
pl.col("loc_abbr") == loc_abbr,
pl.col("target") == selected_target,
pl.col("model_id").is_in(selected_models),
pl.col("reference_date") == selected_ref_date,
)
return data_to_plot, forecasts_to_plot


def validate_schema(
df: pl.DataFrame,
expected_schema: dict[str, pl.DataType],
Expand Down Expand Up @@ -743,55 +872,6 @@ def load_data_ui() -> tuple[pl.DataFrame, pl.DataFrame]:
return observed_data_table, forecast_table


def filter_for_plotting(
observed_data_table: pl.DataFrame,
forecast_table: pl.DataFrame,
selected_models: list[str],
selected_target: str | None,
selected_ref_date: datetime.date,
loc_abbr: str,
) -> tuple[pl.DataFrame, pl.DataFrame]:
"""
Filter forecast and observed data tables for the
selected models and target.

Parameters
----------
observed_data_table : pl.DataFrame
A hubverse table of loaded data (possibly empty).
forecast_table : pl.DataFrame
The hubverse formatted table of forecasted ED
visits and or hospital admissions (possibly empty).
selected_models : list[str]
Selected models to annotate.
selected_target : str
The target for filtering in the forecast and or
observed hubverse tables.
selected_ref_date : datetime.date
The selected reference date.
loc_abbr
The abbreviated US jurisdiction abbreviation.

Returns
-------
tuple
A tuple of observed_data_table (pl.DataFrame) and
forecast_table (pl.DataFrame) filtered by model,
target, and location, to be used for plotting.
"""
data_to_plot = observed_data_table.filter(
pl.col("loc_abbr") == loc_abbr,
pl.col("target") == selected_target,
)
forecasts_to_plot = forecast_table.filter(
pl.col("loc_abbr") == loc_abbr,
pl.col("target") == selected_target,
pl.col("model_id").is_in(selected_models),
pl.col("reference_date") == selected_ref_date,
)
return data_to_plot, forecasts_to_plot


def main() -> None:
# record session start time
start_time = time.time()
Expand All @@ -805,11 +885,14 @@ def main() -> None:
"Please upload Observed Data or Hubverse Forecasts to begin."
)
return None
loc_abbr, selected_ref_date = location_and_reference_data_ui(
observed_data_table, forecast_table
)
selected_models, selected_target = model_and_target_selection_ui(
observed_data_table, forecast_table, loc_abbr
loc_abbr = location_selection_ui(observed_data_table, forecast_table)
selected_ref_date = reference_date_selection_ui(forecast_table)
selected_models = model_selection_ui(forecast_table, loc_abbr)
selected_target = target_selection_ui(
observed_data_table,
forecast_table,
loc_abbr,
selected_models,
)
scale = "log" if st.checkbox("Log-scale", value=True) else "linear"
grid = st.checkbox("Gridlines", value=True)
Expand Down