From 046d9ff9ad5558409fa02021269731cec1e6f8c4 Mon Sep 17 00:00:00 2001 From: Martin Stancsics Date: Tue, 23 Jan 2024 03:00:13 +0100 Subject: [PATCH] Check for unseen categories --- src/glum/_glm.py | 12 ++++++++++-- src/glum/_util.py | 16 ++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/glum/_glm.py b/src/glum/_glm.py index 62afc68d..43f6516e 100644 --- a/src/glum/_glm.py +++ b/src/glum/_glm.py @@ -885,7 +885,12 @@ def _convert_from_pandas( cat_missing_method_after_alignment = self.cat_missing_method if hasattr(self, "feature_dtypes_"): - df = _align_df_categories(df, self.feature_dtypes_) + df = _align_df_categories( + df, + self.feature_dtypes_, + self.has_missing_category_, + self.cat_missing_method, + ) if self.cat_missing_method == "convert": df = _add_missing_categories( df=df, @@ -2650,7 +2655,6 @@ def _set_up_and_check_fit_args( if isinstance(X, pd.DataFrame): if hasattr(self, "formula") and self.formula is not None: - lhs, rhs = _parse_formula( self.formula, include_intercept=self.fit_intercept ) @@ -2705,6 +2709,10 @@ def _set_up_and_check_fit_args( # Maybe TODO: expand categorical penalties with formulas self.feature_dtypes_ = X.dtypes.to_dict() + self.has_missing_category_ = { + col: (self.cat_missing_method == "convert") and X[col].isna().any() + for col in self.feature_dtypes_.keys() + } if any(X.dtypes == "category"): diff --git a/src/glum/_util.py b/src/glum/_util.py index ce734540..f5c463ff 100644 --- a/src/glum/_util.py +++ b/src/glum/_util.py @@ -15,7 +15,9 @@ def _asanyarray(x, **kwargs): return x if pd.api.types.is_scalar(x) else np.asanyarray(x, **kwargs) -def _align_df_categories(df, dtypes) -> pd.DataFrame: +def _align_df_categories( + df, dtypes, has_missing_category, cat_missing_method +) -> pd.DataFrame: """Align data types for prediction. This function checks that categorical columns have same categories in the @@ -26,6 +28,8 @@ def _align_df_categories(df, dtypes) -> pd.DataFrame: ---------- df : pandas.DataFrame dtypes : Dict[str, Union[str, type, pandas.core.dtypes.base.ExtensionDtype]] + has_missing_category : Dict[str, bool] + missing_method : str """ if not isinstance(df, pd.DataFrame): raise TypeError(f"Expected `pandas.DataFrame'; got {type(df)}.") @@ -50,7 +54,15 @@ def _align_df_categories(df, dtypes) -> pd.DataFrame: else: continue - unseen_categories = set(df[column].unique()) - set(dtypes[column].categories) + if cat_missing_method == "convert" and not has_missing_category[column]: + unseen_categories = set(df[column].unique()) - set( + dtypes[column].categories + ) + else: + unseen_categories = set(df[column].dropna().unique()) - set( + dtypes[column].categories + ) + if unseen_categories: raise ValueError( f"Column {column} contains unseen categories: {unseen_categories}."