From cd5ceb7819128c6aa507796bc466aa1d5d54033e Mon Sep 17 00:00:00 2001 From: ursk Date: Mon, 25 Mar 2024 14:39:01 -0700 Subject: [PATCH] Simplify AutoBNN plotting util PiperOrigin-RevId: 618963756 --- .../experimental/autobnn/training_util.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/tensorflow_probability/python/experimental/autobnn/training_util.py b/tensorflow_probability/python/experimental/autobnn/training_util.py index 36e6560a3b..60bce255bb 100644 --- a/tensorflow_probability/python/experimental/autobnn/training_util.py +++ b/tensorflow_probability/python/experimental/autobnn/training_util.py @@ -288,6 +288,7 @@ def plot_results( y_train: Optional[jax.Array] = None, diagnostics: Optional[Dict[str, jax.Array]] = None, log_scale: bool = False, + show_particles: bool = True, left_limit: int = 24*7*2, right_limit: int = 24*7*2, ) -> plt.Figure: @@ -307,44 +308,41 @@ def plot_results( else: fig, res_ax = plt.subplots(figsize=(16, 3), constrained_layout=True) - for idx, p in enumerate(preds): - res_ax.plot( - dates_preds, - p, - 'k-', - alpha=0.1, - label='Particle predictions' if idx == 0 else None, - ) + if show_particles: + for idx, p in enumerate(preds): + res_ax.plot( + dates_preds, + p, + 'k-', + alpha=0.1, + label='Particle predictions' if idx == 0 else None, + ) color = 'steelblue' if p50 is not None: res_ax.plot( - dates_preds, p50, '-', lw=5, color=color, label='Prediction') + dates_preds, p50, '-', lw=2.5, color=color, label='Prediction') if p97_5 is not None and p2_5 is not None: res_ax.plot(dates_preds, p97_5, '-', - lw=3, color=color, label='Upper/lower bound') - res_ax.plot(dates_preds, p2_5, '-', lw=3, color=color) + lw=1.5, color=color, label='Upper/lower bound') + res_ax.plot(dates_preds, p2_5, '-', lw=1.5, color=color) res_ax.fill_between( dates_preds, p2_5, p97_5, color=color, alpha=0.2 ) - data_kwargs = {'ms': 7, 'mec': 'k', 'mew': 2} if dates_train is not None and y_train is not None: res_ax.plot( dates_train, y_train, - 'o', - mfc='red', - label='Train data', - **data_kwargs) + 'k-', + label='Ground truth data', + ) if dates_test is not None and y_test is not None: res_ax.plot( dates_test, y_test, - 'o', - mfc='green', - label='Test data', - **data_kwargs) + 'k-', + ) res_ax.set_title('Predictions') res_ax.legend() left_limit = min(len(dates_preds) - len(dates_test), left_limit) @@ -353,6 +351,7 @@ def plot_results( # TODO(ursk): Rather than modifying xlim, don't plot invisible points at all. res_ax.set_xlim([dates_preds[first_test_point-left_limit], dates_preds[first_test_point+right_limit]]) + res_ax.axvline(dates_preds[first_test_point], linestyle='--') return fig