Skip to content

Commit

Permalink
Check for unseen categories
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Jan 23, 2024
1 parent c448f3d commit 046d9ff
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
12 changes: 10 additions & 2 deletions src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,12 @@ def _convert_from_pandas(
cat_missing_method_after_alignment = self.cat_missing_method

if hasattr(self, "feature_dtypes_"):
df = _align_df_categories(df, self.feature_dtypes_)
df = _align_df_categories(
df,
self.feature_dtypes_,
self.has_missing_category_,
self.cat_missing_method,
)
if self.cat_missing_method == "convert":
df = _add_missing_categories(
df=df,
Expand Down Expand Up @@ -2650,7 +2655,6 @@ def _set_up_and_check_fit_args(

if isinstance(X, pd.DataFrame):
if hasattr(self, "formula") and self.formula is not None:

lhs, rhs = _parse_formula(
self.formula, include_intercept=self.fit_intercept
)
Expand Down Expand Up @@ -2705,6 +2709,10 @@ def _set_up_and_check_fit_args(
# Maybe TODO: expand categorical penalties with formulas

self.feature_dtypes_ = X.dtypes.to_dict()
self.has_missing_category_ = {
col: (self.cat_missing_method == "convert") and X[col].isna().any()
for col in self.feature_dtypes_.keys()
}

if any(X.dtypes == "category"):

Expand Down
16 changes: 14 additions & 2 deletions src/glum/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def _asanyarray(x, **kwargs):
return x if pd.api.types.is_scalar(x) else np.asanyarray(x, **kwargs)


def _align_df_categories(df, dtypes) -> pd.DataFrame:
def _align_df_categories(
df, dtypes, has_missing_category, cat_missing_method
) -> pd.DataFrame:
"""Align data types for prediction.
This function checks that categorical columns have same categories in the
Expand All @@ -26,6 +28,8 @@ def _align_df_categories(df, dtypes) -> pd.DataFrame:
----------
df : pandas.DataFrame
dtypes : Dict[str, Union[str, type, pandas.core.dtypes.base.ExtensionDtype]]
has_missing_category : Dict[str, bool]
missing_method : str
"""
if not isinstance(df, pd.DataFrame):
raise TypeError(f"Expected `pandas.DataFrame'; got {type(df)}.")
Expand All @@ -50,7 +54,15 @@ def _align_df_categories(df, dtypes) -> pd.DataFrame:
else:
continue

unseen_categories = set(df[column].unique()) - set(dtypes[column].categories)
if cat_missing_method == "convert" and not has_missing_category[column]:
unseen_categories = set(df[column].unique()) - set(
dtypes[column].categories
)
else:
unseen_categories = set(df[column].dropna().unique()) - set(
dtypes[column].categories
)

if unseen_categories:
raise ValueError(
f"Column {column} contains unseen categories: {unseen_categories}."
Expand Down

0 comments on commit 046d9ff

Please sign in to comment.