@@ -15,7 +15,9 @@ def _asanyarray(x, **kwargs):
15
15
return x if pd .api .types .is_scalar (x ) else np .asanyarray (x , ** kwargs )
16
16
17
17
18
- def _align_df_categories (df , dtypes ) -> pd .DataFrame :
18
+ def _align_df_categories (
19
+ df , dtypes , has_missing_category , cat_missing_method
20
+ ) -> pd .DataFrame :
19
21
"""Align data types for prediction.
20
22
21
23
This function checks that categorical columns have same categories in the
@@ -26,6 +28,8 @@ def _align_df_categories(df, dtypes) -> pd.DataFrame:
26
28
----------
27
29
df : pandas.DataFrame
28
30
dtypes : Dict[str, Union[str, type, pandas.core.dtypes.base.ExtensionDtype]]
31
+ has_missing_category : Dict[str, bool]
32
+ missing_method : str
29
33
"""
30
34
if not isinstance (df , pd .DataFrame ):
31
35
raise TypeError (f"Expected `pandas.DataFrame'; got { type (df )} ." )
@@ -50,7 +54,15 @@ def _align_df_categories(df, dtypes) -> pd.DataFrame:
50
54
else :
51
55
continue
52
56
53
- unseen_categories = set (df [column ].unique ()) - set (dtypes [column ].categories )
57
+ if cat_missing_method == "convert" and not has_missing_category [column ]:
58
+ unseen_categories = set (df [column ].unique ()) - set (
59
+ dtypes [column ].categories
60
+ )
61
+ else :
62
+ unseen_categories = set (df [column ].dropna ().unique ()) - set (
63
+ dtypes [column ].categories
64
+ )
65
+
54
66
if unseen_categories :
55
67
raise ValueError (
56
68
f"Column { column } contains unseen categories: { unseen_categories } ."
0 commit comments