Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not fail when an estimator misses class members that are new in v3 #757

Merged
merged 8 commits into from
Jan 31, 2024
25 changes: 15 additions & 10 deletions src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,16 +883,16 @@ def _convert_from_pandas(
if hasattr(self, "X_model_spec_"):
return self.X_model_spec_.get_model_matrix(df, context=context)

cat_missing_method_after_alignment = self.cat_missing_method
cat_missing_method_after_alignment = getattr(self, "cat_missing_method", "fail")

if hasattr(self, "feature_dtypes_"):
df = _align_df_categories(
df,
self.feature_dtypes_,
self.has_missing_category_,
self.cat_missing_method,
getattr(self, "has_missing_category_", {}),
cat_missing_method_after_alignment,
)
if self.cat_missing_method == "convert":
if cat_missing_method_after_alignment == "convert":
df = _add_missing_categories(
df=df,
dtypes=self.feature_dtypes_,
Expand All @@ -906,7 +906,9 @@ def _convert_from_pandas(
X = tm.from_pandas(
df,
drop_first=self.drop_first,
categorical_format=self.categorical_format,
categorical_format=getattr( # convention prior to v3
self, "categorical_format", "{name}__{category}"
),
cat_missing_method=cat_missing_method_after_alignment,
)

Expand Down Expand Up @@ -1629,7 +1631,7 @@ def wald_test(
)
if num_lhs_specs != 1:
raise ValueError(
"Exactly one of R, features terms or formula must be specified. "
"Exactly one of R, features, terms or formula must be specified. "
f"Received {num_lhs_specs} specifications."
)

Expand Down Expand Up @@ -2724,7 +2726,8 @@ def _set_up_and_check_fit_args(

self.feature_dtypes_ = X.dtypes.to_dict()
self.has_missing_category_ = {
col: (self.cat_missing_method == "convert") and X[col].isna().any()
col: (getattr(self, "cat_missing_method", "fail") == "convert")
and X[col].isna().any()
for col, dtype in self.feature_dtypes_.items()
if isinstance(dtype, pd.CategoricalDtype)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that, because self.has_missing_category_ is created here, we don't need to use getattr(self, "has_missing_category_", ...) in 2789 and 2792 anymore.

Expand Down Expand Up @@ -2784,9 +2787,11 @@ def _expand_categorical_penalties(
X = tm.from_pandas(
X,
drop_first=self.drop_first,
categorical_format=self.categorical_format,
cat_missing_method=self.cat_missing_method,
cat_missing_name=self.cat_missing_name,
categorical_format=getattr( # convention prior to v3
self, "categorical_format", "{name}__{category}"
),
cat_missing_method=getattr(self, "cat_missing_method", "fail"),
cat_missing_name=getattr(self, "cat_missing_name", "(MISSING)"),
)

if y is None:
Expand Down
Loading