Skip to content

Commit c81c34a

Browse files
Copilotjooyoungseo
andcommitted
fix: handle horizontal vs vertical bar plot orientation in subplot iteration
Co-authored-by: jooyoungseo <[email protected]>
1 parent 9fea5f1 commit c81c34a

File tree

1 file changed

+67
-15
lines changed

1 file changed

+67
-15
lines changed

maidr/core/plot/barplot.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from matplotlib.axes import Axes
44
from matplotlib.container import BarContainer
55

6-
from maidr.core.enum import PlotType
6+
from maidr.core.enum import MaidrKey, PlotType
77
from maidr.core.plot import MaidrPlot
88
from maidr.exception import ExtractionError
99
from maidr.util.mixin import (
@@ -18,44 +18,96 @@ def __init__(self, ax: Axes) -> None:
1818
super().__init__(ax, PlotType.BAR)
1919

2020
def _extract_plot_data(self) -> list:
21+
"""
22+
Extract plot data for bar plots.
23+
24+
For vertical bar plots, categories are on X-axis and values on Y-axis.
25+
For horizontal bar plots, categories are on Y-axis and values on X-axis.
26+
27+
Returns
28+
-------
29+
list
30+
List of dictionaries containing x and y data points.
31+
"""
2132
plot = self.extract_container(self.ax, BarContainer, include_all=True)
2233
data = self._extract_bar_container_data(plot)
23-
levels = self.extract_level(self.ax)
34+
35+
# Extract appropriate axis labels based on bar orientation
36+
if plot and plot[0].orientation == "vertical":
37+
# For vertical bars: categories on X-axis, values on Y-axis
38+
levels = self.extract_level(self.ax, MaidrKey.X)
39+
else:
40+
# For horizontal bars: categories on Y-axis, values on X-axis
41+
levels = self.extract_level(self.ax, MaidrKey.Y)
42+
43+
# Handle the case where levels might be None or empty
44+
if levels is None or data is None:
45+
if data is None:
46+
raise ExtractionError(self.type, plot)
47+
# If levels is None but data exists, create default labels
48+
levels = [f"Item {i+1}" for i in range(len(data))]
49+
2450
formatted_data = []
2551
combined_data = list(
2652
zip(levels, data)
27-
if plot[0].orientation == "vertical"
28-
else zip(data, levels) # type: ignore
53+
if plot and plot[0].orientation == "vertical"
54+
else zip(data, levels)
2955
)
30-
if combined_data: # type: ignore
31-
for x, y in combined_data: # type: ignore
56+
57+
if combined_data:
58+
for x, y in combined_data:
3259
formatted_data.append({"x": x, "y": y})
3360
return formatted_data
61+
62+
# If no formatted data could be created, raise an error
3463
if len(formatted_data) == 0:
3564
raise ExtractionError(self.type, plot)
36-
if data is None:
37-
raise ExtractionError(self.type, plot)
3865

3966
return data
4067

4168
def _extract_bar_container_data(
4269
self, plot: list[BarContainer] | None
4370
) -> list | None:
71+
"""
72+
Extract bar container data with proper orientation handling.
73+
74+
Parameters
75+
----------
76+
plot : list[BarContainer] | None
77+
List of bar containers from the plot.
78+
79+
Returns
80+
-------
81+
list | None
82+
List of bar heights/widths, or None if extraction fails.
83+
"""
4484
if plot is None:
4585
return None
4686

4787
# Since v0.13, Seaborn has transitioned from using `list[Patch]` to
4888
# `list[BarContainers] for plotting bar plots.
4989
# So, extract data correspondingly based on the level.
5090
# Flatten all the `list[BarContainer]` to `list[Patch]`.
51-
plot = [patch for container in plot for patch in container.patches]
52-
level = self.extract_level(self.ax)
53-
if len(level) == 0: # type: ignore
54-
level = ["" for _ in range(len(plot))] # type: ignore
91+
plot_patches = [patch for container in plot for patch in container.patches]
92+
93+
# Extract appropriate axis labels based on bar orientation
94+
if plot[0].orientation == "vertical":
95+
# For vertical bars: categories on X-axis
96+
level = self.extract_level(self.ax, MaidrKey.X)
97+
else:
98+
# For horizontal bars: categories on Y-axis
99+
level = self.extract_level(self.ax, MaidrKey.Y)
100+
101+
if level is None or len(level) == 0:
102+
level = ["" for _ in range(len(plot_patches))]
55103

56-
if len(plot) != len(level):
104+
if len(plot_patches) != len(level):
57105
return None
58106

59-
self._elements.extend(plot)
107+
self._elements.extend(plot_patches)
60108

61-
return [float(patch.get_height()) for patch in plot]
109+
# For horizontal bars, use width; for vertical bars, use height
110+
if plot[0].orientation == "horizontal":
111+
return [float(patch.get_width()) for patch in plot_patches]
112+
else:
113+
return [float(patch.get_height()) for patch in plot_patches]

0 commit comments

Comments
 (0)