diff --git a/livelossplot/outputs/matplotlib_plot.py b/livelossplot/outputs/matplotlib_plot.py index 0197acf..a7025d6 100644 --- a/livelossplot/outputs/matplotlib_plot.py +++ b/livelossplot/outputs/matplotlib_plot.py @@ -59,6 +59,8 @@ def send(self, logger: MainLogger): max_rows = math.ceil((len(log_groups) + len(self.extra_plots)) / self.max_cols) fig, axes = plt.subplots(max_rows, self.max_cols) + if not isinstance(axes, np.ndarray): + axes = np.array([[axes]]) axes = axes.reshape(-1, self.max_cols) self._before_plots(fig, axes, len(log_groups))