@@ -67,12 +67,8 @@ def __init__(
6767 metadata = metadata ,
6868 )
6969 if isclass (categories ) and issubclass (categories , enum .Enum ):
70- categories = pl .Series (
71- values = [getattr (v , "value" , v ) for v in categories .__members__ .values ()]
72- )
73- elif not isinstance (categories , pl .Series ):
74- categories = pl .Series (values = categories )
75- self .categories = categories
70+ categories = (item .value for item in categories )
71+ self .categories = list (categories )
7672
7773 @property
7874 def dtype (self ) -> pl .DataType :
@@ -81,7 +77,7 @@ def dtype(self) -> pl.DataType:
8177 def validate_dtype (self , dtype : PolarsDataType ) -> bool :
8278 if not isinstance (dtype , pl .Enum ):
8379 return False
84- return self .categories . equals ( dtype .categories )
80+ return self .categories == dtype .categories . to_list ( )
8581
8682 def sqlalchemy_dtype (self , dialect : sa .Dialect ) -> sa_TypeEngine :
8783 category_lengths = [len (c ) for c in self .categories ]
@@ -102,6 +98,6 @@ def pyarrow_dtype(self) -> pa.DataType:
10298 def _sample_unchecked (self , generator : Generator , n : int ) -> pl .Series :
10399 return generator .sample_choice (
104100 n ,
105- choices = self .categories . to_list () ,
101+ choices = self .categories ,
106102 null_probability = self ._null_probability ,
107103 ).cast (self .dtype )
0 commit comments