Skip to content

Commit bfb129b

Browse files
authored
Default Plot, Targets, And Selections (#92)
* all or none buttons * move filter for plotting location * separate funcs for model and target selection * internal use indicators * unique keys for different models selections * minor correction to target key * cut back obs data if forecast file present * default for ref date * split location and reference ui into two * wait with model and ref date selection until forecast file uploaded * docstring edit * remove cutting feature
1 parent d50884b commit bfb129b

File tree

1 file changed

+165
-82
lines changed

1 file changed

+165
-82
lines changed

hubverse_annotator/app.py

Lines changed: 165 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -93,43 +93,92 @@ def forecast_annotation_ui(
9393
export_button()
9494

9595

96-
def model_and_target_selection_ui(
97-
observed_data_table: pl.DataFrame,
96+
def model_selection_ui(
9897
forecast_table: pl.DataFrame,
9998
loc_abbr: str,
100-
) -> tuple[list[str], str | None]:
99+
) -> list[str]:
101100
"""
102-
Streamlit widget for model and target selection.
101+
Renders a Streamlit multiselect widget for choosing
102+
models, with "All" and "None" buttons. Defaults to all
103+
models being selected.
103104
104105
Parameters
105106
----------
106-
observed_data_table : pl.DataFrame
107-
A hubverse table of loaded data (possibly empty).
108107
forecast_table : pl.DataFrame
109-
The hubverse formatted table of forecasted ED
110-
visits and or hospital admissions (possibly empty).
108+
Hubverse-formatted forecasts (must include
109+
"loc_abbr" and "model_id" columns; possibly empty).
111110
loc_abbr : str
112-
The selection location, typically a US jurisdiction.
111+
The selection location, typically a US
112+
jurisdiction.
113113
114114
Returns
115115
-------
116-
tuple
117-
Returns a list of selected model names and the
118-
selected target.
116+
list[str]
117+
The list of currently selected model_ids.
119118
"""
119+
120120
models = (
121121
forecast_table.filter(pl.col("loc_abbr") == loc_abbr)
122122
.get_column("model_id")
123123
.unique()
124124
.sort()
125125
.to_list()
126126
)
127+
if not models:
128+
st.info("Upload forecasts for this location to pick models.")
129+
return []
130+
if "model_selection" not in st.session_state:
131+
st.session_state.model_selection = models.copy()
132+
133+
def _select_all():
134+
st.session_state.model_selection = models.copy()
135+
136+
def _select_none():
137+
st.session_state.model_selection = []
138+
139+
all_button, none_button = st.columns(2)
140+
all_button.button("All", on_click=_select_all)
141+
none_button.button("None", on_click=_select_none)
127142
selected_models = st.multiselect(
128143
"Model(s)",
129144
options=models,
130-
default=None,
131145
key="model_selection",
132146
)
147+
148+
return selected_models
149+
150+
151+
def target_selection_ui(
152+
observed_data_table: pl.DataFrame,
153+
forecast_table: pl.DataFrame,
154+
loc_abbr: str,
155+
selected_models: list[str],
156+
) -> str | None:
157+
"""
158+
Renders a Streamlit selectbox for choosing a
159+
target. Always defaults to the first alphabetical
160+
target whenever the set of models changes.
161+
162+
Parameters
163+
----------
164+
observed_data_table : pl.DataFrame
165+
Hubverse formatted observed data table (must
166+
include "loc_abbr", "target"; possibly empty).
167+
forecast_table : pl.DataFrame
168+
Hubverse formatted forecast table (must include
169+
"loc_abbr", "model_id", "target"; possibly empty).
170+
loc_abbr : str
171+
The selection location, typically a US
172+
jurisdiction.
173+
selected_models : list[str]
174+
Models currently selected.
175+
176+
Returns
177+
-------
178+
str | None
179+
The selected target, or None if there are no
180+
targets.
181+
"""
133182
forecast_targets = (
134183
forecast_table.filter(
135184
pl.col("loc_abbr") == loc_abbr,
@@ -147,13 +196,16 @@ def model_and_target_selection_ui(
147196
.sort()
148197
.to_list()
149198
)
199+
suffix = "_".join(selected_models) if selected_models else "none"
200+
target_selection_key = f"target_selection_{suffix}"
150201
all_targets = sorted(set(forecast_targets + observed_data_targets))
151202
selected_target = st.selectbox(
152203
"Target",
153204
options=all_targets,
154-
key="target_selection",
205+
index=0,
206+
key=target_selection_key,
155207
)
156-
return selected_models, selected_target
208+
return selected_target
157209

158210

159211
@st.cache_data
@@ -212,12 +264,13 @@ def get_reference_dates(forecast_table: pl.DataFrame) -> list[datetime.date]:
212264
return forecast_table.get_column("reference_date").unique().to_list()
213265

214266

215-
def location_and_reference_data_ui(
216-
observed_data_table: pl.DataFrame, forecast_table: pl.DataFrame
217-
) -> tuple[str, datetime.date]:
267+
def location_selection_ui(
268+
observed_data_table: pl.DataFrame,
269+
forecast_table: pl.DataFrame,
270+
) -> str:
218271
"""
219-
Streamlit widget for the reference date and location
220-
selection.
272+
Streamlit widget for the selection of a location (two
273+
letter abbreviation for US jurisdiction).
221274
222275
Parameters
223276
----------
@@ -229,20 +282,20 @@ def location_and_reference_data_ui(
229282
230283
Returns
231284
-------
232-
tuple
233-
Returns a tuple of the two letter location
234-
abbreviation and the selected reference date.
285+
str
286+
The selected two-letter location abbreviation.
287+
235288
"""
236289
loc_lookup = get_available_locations(observed_data_table, forecast_table)
237290
if "locations_list" not in st.session_state:
238291
st.session_state.locations_list = (
239292
loc_lookup.get_column("long_name").sort().to_list()
240293
)
241294

242-
def go_to_prev_loc():
295+
def _go_to_prev_loc():
243296
st.session_state.current_loc_id -= 1
244297

245-
def go_to_next_loc():
298+
def _go_to_next_loc():
246299
st.session_state.current_loc_id += 1
247300

248301
location_col, prev_col, next_col = st.columns(
@@ -259,7 +312,7 @@ def go_to_next_loc():
259312
st.button(
260313
"⏮️",
261314
disabled=(st.session_state.current_loc_id == 0),
262-
on_click=go_to_prev_loc,
315+
on_click=_go_to_prev_loc,
263316
key="prev_button",
264317
use_container_width=True,
265318
)
@@ -270,7 +323,7 @@ def go_to_next_loc():
270323
st.session_state.current_loc_id
271324
== len(st.session_state.locations_list) - 1
272325
),
273-
on_click=go_to_next_loc,
326+
on_click=_go_to_next_loc,
274327
key="next_button",
275328
use_container_width=True,
276329
)
@@ -281,14 +334,41 @@ def go_to_next_loc():
281334
.get_column("short_name")
282335
.item()
283336
)
337+
return loc_abbr
338+
339+
340+
def reference_date_selection_ui(
341+
forecast_table: pl.DataFrame,
342+
) -> datetime.date | None:
343+
"""
344+
Streamlit widget for the selection of forecast
345+
reference date.
346+
347+
Parameters
348+
----------
349+
forecast_table : pl.DataFrame
350+
The hubverse formatted table of forecasted ED
351+
visits and or hospital admissions (possibly empty).
352+
353+
Returns
354+
-------
355+
datetime.date | None
356+
The selected reference date, or None if no dates
357+
are available.
358+
"""
284359
ref_dates = sorted(get_reference_dates(forecast_table), reverse=True)
360+
if not ref_dates:
361+
st.info("Upload a forecast file to select a reference date.")
362+
return None
363+
if ref_dates and "ref_date_selection" not in st.session_state:
364+
st.session_state.ref_date_selection = ref_dates[0]
285365
selected_ref_date = st.selectbox(
286366
"Reference Date",
287367
options=ref_dates,
288368
format_func=lambda d: d.strftime("%Y-%m-%d"),
289369
key="ref_date_selection",
290370
)
291-
return loc_abbr, selected_ref_date
371+
return selected_ref_date
292372

293373

294374
def is_empty_chart(chart: alt.LayerChart) -> bool:
@@ -523,6 +603,55 @@ def plotting_ui(
523603
base_chart.altair_chart(chart, use_container_width=False, key=chart_key)
524604

525605

606+
def filter_for_plotting(
607+
observed_data_table: pl.DataFrame,
608+
forecast_table: pl.DataFrame,
609+
selected_models: list[str],
610+
selected_target: str | None,
611+
selected_ref_date: datetime.date,
612+
loc_abbr: str,
613+
) -> tuple[pl.DataFrame, pl.DataFrame]:
614+
"""
615+
Filter forecast and observed data tables for the
616+
selected models and target.
617+
618+
Parameters
619+
----------
620+
observed_data_table : pl.DataFrame
621+
A hubverse table of loaded data (possibly empty).
622+
forecast_table : pl.DataFrame
623+
The hubverse formatted table of forecasted ED
624+
visits and or hospital admissions (possibly empty).
625+
selected_models : list[str]
626+
Selected models to annotate.
627+
selected_target : str
628+
The target for filtering in the forecast and or
629+
observed hubverse tables.
630+
selected_ref_date : datetime.date
631+
The selected reference date.
632+
loc_abbr
633+
The abbreviated US jurisdiction abbreviation.
634+
635+
Returns
636+
-------
637+
tuple
638+
A tuple of observed_data_table (pl.DataFrame) and
639+
forecast_table (pl.DataFrame) filtered by model,
640+
target, and location, to be used for plotting.
641+
"""
642+
data_to_plot = observed_data_table.filter(
643+
pl.col("loc_abbr") == loc_abbr,
644+
pl.col("target") == selected_target,
645+
)
646+
forecasts_to_plot = forecast_table.filter(
647+
pl.col("loc_abbr") == loc_abbr,
648+
pl.col("target") == selected_target,
649+
pl.col("model_id").is_in(selected_models),
650+
pl.col("reference_date") == selected_ref_date,
651+
)
652+
return data_to_plot, forecasts_to_plot
653+
654+
526655
def validate_schema(
527656
df: pl.DataFrame,
528657
expected_schema: dict[str, pl.DataType],
@@ -743,55 +872,6 @@ def load_data_ui() -> tuple[pl.DataFrame, pl.DataFrame]:
743872
return observed_data_table, forecast_table
744873

745874

746-
def filter_for_plotting(
747-
observed_data_table: pl.DataFrame,
748-
forecast_table: pl.DataFrame,
749-
selected_models: list[str],
750-
selected_target: str | None,
751-
selected_ref_date: datetime.date,
752-
loc_abbr: str,
753-
) -> tuple[pl.DataFrame, pl.DataFrame]:
754-
"""
755-
Filter forecast and observed data tables for the
756-
selected models and target.
757-
758-
Parameters
759-
----------
760-
observed_data_table : pl.DataFrame
761-
A hubverse table of loaded data (possibly empty).
762-
forecast_table : pl.DataFrame
763-
The hubverse formatted table of forecasted ED
764-
visits and or hospital admissions (possibly empty).
765-
selected_models : list[str]
766-
Selected models to annotate.
767-
selected_target : str
768-
The target for filtering in the forecast and or
769-
observed hubverse tables.
770-
selected_ref_date : datetime.date
771-
The selected reference date.
772-
loc_abbr
773-
The abbreviated US jurisdiction abbreviation.
774-
775-
Returns
776-
-------
777-
tuple
778-
A tuple of observed_data_table (pl.DataFrame) and
779-
forecast_table (pl.DataFrame) filtered by model,
780-
target, and location, to be used for plotting.
781-
"""
782-
data_to_plot = observed_data_table.filter(
783-
pl.col("loc_abbr") == loc_abbr,
784-
pl.col("target") == selected_target,
785-
)
786-
forecasts_to_plot = forecast_table.filter(
787-
pl.col("loc_abbr") == loc_abbr,
788-
pl.col("target") == selected_target,
789-
pl.col("model_id").is_in(selected_models),
790-
pl.col("reference_date") == selected_ref_date,
791-
)
792-
return data_to_plot, forecasts_to_plot
793-
794-
795875
def main() -> None:
796876
# record session start time
797877
start_time = time.time()
@@ -805,11 +885,14 @@ def main() -> None:
805885
"Please upload Observed Data or Hubverse Forecasts to begin."
806886
)
807887
return None
808-
loc_abbr, selected_ref_date = location_and_reference_data_ui(
809-
observed_data_table, forecast_table
810-
)
811-
selected_models, selected_target = model_and_target_selection_ui(
812-
observed_data_table, forecast_table, loc_abbr
888+
loc_abbr = location_selection_ui(observed_data_table, forecast_table)
889+
selected_ref_date = reference_date_selection_ui(forecast_table)
890+
selected_models = model_selection_ui(forecast_table, loc_abbr)
891+
selected_target = target_selection_ui(
892+
observed_data_table,
893+
forecast_table,
894+
loc_abbr,
895+
selected_models,
813896
)
814897
scale = "log" if st.checkbox("Log-scale", value=True) else "linear"
815898
grid = st.checkbox("Gridlines", value=True)

0 commit comments

Comments
 (0)