diff --git a/hubverse_annotator/ui.py b/hubverse_annotator/ui.py index a8b7858..3d0e47d 100644 --- a/hubverse_annotator/ui.py +++ b/hubverse_annotator/ui.py @@ -17,6 +17,7 @@ import streamlit as st from streamlit_shortcuts import add_shortcuts from utils import ( + build_ci_specs_from_df, get_available_locations, get_initial_window_range, get_reference_dates, @@ -31,6 +32,7 @@ CHART_TITLE_FONT_SIZE = 18 REF_DATE_STROKE_WIDTH = 2.5 REF_DATE_STROKE_DASH = [6, 6] +MARKER_SIZE = 65 ROOT = pathlib.Path(__file__).resolve().parent.parent @@ -406,11 +408,44 @@ def plotting_ui( # empty streamlit object (DeltaGenerator) needed for # plots to reload successfully with new data. base_chart = st.empty() - forecast_layer = quantile_forecast_chart( - forecasts_to_plot, selected_target, scale=scale, grid=show_grid - ) + + has_obs = not data_to_plot.is_empty() + has_fc = not forecasts_to_plot.is_empty() + ci_specs = build_ci_specs_from_df(forecasts_to_plot) if has_fc else {} + + legend_labels = ["Observations"] if has_obs else [] + color_range = ["limegreen"] if has_obs else [] + + if has_fc and ci_specs: + legend_labels.extend(ci_specs.keys()) + color_range.extend(["blue"] * len(ci_specs)) + + if len(legend_labels) > 1: + color_enc = alt.Color( + "legend_label:N", + title=None, + scale=alt.Scale(domain=legend_labels, scheme="blues"), + ) + else: + color_enc = alt.Color( + "legend_label:N", + title=None, + scale=alt.Scale(domain=legend_labels, range=color_range), + ) observed_layer = target_data_chart( - data_to_plot, selected_target, scale=scale, grid=show_grid + data_to_plot, + selected_target, + color_enc=color_enc, + scale=scale, + grid=show_grid, + ) + forecast_layer = quantile_forecast_chart( + forecasts_to_plot, + selected_target, + ci_specs, + color_enc=color_enc, + scale=scale, + grid=show_grid, ) sub_layers = [ layer for layer in [forecast_layer, observed_layer] if not is_empty_chart(layer) @@ -462,6 +497,13 @@ def plotting_ui( .interactive() .resolve_scale(y="independent") .resolve_axis(x="independent") + .configure_legend( + orient="top", + direction="horizontal", + symbolType="circle", + symbolSize=MARKER_SIZE, + titleAnchor="middle", + ) ) base_chart.altair_chart( chart, diff --git a/hubverse_annotator/utils.py b/hubverse_annotator/utils.py index 607f47c..51d2d29 100644 --- a/hubverse_annotator/utils.py +++ b/hubverse_annotator/utils.py @@ -18,11 +18,14 @@ import polars as pl import polars.selectors as cs import streamlit as st +from babel.numbers import format_percent from streamlit.runtime.uploaded_file_manager import UploadedFile PLOT_WIDTH = 625 STROKE_WIDTH = 2 -MARKER_SIZE = 55 +MARKER_SIZE = 65 +MAX_NUM_CIS = 7 + type ScaleType = Literal["linear", "log"] @@ -130,6 +133,65 @@ def get_initial_window_range( return (start_date, end_date) +def pivot_quantile_df(forecast_table: pl.DataFrame) -> pl.DataFrame: + """ + Converts long-format quantile forecast table into + wide format where each quantile is a column. + """ + return ( + forecast_table.filter(pl.col("output_type") == "quantile") + .pivot( + on="output_type_id", + index=cs.exclude("output_type_id", "value"), + values="value", + ) + .rename({"0.5": "median"}) + ) + + +def build_ci_specs_from_df( + forecast_table: pl.DataFrame, +) -> dict[str, dict[str, str]]: + """ + Automatically constructs a CI_SPECS-style dict for + altair legend creation by finding quantile columns in + a wide forecast table. + + Parameters + ---------- + forecast_table : pl.DataFrame + A Polars DataFrame of forecasts, with a quantile + forecast columns. + + Returns + ------- + dict[str, dict[str, str]] + Mapping from CI label (e.g. "95% CI") to its + bounds and color. + """ + df_wide = pivot_quantile_df(forecast_table) + quant_vals = sorted(float(col) for col in df_wide.columns if col.startswith("0.")) + ci_pairs = sorted( + [(q, 1 - q) for q in quant_vals if q < 0.5 and (1 - q) in quant_vals], + key=lambda p: p[1] - p[0], + reverse=True, + )[:MAX_NUM_CIS] + + labels = [ + f"{format_percent(high - low, locale='en_US', format='#,##0%')} CI" + for low, high in ci_pairs + ] + + specs = { + label: { + "low": f"{low:.3f}".rstrip("0").rstrip("."), + "high": f"{high:.3f}".rstrip("0").rstrip("."), + } + for (low, high), label in zip(ci_pairs, labels, strict=False) + } + return specs + + def is_empty_chart(chart: alt.LayerChart) -> bool: """ Checks if an altair layer is empty. Primarily used for @@ -168,6 +230,7 @@ def is_empty_chart(chart: alt.LayerChart) -> bool: def target_data_chart( observed_data_table: pl.DataFrame, selected_target: str, + color_enc: alt.Color, scale: ScaleType = "log", grid: bool = True, ) -> alt.Chart | alt.LayerChart: @@ -182,6 +245,9 @@ def target_data_chart( selected_target : str The target for filtering in the forecast and or observed hubverse tables. + color_enc : alt.Color + An Altair color encoding used for plotting the + observations color and legend. scale : str The scale to use for the Y axis during plotting. Defaults to logarithmic. @@ -199,7 +265,6 @@ def target_data_chart( x_enc = alt.X( "date:T", axis=alt.Axis(title="Date", grid=grid, ticks=True, labels=True), - scale=alt.Scale(type=scale), ) y_enc = alt.Y( "observation:Q", @@ -214,10 +279,15 @@ def target_data_chart( ) obs_layer = ( alt.Chart(observed_data_table, width=PLOT_WIDTH) - .mark_point(filled=True, size=MARKER_SIZE, color="limegreen") + .transform_calculate(legend_label="'Observations'") + .mark_point( + filled=True, + size=MARKER_SIZE, + ) .encode( x=x_enc, y=y_enc, + color=color_enc, tooltip=[ alt.Tooltip("date:T", title="Date"), alt.Tooltip("observation:Q", title="Value"), @@ -230,6 +300,8 @@ def target_data_chart( def quantile_forecast_chart( forecast_table: pl.DataFrame, selected_target: str, + ci_specs, + color_enc: alt.Color, scale: ScaleType = "log", grid: bool = True, ) -> alt.LayerChart: @@ -246,6 +318,9 @@ def quantile_forecast_chart( selected_target : str The target for filtering in the forecast and or observed hubverse tables. + color_enc : alt.Color + An Altair color encoding used for plotting the + quantile bands color and legend. scale : str The scale to use for the Y axis during plotting. Defaults to logarithmic. @@ -260,24 +335,26 @@ def quantile_forecast_chart( """ if forecast_table.is_empty(): return alt.layer() - df_wide = ( - forecast_table.filter(pl.col("output_type") == "quantile") - .pivot( - on="output_type_id", - index=cs.exclude("output_type_id", "value"), - values="value", - ) - .rename({"0.5": "median"}) - ) + df_wide = pivot_quantile_df(forecast_table) x_enc = alt.X("target_end_date:T", title="Date", axis=alt.Axis(grid=grid)) y_enc = alt.Y( "median:Q", axis=alt.Axis(grid=grid), scale=alt.Scale(type=scale), ) - base = alt.Chart(df_wide, width=PLOT_WIDTH).encode(x=x_enc, y=y_enc) + base = ( + alt.Chart(df_wide, width=PLOT_WIDTH) + .transform_calculate( + date="toDate(datum.target_end_date)", + data_type="'Forecast'", + ) + .encode( + x=x_enc, + y=y_enc, + ) + ) - def band(low: str, high: str, opacity: float) -> alt.Chart: + def band(low: str, high: str, label: str) -> alt.Chart: """ Builds an errorband layer for a quantile. @@ -289,9 +366,9 @@ def band(low: str, high: str, opacity: float) -> alt.Chart: high : str Upper-bound column name in the wide forecast table (e.g., "0.975"). - opacity : float - Fill opacity for the band in the range - [0.0, 1.0]. + label : str + The label in the legend for the confidence + interval (e.g. "97.5% CI"). Returns ------- @@ -299,17 +376,23 @@ def band(low: str, high: str, opacity: float) -> alt.Chart: An Altair layer with the filled band from ``low`` to ``high``, with step interpolation. """ - return base.mark_errorband(opacity=opacity, interpolate="step").encode( - y=alt.Y(f"{low}:Q", title=f"{selected_target}"), - y2=f"{high}:Q", - fill=alt.value("steelblue"), + return ( + base.transform_calculate(legend_label=f"'{label}'") + .mark_errorband(interpolate="step") + .encode( + y=alt.Y(f"{low}:Q", title=f"{selected_target}"), + y2=f"{high}:Q", + color=color_enc, + opacity=alt.value(1.0), + ) ) bands = [ - band("0.025", "0.975", 0.10), - band("0.1", "0.9", 0.20), - band("0.25", "0.75", 0.30), + band(spec["low"], spec["high"], label) + for label, spec in ci_specs.items() + if spec["low"] in df_wide.columns and spec["high"] in df_wide.columns ] + median = base.mark_line(strokeWidth=STROKE_WIDTH, interpolate="step", color="navy") return alt.layer(*bands, median) diff --git a/pyproject.toml b/pyproject.toml index aba453a..0cef114 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "streamlit>=1.47.0", "polars>=0.19.37", "streamlit-shortcuts>=1.1.5", + "babel>=2.17.0", ] [project.urls] diff --git a/uv.lock b/uv.lock index f7d313b..9ccdefd 100644 --- a/uv.lock +++ b/uv.lock @@ -48,6 +48,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, ] +[[package]] +name = "babel" +version = "2.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852, upload-time = "2025-02-01T15:17:41.026Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, +] + [[package]] name = "blinker" version = "1.9.0" @@ -258,6 +267,7 @@ version = "0.0.1" source = { virtual = "." } dependencies = [ { name = "altair" }, + { name = "babel" }, { name = "forecasttools" }, { name = "polars" }, { name = "streamlit" }, @@ -267,6 +277,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "altair", specifier = ">=5.5.0" }, + { name = "babel", specifier = ">=2.17.0" }, { name = "forecasttools", git = "https://github.com/cdcgov/forecasttools-py" }, { name = "polars", specifier = ">=0.19.37" }, { name = "streamlit", specifier = ">=1.47.0" },