Skip to content

Commit 046d9ff

Browse files
committed
Check for unseen categories
1 parent c448f3d commit 046d9ff

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/glum/_glm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,12 @@ def _convert_from_pandas(
885885
cat_missing_method_after_alignment = self.cat_missing_method
886886

887887
if hasattr(self, "feature_dtypes_"):
888-
df = _align_df_categories(df, self.feature_dtypes_)
888+
df = _align_df_categories(
889+
df,
890+
self.feature_dtypes_,
891+
self.has_missing_category_,
892+
self.cat_missing_method,
893+
)
889894
if self.cat_missing_method == "convert":
890895
df = _add_missing_categories(
891896
df=df,
@@ -2650,7 +2655,6 @@ def _set_up_and_check_fit_args(
26502655

26512656
if isinstance(X, pd.DataFrame):
26522657
if hasattr(self, "formula") and self.formula is not None:
2653-
26542658
lhs, rhs = _parse_formula(
26552659
self.formula, include_intercept=self.fit_intercept
26562660
)
@@ -2705,6 +2709,10 @@ def _set_up_and_check_fit_args(
27052709
# Maybe TODO: expand categorical penalties with formulas
27062710

27072711
self.feature_dtypes_ = X.dtypes.to_dict()
2712+
self.has_missing_category_ = {
2713+
col: (self.cat_missing_method == "convert") and X[col].isna().any()
2714+
for col in self.feature_dtypes_.keys()
2715+
}
27082716

27092717
if any(X.dtypes == "category"):
27102718

src/glum/_util.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def _asanyarray(x, **kwargs):
1515
return x if pd.api.types.is_scalar(x) else np.asanyarray(x, **kwargs)
1616

1717

18-
def _align_df_categories(df, dtypes) -> pd.DataFrame:
18+
def _align_df_categories(
19+
df, dtypes, has_missing_category, cat_missing_method
20+
) -> pd.DataFrame:
1921
"""Align data types for prediction.
2022
2123
This function checks that categorical columns have same categories in the
@@ -26,6 +28,8 @@ def _align_df_categories(df, dtypes) -> pd.DataFrame:
2628
----------
2729
df : pandas.DataFrame
2830
dtypes : Dict[str, Union[str, type, pandas.core.dtypes.base.ExtensionDtype]]
31+
has_missing_category : Dict[str, bool]
32+
missing_method : str
2933
"""
3034
if not isinstance(df, pd.DataFrame):
3135
raise TypeError(f"Expected `pandas.DataFrame'; got {type(df)}.")
@@ -50,7 +54,15 @@ def _align_df_categories(df, dtypes) -> pd.DataFrame:
5054
else:
5155
continue
5256

53-
unseen_categories = set(df[column].unique()) - set(dtypes[column].categories)
57+
if cat_missing_method == "convert" and not has_missing_category[column]:
58+
unseen_categories = set(df[column].unique()) - set(
59+
dtypes[column].categories
60+
)
61+
else:
62+
unseen_categories = set(df[column].dropna().unique()) - set(
63+
dtypes[column].categories
64+
)
65+
5466
if unseen_categories:
5567
raise ValueError(
5668
f"Column {column} contains unseen categories: {unseen_categories}."

0 commit comments

Comments
 (0)