diff --git a/maidr/core/plot/barplot.py b/maidr/core/plot/barplot.py index 6e4dddd..d978035 100644 --- a/maidr/core/plot/barplot.py +++ b/maidr/core/plot/barplot.py @@ -3,7 +3,7 @@ from matplotlib.axes import Axes from matplotlib.container import BarContainer -from maidr.core.enum import PlotType +from maidr.core.enum import MaidrKey, PlotType from maidr.core.plot import MaidrPlot from maidr.exception import ExtractionError from maidr.util.mixin import ( @@ -18,29 +18,69 @@ def __init__(self, ax: Axes) -> None: super().__init__(ax, PlotType.BAR) def _extract_plot_data(self) -> list: + """ + Extract plot data for bar plots. + + For vertical bar plots, categories are on X-axis and values on Y-axis. + For horizontal bar plots, categories are on Y-axis and values on X-axis. + + Returns + ------- + list + List of dictionaries containing x and y data points. + """ plot = self.extract_container(self.ax, BarContainer, include_all=True) data = self._extract_bar_container_data(plot) - levels = self.extract_level(self.ax) + + # Extract appropriate axis labels based on bar orientation + if plot and plot[0].orientation == "vertical": + # For vertical bars: categories on X-axis, values on Y-axis + levels = self.extract_level(self.ax, MaidrKey.X) + else: + # For horizontal bars: categories on Y-axis, values on X-axis + levels = self.extract_level(self.ax, MaidrKey.Y) + + # Handle the case where levels might be None or empty + if levels is None or data is None: + if data is None: + raise ExtractionError(self.type, plot) + # If levels is None but data exists, create default labels + levels = [f"Item {i+1}" for i in range(len(data))] + formatted_data = [] combined_data = list( zip(levels, data) - if plot[0].orientation == "vertical" - else zip(data, levels) # type: ignore + if plot and plot[0].orientation == "vertical" + else zip(data, levels) ) - if combined_data: # type: ignore - for x, y in combined_data: # type: ignore + + if combined_data: + for x, y in combined_data: formatted_data.append({"x": x, "y": y}) return formatted_data + + # If no formatted data could be created, raise an error if len(formatted_data) == 0: raise ExtractionError(self.type, plot) - if data is None: - raise ExtractionError(self.type, plot) return data def _extract_bar_container_data( self, plot: list[BarContainer] | None ) -> list | None: + """ + Extract bar container data with proper orientation handling. + + Parameters + ---------- + plot : list[BarContainer] | None + List of bar containers from the plot. + + Returns + ------- + list | None + List of bar heights/widths, or None if extraction fails. + """ if plot is None: return None @@ -48,14 +88,26 @@ def _extract_bar_container_data( # `list[BarContainers] for plotting bar plots. # So, extract data correspondingly based on the level. # Flatten all the `list[BarContainer]` to `list[Patch]`. - plot = [patch for container in plot for patch in container.patches] - level = self.extract_level(self.ax) - if len(level) == 0: # type: ignore - level = ["" for _ in range(len(plot))] # type: ignore + plot_patches = [patch for container in plot for patch in container.patches] + + # Extract appropriate axis labels based on bar orientation + if plot[0].orientation == "vertical": + # For vertical bars: categories on X-axis + level = self.extract_level(self.ax, MaidrKey.X) + else: + # For horizontal bars: categories on Y-axis + level = self.extract_level(self.ax, MaidrKey.Y) + + if level is None or len(level) == 0: + level = ["" for _ in range(len(plot_patches))] - if len(plot) != len(level): + if len(plot_patches) != len(level): return None - self._elements.extend(plot) + self._elements.extend(plot_patches) - return [float(patch.get_height()) for patch in plot] + # For horizontal bars, use width; for vertical bars, use height + if plot[0].orientation == "horizontal": + return [float(patch.get_width()) for patch in plot_patches] + else: + return [float(patch.get_height()) for patch in plot_patches]