-
Notifications
You must be signed in to change notification settings - Fork 0
Add Legends To Altair Plots #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 28 commits
6bc9e1a
7990fd3
8f16fd0
8cda371
83015c3
2becfc8
61ac0bf
7185da2
4bd04e2
5624d75
1cdba52
043dd46
7ba3e72
2adb8db
bf0074b
a2e9535
8530a62
6c1736c
a8dd896
dffa4d9
90cb0d2
527adba
c8d2ded
fe3e833
cb8b5ab
c823da8
50b57b2
bc0cbc2
80852ad
59d3d20
7d9ba4c
2a911de
8b57127
2be5c6c
53764a5
89b6a47
a5fc900
5e807d1
2c8173d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,18 +11,23 @@ | |
| import datetime | ||
| import logging | ||
| import pathlib | ||
| from itertools import combinations | ||
| from typing import Literal | ||
|
|
||
| import altair as alt | ||
| import colorbrewer | ||
| import forecasttools | ||
| import polars as pl | ||
| import polars.selectors as cs | ||
| import streamlit as st | ||
| from matplotlib.colors import to_hex | ||
| from streamlit.runtime.uploaded_file_manager import UploadedFile | ||
|
|
||
| PLOT_WIDTH = 625 | ||
| STROKE_WIDTH = 2 | ||
| MARKER_SIZE = 55 | ||
| MARKER_SIZE = 65 | ||
| MAX_NUM_CIS = 7 | ||
|
|
||
|
|
||
| type ScaleType = Literal["linear", "log"] | ||
|
|
||
|
|
@@ -132,6 +137,73 @@ def get_initial_window_range( | |
| return (start_date, end_date) | ||
|
|
||
|
|
||
| def wide_quantile_df(forecast_table: pl.DataFrame) -> pl.DataFrame: | ||
O957 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Converts long-format quantile forecast table into | ||
| wide format where each quantile is a column. | ||
| """ | ||
| return ( | ||
| forecast_table.filter(pl.col("output_type") == "quantile") | ||
| .pivot( | ||
| on="output_type_id", | ||
| index=cs.exclude("output_type_id", "value"), | ||
| values="value", | ||
| ) | ||
| .rename({"0.5": "median"}) | ||
| ) | ||
|
|
||
|
|
||
| def build_ci_specs_from_df( | ||
| forecast_table: pl.DataFrame, | ||
| ) -> dict[str, dict[str, str]]: | ||
| """ | ||
| Automatically constructs a CI_SPECS-style dict for | ||
| altair legend creation by finding quantile columns in | ||
| a wide forecast table. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| forecast_table : pl.DataFrame | ||
| A Polars DataFrame of forecasts, with a quantile | ||
| forecast columns. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, dict[str, str]] | ||
| Mapping from CI label (e.g. "95% CI") to its | ||
| bounds and color. | ||
| """ | ||
| df_wide = wide_quantile_df(forecast_table) | ||
| quant_vals = sorted( | ||
| float(col) for col in df_wide.columns if col.startswith("0.") | ||
| ) | ||
| ci_pairs = [ | ||
| (low, high) | ||
| for low, high in combinations(quant_vals, 2) | ||
| if low < 0.5 < high and abs(low + high - 1.0) < 1e-6 | ||
| ] | ||
O957 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ci_pairs.sort(key=lambda p: p[1] - p[0], reverse=True) | ||
|
|
||
| labels = [f"{round((high - low) * 100)}% CI" for low, high in ci_pairs] | ||
O957 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| palette_rgb255 = list( | ||
| colorbrewer.Blues[max(3, min(MAX_NUM_CIS, len(labels)))] | ||
|
||
| ) | ||
| palette = [(r / 255, g / 255, b / 255) for r, g, b in palette_rgb255] | ||
|
||
|
|
||
| specs = { | ||
| label: { | ||
| "low": f"{low:.3f}".rstrip("0").rstrip("."), | ||
| "high": f"{high:.3f}".rstrip("0").rstrip("."), | ||
| "color": to_hex(color), | ||
|
||
| } | ||
| for (low, high), label, color in zip( | ||
| ci_pairs, labels, palette, strict=False | ||
| ) | ||
| } | ||
| return specs | ||
|
|
||
|
|
||
| def is_empty_chart(chart: alt.LayerChart) -> bool: | ||
| """ | ||
| Checks if an altair layer is empty. Primarily used for | ||
|
|
@@ -172,6 +244,7 @@ def is_empty_chart(chart: alt.LayerChart) -> bool: | |
| def target_data_chart( | ||
| observed_data_table: pl.DataFrame, | ||
| selected_target: str, | ||
| color_enc: alt.Color, | ||
| scale: ScaleType = "log", | ||
| grid: bool = True, | ||
| ) -> alt.Chart | alt.LayerChart: | ||
|
|
@@ -186,6 +259,9 @@ def target_data_chart( | |
| selected_target : str | ||
| The target for filtering in the forecast and or | ||
| observed hubverse tables. | ||
| color_enc : alt.Color | ||
| An Altair color encoding used for plotting the | ||
| observations color and legend. | ||
| scale : str | ||
| The scale to use for the Y axis during plotting. | ||
| Defaults to logarithmic. | ||
|
|
@@ -203,7 +279,6 @@ def target_data_chart( | |
| x_enc = alt.X( | ||
| "date:T", | ||
| axis=alt.Axis(title="Date", grid=grid, ticks=True, labels=True), | ||
| scale=alt.Scale(type=scale), | ||
| ) | ||
| y_enc = alt.Y( | ||
| "observation:Q", | ||
|
|
@@ -218,10 +293,15 @@ def target_data_chart( | |
| ) | ||
| obs_layer = ( | ||
| alt.Chart(observed_data_table, width=PLOT_WIDTH) | ||
| .mark_point(filled=True, size=MARKER_SIZE, color="limegreen") | ||
| .transform_calculate(legend_label="'Observations'") | ||
| .mark_point( | ||
| filled=True, | ||
| size=MARKER_SIZE, | ||
| ) | ||
| .encode( | ||
| x=x_enc, | ||
| y=y_enc, | ||
| color=color_enc, | ||
| tooltip=[ | ||
| alt.Tooltip("date:T", title="Date"), | ||
| alt.Tooltip("observation:Q", title="Value"), | ||
|
|
@@ -234,6 +314,8 @@ def target_data_chart( | |
| def quantile_forecast_chart( | ||
| forecast_table: pl.DataFrame, | ||
| selected_target: str, | ||
| ci_specs, | ||
| color_enc: alt.Color, | ||
| scale: ScaleType = "log", | ||
| grid: bool = True, | ||
| ) -> alt.LayerChart: | ||
|
|
@@ -250,6 +332,9 @@ def quantile_forecast_chart( | |
| selected_target : str | ||
| The target for filtering in the forecast and or | ||
| observed hubverse tables. | ||
| color_enc : alt.Color | ||
| An Altair color encoding used for plotting the | ||
| quantile bands color and legend. | ||
| scale : str | ||
| The scale to use for the Y axis during plotting. | ||
| Defaults to logarithmic. | ||
|
|
@@ -264,24 +349,26 @@ def quantile_forecast_chart( | |
| """ | ||
| if forecast_table.is_empty(): | ||
| return alt.layer() | ||
| df_wide = ( | ||
| forecast_table.filter(pl.col("output_type") == "quantile") | ||
| .pivot( | ||
| on="output_type_id", | ||
| index=cs.exclude("output_type_id", "value"), | ||
| values="value", | ||
| ) | ||
| .rename({"0.5": "median"}) | ||
| ) | ||
| df_wide = wide_quantile_df(forecast_table) | ||
| x_enc = alt.X("target_end_date:T", title="Date", axis=alt.Axis(grid=grid)) | ||
| y_enc = alt.Y( | ||
| "median:Q", | ||
| axis=alt.Axis(grid=grid), | ||
| scale=alt.Scale(type=scale), | ||
| ) | ||
| base = alt.Chart(df_wide, width=PLOT_WIDTH).encode(x=x_enc, y=y_enc) | ||
| base = ( | ||
| alt.Chart(df_wide, width=PLOT_WIDTH) | ||
| .transform_calculate( | ||
| date="toDate(datum.target_end_date)", | ||
| data_type="'Forecast'", | ||
| ) | ||
| .encode( | ||
| x=x_enc, | ||
| y=y_enc, | ||
| ) | ||
| ) | ||
|
|
||
| def band(low: str, high: str, opacity: float) -> alt.Chart: | ||
| def band(low: str, high: str, label: str) -> alt.Chart: | ||
| """ | ||
| Builds an errorband layer for a quantile. | ||
|
|
||
|
|
@@ -293,27 +380,33 @@ def band(low: str, high: str, opacity: float) -> alt.Chart: | |
| high : str | ||
| Upper-bound column name in the wide forecast | ||
| table (e.g., "0.975"). | ||
| opacity : float | ||
| Fill opacity for the band in the range | ||
| [0.0, 1.0]. | ||
| label : str | ||
| The label in the legend for the confidence | ||
| interval (e.g. "97.5% CI"). | ||
|
|
||
| Returns | ||
| ------- | ||
| alt.Chart | ||
| An Altair layer with the filled band from | ||
| ``low`` to ``high``, with step interpolation. | ||
| """ | ||
| return base.mark_errorband(opacity=opacity, interpolate="step").encode( | ||
| y=alt.Y(f"{low}:Q", title=f"{selected_target}"), | ||
| y2=f"{high}:Q", | ||
| fill=alt.value("steelblue"), | ||
| return ( | ||
| base.transform_calculate(legend_label=f"'{label}'") | ||
| .mark_errorband(interpolate="step") | ||
| .encode( | ||
| y=alt.Y(f"{low}:Q", title=f"{selected_target}"), | ||
| y2=f"{high}:Q", | ||
| color=color_enc, | ||
| opacity=alt.value(1.0), | ||
| ) | ||
| ) | ||
|
|
||
| bands = [ | ||
| band("0.025", "0.975", 0.10), | ||
| band("0.1", "0.9", 0.20), | ||
| band("0.25", "0.75", 0.30), | ||
| band(spec["low"], spec["high"], label) | ||
| for label, spec in ci_specs.items() | ||
| if spec["low"] in df_wide.columns and spec["high"] in df_wide.columns | ||
|
Comment on lines
308
to
+477
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Structure overall feels a bit off. What about
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will try this out.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Get the colors" step may be obviated by https://vega.github.io/vega/docs/schemes/#blues |
||
| ] | ||
|
|
||
| median = base.mark_line( | ||
| strokeWidth=STROKE_WIDTH, interpolate="step", color="navy" | ||
| ) | ||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than colorbrewer, we should try to take advantage of https://vega.github.io/vega/docs/schemes/#blues