diff --git a/hubverse_annotator/app.py b/hubverse_annotator/app.py index 86a08ce7..57f80774 100644 --- a/hubverse_annotator/app.py +++ b/hubverse_annotator/app.py @@ -93,30 +93,30 @@ def forecast_annotation_ui( export_button() -def model_and_target_selection_ui( - observed_data_table: pl.DataFrame, +def model_selection_ui( forecast_table: pl.DataFrame, loc_abbr: str, -) -> tuple[list[str], str | None]: +) -> list[str]: """ - Streamlit widget for model and target selection. + Renders a Streamlit multiselect widget for choosing + models, with "All" and "None" buttons. Defaults to all + models being selected. Parameters ---------- - observed_data_table : pl.DataFrame - A hubverse table of loaded data (possibly empty). forecast_table : pl.DataFrame - The hubverse formatted table of forecasted ED - visits and or hospital admissions (possibly empty). + Hubverse-formatted forecasts (must include + "loc_abbr" and "model_id" columns; possibly empty). loc_abbr : str - The selection location, typically a US jurisdiction. + The selection location, typically a US + jurisdiction. Returns ------- - tuple - Returns a list of selected model names and the - selected target. + list[str] + The list of currently selected model_ids. """ + models = ( forecast_table.filter(pl.col("loc_abbr") == loc_abbr) .get_column("model_id") @@ -124,12 +124,61 @@ def model_and_target_selection_ui( .sort() .to_list() ) + if not models: + st.info("Upload forecasts for this location to pick models.") + return [] + if "model_selection" not in st.session_state: + st.session_state.model_selection = models.copy() + + def _select_all(): + st.session_state.model_selection = models.copy() + + def _select_none(): + st.session_state.model_selection = [] + + all_button, none_button = st.columns(2) + all_button.button("All", on_click=_select_all) + none_button.button("None", on_click=_select_none) selected_models = st.multiselect( "Model(s)", options=models, - default=None, key="model_selection", ) + + return selected_models + + +def target_selection_ui( + observed_data_table: pl.DataFrame, + forecast_table: pl.DataFrame, + loc_abbr: str, + selected_models: list[str], +) -> str | None: + """ + Renders a Streamlit selectbox for choosing a + target. Always defaults to the first alphabetical + target whenever the set of models changes. + + Parameters + ---------- + observed_data_table : pl.DataFrame + Hubverse formatted observed data table (must + include "loc_abbr", "target"; possibly empty). + forecast_table : pl.DataFrame + Hubverse formatted forecast table (must include + "loc_abbr", "model_id", "target"; possibly empty). + loc_abbr : str + The selection location, typically a US + jurisdiction. + selected_models : list[str] + Models currently selected. + + Returns + ------- + str | None + The selected target, or None if there are no + targets. + """ forecast_targets = ( forecast_table.filter( pl.col("loc_abbr") == loc_abbr, @@ -147,13 +196,16 @@ def model_and_target_selection_ui( .sort() .to_list() ) + suffix = "_".join(selected_models) if selected_models else "none" + target_selection_key = f"target_selection_{suffix}" all_targets = sorted(set(forecast_targets + observed_data_targets)) selected_target = st.selectbox( "Target", options=all_targets, - key="target_selection", + index=0, + key=target_selection_key, ) - return selected_models, selected_target + return selected_target @st.cache_data @@ -212,12 +264,13 @@ def get_reference_dates(forecast_table: pl.DataFrame) -> list[datetime.date]: return forecast_table.get_column("reference_date").unique().to_list() -def location_and_reference_data_ui( - observed_data_table: pl.DataFrame, forecast_table: pl.DataFrame -) -> tuple[str, datetime.date]: +def location_selection_ui( + observed_data_table: pl.DataFrame, + forecast_table: pl.DataFrame, +) -> str: """ - Streamlit widget for the reference date and location - selection. + Streamlit widget for the selection of a location (two + letter abbreviation for US jurisdiction). Parameters ---------- @@ -229,9 +282,9 @@ def location_and_reference_data_ui( Returns ------- - tuple - Returns a tuple of the two letter location - abbreviation and the selected reference date. + str + The selected two-letter location abbreviation. + """ loc_lookup = get_available_locations(observed_data_table, forecast_table) if "locations_list" not in st.session_state: @@ -239,10 +292,10 @@ def location_and_reference_data_ui( loc_lookup.get_column("long_name").sort().to_list() ) - def go_to_prev_loc(): + def _go_to_prev_loc(): st.session_state.current_loc_id -= 1 - def go_to_next_loc(): + def _go_to_next_loc(): st.session_state.current_loc_id += 1 location_col, prev_col, next_col = st.columns( @@ -259,7 +312,7 @@ def go_to_next_loc(): st.button( "⏮️", disabled=(st.session_state.current_loc_id == 0), - on_click=go_to_prev_loc, + on_click=_go_to_prev_loc, key="prev_button", use_container_width=True, ) @@ -270,7 +323,7 @@ def go_to_next_loc(): st.session_state.current_loc_id == len(st.session_state.locations_list) - 1 ), - on_click=go_to_next_loc, + on_click=_go_to_next_loc, key="next_button", use_container_width=True, ) @@ -281,14 +334,41 @@ def go_to_next_loc(): .get_column("short_name") .item() ) + return loc_abbr + + +def reference_date_selection_ui( + forecast_table: pl.DataFrame, +) -> datetime.date | None: + """ + Streamlit widget for the selection of forecast + reference date. + + Parameters + ---------- + forecast_table : pl.DataFrame + The hubverse formatted table of forecasted ED + visits and or hospital admissions (possibly empty). + + Returns + ------- + datetime.date | None + The selected reference date, or None if no dates + are available. + """ ref_dates = sorted(get_reference_dates(forecast_table), reverse=True) + if not ref_dates: + st.info("Upload a forecast file to select a reference date.") + return None + if ref_dates and "ref_date_selection" not in st.session_state: + st.session_state.ref_date_selection = ref_dates[0] selected_ref_date = st.selectbox( "Reference Date", options=ref_dates, format_func=lambda d: d.strftime("%Y-%m-%d"), key="ref_date_selection", ) - return loc_abbr, selected_ref_date + return selected_ref_date def is_empty_chart(chart: alt.LayerChart) -> bool: @@ -523,6 +603,55 @@ def plotting_ui( base_chart.altair_chart(chart, use_container_width=False, key=chart_key) +def filter_for_plotting( + observed_data_table: pl.DataFrame, + forecast_table: pl.DataFrame, + selected_models: list[str], + selected_target: str | None, + selected_ref_date: datetime.date, + loc_abbr: str, +) -> tuple[pl.DataFrame, pl.DataFrame]: + """ + Filter forecast and observed data tables for the + selected models and target. + + Parameters + ---------- + observed_data_table : pl.DataFrame + A hubverse table of loaded data (possibly empty). + forecast_table : pl.DataFrame + The hubverse formatted table of forecasted ED + visits and or hospital admissions (possibly empty). + selected_models : list[str] + Selected models to annotate. + selected_target : str + The target for filtering in the forecast and or + observed hubverse tables. + selected_ref_date : datetime.date + The selected reference date. + loc_abbr + The abbreviated US jurisdiction abbreviation. + + Returns + ------- + tuple + A tuple of observed_data_table (pl.DataFrame) and + forecast_table (pl.DataFrame) filtered by model, + target, and location, to be used for plotting. + """ + data_to_plot = observed_data_table.filter( + pl.col("loc_abbr") == loc_abbr, + pl.col("target") == selected_target, + ) + forecasts_to_plot = forecast_table.filter( + pl.col("loc_abbr") == loc_abbr, + pl.col("target") == selected_target, + pl.col("model_id").is_in(selected_models), + pl.col("reference_date") == selected_ref_date, + ) + return data_to_plot, forecasts_to_plot + + def validate_schema( df: pl.DataFrame, expected_schema: dict[str, pl.DataType], @@ -743,55 +872,6 @@ def load_data_ui() -> tuple[pl.DataFrame, pl.DataFrame]: return observed_data_table, forecast_table -def filter_for_plotting( - observed_data_table: pl.DataFrame, - forecast_table: pl.DataFrame, - selected_models: list[str], - selected_target: str | None, - selected_ref_date: datetime.date, - loc_abbr: str, -) -> tuple[pl.DataFrame, pl.DataFrame]: - """ - Filter forecast and observed data tables for the - selected models and target. - - Parameters - ---------- - observed_data_table : pl.DataFrame - A hubverse table of loaded data (possibly empty). - forecast_table : pl.DataFrame - The hubverse formatted table of forecasted ED - visits and or hospital admissions (possibly empty). - selected_models : list[str] - Selected models to annotate. - selected_target : str - The target for filtering in the forecast and or - observed hubverse tables. - selected_ref_date : datetime.date - The selected reference date. - loc_abbr - The abbreviated US jurisdiction abbreviation. - - Returns - ------- - tuple - A tuple of observed_data_table (pl.DataFrame) and - forecast_table (pl.DataFrame) filtered by model, - target, and location, to be used for plotting. - """ - data_to_plot = observed_data_table.filter( - pl.col("loc_abbr") == loc_abbr, - pl.col("target") == selected_target, - ) - forecasts_to_plot = forecast_table.filter( - pl.col("loc_abbr") == loc_abbr, - pl.col("target") == selected_target, - pl.col("model_id").is_in(selected_models), - pl.col("reference_date") == selected_ref_date, - ) - return data_to_plot, forecasts_to_plot - - def main() -> None: # record session start time start_time = time.time() @@ -805,11 +885,14 @@ def main() -> None: "Please upload Observed Data or Hubverse Forecasts to begin." ) return None - loc_abbr, selected_ref_date = location_and_reference_data_ui( - observed_data_table, forecast_table - ) - selected_models, selected_target = model_and_target_selection_ui( - observed_data_table, forecast_table, loc_abbr + loc_abbr = location_selection_ui(observed_data_table, forecast_table) + selected_ref_date = reference_date_selection_ui(forecast_table) + selected_models = model_selection_ui(forecast_table, loc_abbr) + selected_target = target_selection_ui( + observed_data_table, + forecast_table, + loc_abbr, + selected_models, ) scale = "log" if st.checkbox("Log-scale", value=True) else "linear" grid = st.checkbox("Gridlines", value=True)