33from  matplotlib .axes  import  Axes 
44from  matplotlib .container  import  BarContainer 
55
6- from  maidr .core .enum  import  PlotType 
6+ from  maidr .core .enum  import  MaidrKey ,  PlotType 
77from  maidr .core .plot  import  MaidrPlot 
88from  maidr .exception  import  ExtractionError 
99from  maidr .util .mixin  import  (
@@ -18,44 +18,96 @@ def __init__(self, ax: Axes) -> None:
1818        super ().__init__ (ax , PlotType .BAR )
1919
2020    def  _extract_plot_data (self ) ->  list :
21+         """ 
22+         Extract plot data for bar plots. 
23+          
24+         For vertical bar plots, categories are on X-axis and values on Y-axis. 
25+         For horizontal bar plots, categories are on Y-axis and values on X-axis. 
26+          
27+         Returns 
28+         ------- 
29+         list 
30+             List of dictionaries containing x and y data points. 
31+         """ 
2132        plot  =  self .extract_container (self .ax , BarContainer , include_all = True )
2233        data  =  self ._extract_bar_container_data (plot )
23-         levels  =  self .extract_level (self .ax )
34+         
35+         # Extract appropriate axis labels based on bar orientation 
36+         if  plot  and  plot [0 ].orientation  ==  "vertical" :
37+             # For vertical bars: categories on X-axis, values on Y-axis 
38+             levels  =  self .extract_level (self .ax , MaidrKey .X )
39+         else :
40+             # For horizontal bars: categories on Y-axis, values on X-axis 
41+             levels  =  self .extract_level (self .ax , MaidrKey .Y )
42+         
43+         # Handle the case where levels might be None or empty 
44+         if  levels  is  None  or  data  is  None :
45+             if  data  is  None :
46+                 raise  ExtractionError (self .type , plot )
47+             # If levels is None but data exists, create default labels 
48+             levels  =  [f"Item { i + 1 }   for  i  in  range (len (data ))]
49+         
2450        formatted_data  =  []
2551        combined_data  =  list (
2652            zip (levels , data )
27-             if  plot [0 ].orientation  ==  "vertical" 
28-             else  zip (data , levels )   # type: ignore 
53+             if  plot   and   plot [0 ].orientation  ==  "vertical" 
54+             else  zip (data , levels )
2955        )
30-         if  combined_data :  # type: ignore 
31-             for  x , y  in  combined_data :  # type: ignore 
56+         
57+         if  combined_data :
58+             for  x , y  in  combined_data :
3259                formatted_data .append ({"x" : x , "y" : y })
3360            return  formatted_data 
61+         
62+         # If no formatted data could be created, raise an error 
3463        if  len (formatted_data ) ==  0 :
3564            raise  ExtractionError (self .type , plot )
36-         if  data  is  None :
37-             raise  ExtractionError (self .type , plot )
3865
3966        return  data 
4067
4168    def  _extract_bar_container_data (
4269        self , plot : list [BarContainer ] |  None 
4370    ) ->  list  |  None :
71+         """ 
72+         Extract bar container data with proper orientation handling. 
73+          
74+         Parameters 
75+         ---------- 
76+         plot : list[BarContainer] | None 
77+             List of bar containers from the plot. 
78+              
79+         Returns 
80+         ------- 
81+         list | None 
82+             List of bar heights/widths, or None if extraction fails. 
83+         """ 
4484        if  plot  is  None :
4585            return  None 
4686
4787        # Since v0.13, Seaborn has transitioned from using `list[Patch]` to 
4888        # `list[BarContainers] for plotting bar plots. 
4989        # So, extract data correspondingly based on the level. 
5090        # Flatten all the `list[BarContainer]` to `list[Patch]`. 
51-         plot  =  [patch  for  container  in  plot  for  patch  in  container .patches ]
52-         level  =  self .extract_level (self .ax )
53-         if  len (level ) ==  0 :  # type: ignore 
54-             level  =  [""  for  _  in  range (len (plot ))]  # type: ignore 
91+         plot_patches  =  [patch  for  container  in  plot  for  patch  in  container .patches ]
92+         
93+         # Extract appropriate axis labels based on bar orientation 
94+         if  plot [0 ].orientation  ==  "vertical" :
95+             # For vertical bars: categories on X-axis 
96+             level  =  self .extract_level (self .ax , MaidrKey .X )
97+         else :
98+             # For horizontal bars: categories on Y-axis 
99+             level  =  self .extract_level (self .ax , MaidrKey .Y )
100+             
101+         if  level  is  None  or  len (level ) ==  0 :
102+             level  =  [""  for  _  in  range (len (plot_patches ))]
55103
56-         if  len (plot ) !=  len (level ):
104+         if  len (plot_patches ) !=  len (level ):
57105            return  None 
58106
59-         self ._elements .extend (plot )
107+         self ._elements .extend (plot_patches )
60108
61-         return  [float (patch .get_height ()) for  patch  in  plot ]
109+         # For horizontal bars, use width; for vertical bars, use height 
110+         if  plot [0 ].orientation  ==  "horizontal" :
111+             return  [float (patch .get_width ()) for  patch  in  plot_patches ]
112+         else :
113+             return  [float (patch .get_height ()) for  patch  in  plot_patches ]
0 commit comments