@@ -93,43 +93,92 @@ def forecast_annotation_ui(
9393 export_button ()
9494
9595
96- def model_and_target_selection_ui (
97- observed_data_table : pl .DataFrame ,
96+ def model_selection_ui (
9897 forecast_table : pl .DataFrame ,
9998 loc_abbr : str ,
100- ) -> tuple [ list [str ], str | None ]:
99+ ) -> list [str ]:
101100 """
102- Streamlit widget for model and target selection.
101+ Renders a Streamlit multiselect widget for choosing
102+ models, with "All" and "None" buttons. Defaults to all
103+ models being selected.
103104
104105 Parameters
105106 ----------
106- observed_data_table : pl.DataFrame
107- A hubverse table of loaded data (possibly empty).
108107 forecast_table : pl.DataFrame
109- The hubverse formatted table of forecasted ED
110- visits and or hospital admissions ( possibly empty).
108+ Hubverse- formatted forecasts (must include
109+ "loc_abbr" and "model_id" columns; possibly empty).
111110 loc_abbr : str
112- The selection location, typically a US jurisdiction.
111+ The selection location, typically a US
112+ jurisdiction.
113113
114114 Returns
115115 -------
116- tuple
117- Returns a list of selected model names and the
118- selected target.
116+ list[str]
117+ The list of currently selected model_ids.
119118 """
119+
120120 models = (
121121 forecast_table .filter (pl .col ("loc_abbr" ) == loc_abbr )
122122 .get_column ("model_id" )
123123 .unique ()
124124 .sort ()
125125 .to_list ()
126126 )
127+ if not models :
128+ st .info ("Upload forecasts for this location to pick models." )
129+ return []
130+ if "model_selection" not in st .session_state :
131+ st .session_state .model_selection = models .copy ()
132+
133+ def _select_all ():
134+ st .session_state .model_selection = models .copy ()
135+
136+ def _select_none ():
137+ st .session_state .model_selection = []
138+
139+ all_button , none_button = st .columns (2 )
140+ all_button .button ("All" , on_click = _select_all )
141+ none_button .button ("None" , on_click = _select_none )
127142 selected_models = st .multiselect (
128143 "Model(s)" ,
129144 options = models ,
130- default = None ,
131145 key = "model_selection" ,
132146 )
147+
148+ return selected_models
149+
150+
151+ def target_selection_ui (
152+ observed_data_table : pl .DataFrame ,
153+ forecast_table : pl .DataFrame ,
154+ loc_abbr : str ,
155+ selected_models : list [str ],
156+ ) -> str | None :
157+ """
158+ Renders a Streamlit selectbox for choosing a
159+ target. Always defaults to the first alphabetical
160+ target whenever the set of models changes.
161+
162+ Parameters
163+ ----------
164+ observed_data_table : pl.DataFrame
165+ Hubverse formatted observed data table (must
166+ include "loc_abbr", "target"; possibly empty).
167+ forecast_table : pl.DataFrame
168+ Hubverse formatted forecast table (must include
169+ "loc_abbr", "model_id", "target"; possibly empty).
170+ loc_abbr : str
171+ The selection location, typically a US
172+ jurisdiction.
173+ selected_models : list[str]
174+ Models currently selected.
175+
176+ Returns
177+ -------
178+ str | None
179+ The selected target, or None if there are no
180+ targets.
181+ """
133182 forecast_targets = (
134183 forecast_table .filter (
135184 pl .col ("loc_abbr" ) == loc_abbr ,
@@ -147,13 +196,16 @@ def model_and_target_selection_ui(
147196 .sort ()
148197 .to_list ()
149198 )
199+ suffix = "_" .join (selected_models ) if selected_models else "none"
200+ target_selection_key = f"target_selection_{ suffix } "
150201 all_targets = sorted (set (forecast_targets + observed_data_targets ))
151202 selected_target = st .selectbox (
152203 "Target" ,
153204 options = all_targets ,
154- key = "target_selection" ,
205+ index = 0 ,
206+ key = target_selection_key ,
155207 )
156- return selected_models , selected_target
208+ return selected_target
157209
158210
159211@st .cache_data
@@ -212,12 +264,13 @@ def get_reference_dates(forecast_table: pl.DataFrame) -> list[datetime.date]:
212264 return forecast_table .get_column ("reference_date" ).unique ().to_list ()
213265
214266
215- def location_and_reference_data_ui (
216- observed_data_table : pl .DataFrame , forecast_table : pl .DataFrame
217- ) -> tuple [str , datetime .date ]:
267+ def location_selection_ui (
268+ observed_data_table : pl .DataFrame ,
269+ forecast_table : pl .DataFrame ,
270+ ) -> str :
218271 """
219- Streamlit widget for the reference date and location
220- selection .
272+ Streamlit widget for the selection of a location (two
273+ letter abbreviation for US jurisdiction) .
221274
222275 Parameters
223276 ----------
@@ -229,20 +282,20 @@ def location_and_reference_data_ui(
229282
230283 Returns
231284 -------
232- tuple
233- Returns a tuple of the two letter location
234- abbreviation and the selected reference date.
285+ str
286+ The selected two- letter location abbreviation.
287+
235288 """
236289 loc_lookup = get_available_locations (observed_data_table , forecast_table )
237290 if "locations_list" not in st .session_state :
238291 st .session_state .locations_list = (
239292 loc_lookup .get_column ("long_name" ).sort ().to_list ()
240293 )
241294
242- def go_to_prev_loc ():
295+ def _go_to_prev_loc ():
243296 st .session_state .current_loc_id -= 1
244297
245- def go_to_next_loc ():
298+ def _go_to_next_loc ():
246299 st .session_state .current_loc_id += 1
247300
248301 location_col , prev_col , next_col = st .columns (
@@ -259,7 +312,7 @@ def go_to_next_loc():
259312 st .button (
260313 "⏮️" ,
261314 disabled = (st .session_state .current_loc_id == 0 ),
262- on_click = go_to_prev_loc ,
315+ on_click = _go_to_prev_loc ,
263316 key = "prev_button" ,
264317 use_container_width = True ,
265318 )
@@ -270,7 +323,7 @@ def go_to_next_loc():
270323 st .session_state .current_loc_id
271324 == len (st .session_state .locations_list ) - 1
272325 ),
273- on_click = go_to_next_loc ,
326+ on_click = _go_to_next_loc ,
274327 key = "next_button" ,
275328 use_container_width = True ,
276329 )
@@ -281,14 +334,41 @@ def go_to_next_loc():
281334 .get_column ("short_name" )
282335 .item ()
283336 )
337+ return loc_abbr
338+
339+
340+ def reference_date_selection_ui (
341+ forecast_table : pl .DataFrame ,
342+ ) -> datetime .date | None :
343+ """
344+ Streamlit widget for the selection of forecast
345+ reference date.
346+
347+ Parameters
348+ ----------
349+ forecast_table : pl.DataFrame
350+ The hubverse formatted table of forecasted ED
351+ visits and or hospital admissions (possibly empty).
352+
353+ Returns
354+ -------
355+ datetime.date | None
356+ The selected reference date, or None if no dates
357+ are available.
358+ """
284359 ref_dates = sorted (get_reference_dates (forecast_table ), reverse = True )
360+ if not ref_dates :
361+ st .info ("Upload a forecast file to select a reference date." )
362+ return None
363+ if ref_dates and "ref_date_selection" not in st .session_state :
364+ st .session_state .ref_date_selection = ref_dates [0 ]
285365 selected_ref_date = st .selectbox (
286366 "Reference Date" ,
287367 options = ref_dates ,
288368 format_func = lambda d : d .strftime ("%Y-%m-%d" ),
289369 key = "ref_date_selection" ,
290370 )
291- return loc_abbr , selected_ref_date
371+ return selected_ref_date
292372
293373
294374def is_empty_chart (chart : alt .LayerChart ) -> bool :
@@ -523,6 +603,55 @@ def plotting_ui(
523603 base_chart .altair_chart (chart , use_container_width = False , key = chart_key )
524604
525605
606+ def filter_for_plotting (
607+ observed_data_table : pl .DataFrame ,
608+ forecast_table : pl .DataFrame ,
609+ selected_models : list [str ],
610+ selected_target : str | None ,
611+ selected_ref_date : datetime .date ,
612+ loc_abbr : str ,
613+ ) -> tuple [pl .DataFrame , pl .DataFrame ]:
614+ """
615+ Filter forecast and observed data tables for the
616+ selected models and target.
617+
618+ Parameters
619+ ----------
620+ observed_data_table : pl.DataFrame
621+ A hubverse table of loaded data (possibly empty).
622+ forecast_table : pl.DataFrame
623+ The hubverse formatted table of forecasted ED
624+ visits and or hospital admissions (possibly empty).
625+ selected_models : list[str]
626+ Selected models to annotate.
627+ selected_target : str
628+ The target for filtering in the forecast and or
629+ observed hubverse tables.
630+ selected_ref_date : datetime.date
631+ The selected reference date.
632+ loc_abbr
633+ The abbreviated US jurisdiction abbreviation.
634+
635+ Returns
636+ -------
637+ tuple
638+ A tuple of observed_data_table (pl.DataFrame) and
639+ forecast_table (pl.DataFrame) filtered by model,
640+ target, and location, to be used for plotting.
641+ """
642+ data_to_plot = observed_data_table .filter (
643+ pl .col ("loc_abbr" ) == loc_abbr ,
644+ pl .col ("target" ) == selected_target ,
645+ )
646+ forecasts_to_plot = forecast_table .filter (
647+ pl .col ("loc_abbr" ) == loc_abbr ,
648+ pl .col ("target" ) == selected_target ,
649+ pl .col ("model_id" ).is_in (selected_models ),
650+ pl .col ("reference_date" ) == selected_ref_date ,
651+ )
652+ return data_to_plot , forecasts_to_plot
653+
654+
526655def validate_schema (
527656 df : pl .DataFrame ,
528657 expected_schema : dict [str , pl .DataType ],
@@ -743,55 +872,6 @@ def load_data_ui() -> tuple[pl.DataFrame, pl.DataFrame]:
743872 return observed_data_table , forecast_table
744873
745874
746- def filter_for_plotting (
747- observed_data_table : pl .DataFrame ,
748- forecast_table : pl .DataFrame ,
749- selected_models : list [str ],
750- selected_target : str | None ,
751- selected_ref_date : datetime .date ,
752- loc_abbr : str ,
753- ) -> tuple [pl .DataFrame , pl .DataFrame ]:
754- """
755- Filter forecast and observed data tables for the
756- selected models and target.
757-
758- Parameters
759- ----------
760- observed_data_table : pl.DataFrame
761- A hubverse table of loaded data (possibly empty).
762- forecast_table : pl.DataFrame
763- The hubverse formatted table of forecasted ED
764- visits and or hospital admissions (possibly empty).
765- selected_models : list[str]
766- Selected models to annotate.
767- selected_target : str
768- The target for filtering in the forecast and or
769- observed hubverse tables.
770- selected_ref_date : datetime.date
771- The selected reference date.
772- loc_abbr
773- The abbreviated US jurisdiction abbreviation.
774-
775- Returns
776- -------
777- tuple
778- A tuple of observed_data_table (pl.DataFrame) and
779- forecast_table (pl.DataFrame) filtered by model,
780- target, and location, to be used for plotting.
781- """
782- data_to_plot = observed_data_table .filter (
783- pl .col ("loc_abbr" ) == loc_abbr ,
784- pl .col ("target" ) == selected_target ,
785- )
786- forecasts_to_plot = forecast_table .filter (
787- pl .col ("loc_abbr" ) == loc_abbr ,
788- pl .col ("target" ) == selected_target ,
789- pl .col ("model_id" ).is_in (selected_models ),
790- pl .col ("reference_date" ) == selected_ref_date ,
791- )
792- return data_to_plot , forecasts_to_plot
793-
794-
795875def main () -> None :
796876 # record session start time
797877 start_time = time .time ()
@@ -805,11 +885,14 @@ def main() -> None:
805885 "Please upload Observed Data or Hubverse Forecasts to begin."
806886 )
807887 return None
808- loc_abbr , selected_ref_date = location_and_reference_data_ui (
809- observed_data_table , forecast_table
810- )
811- selected_models , selected_target = model_and_target_selection_ui (
812- observed_data_table , forecast_table , loc_abbr
888+ loc_abbr = location_selection_ui (observed_data_table , forecast_table )
889+ selected_ref_date = reference_date_selection_ui (forecast_table )
890+ selected_models = model_selection_ui (forecast_table , loc_abbr )
891+ selected_target = target_selection_ui (
892+ observed_data_table ,
893+ forecast_table ,
894+ loc_abbr ,
895+ selected_models ,
813896 )
814897 scale = "log" if st .checkbox ("Log-scale" , value = True ) else "linear"
815898 grid = st .checkbox ("Gridlines" , value = True )
0 commit comments