Skip to content

Commit

Permalink
Do not fail when an estimator misses class members that are new in v3 (
Browse files Browse the repository at this point in the history
…#757)

* do not fail on missing class members that are new in v3

* simplify

* convert

* shorten the comment

* simplify

* don't use getattr unnecessarily

* cosmetics

* fix unrelated typo
  • Loading branch information
MatthiasSchmidtblaicherQC authored Jan 31, 2024
1 parent b185fe4 commit 7e86e3f
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,16 +883,16 @@ def _convert_from_pandas(
if hasattr(self, "X_model_spec_"):
return self.X_model_spec_.get_model_matrix(df, context=context)

cat_missing_method_after_alignment = self.cat_missing_method
cat_missing_method_after_alignment = getattr(self, "cat_missing_method", "fail")

if hasattr(self, "feature_dtypes_"):
df = _align_df_categories(
df,
self.feature_dtypes_,
self.has_missing_category_,
self.cat_missing_method,
getattr(self, "has_missing_category_", {}),
cat_missing_method_after_alignment,
)
if self.cat_missing_method == "convert":
if cat_missing_method_after_alignment == "convert":
df = _add_missing_categories(
df=df,
dtypes=self.feature_dtypes_,
Expand All @@ -906,7 +906,9 @@ def _convert_from_pandas(
X = tm.from_pandas(
df,
drop_first=self.drop_first,
categorical_format=self.categorical_format,
categorical_format=getattr( # convention prior to v3
self, "categorical_format", "{name}__{category}"
),
cat_missing_method=cat_missing_method_after_alignment,
)

Expand Down Expand Up @@ -1629,7 +1631,7 @@ def wald_test(
)
if num_lhs_specs != 1:
raise ValueError(
"Exactly one of R, features terms or formula must be specified. "
"Exactly one of R, features, terms or formula must be specified. "
f"Received {num_lhs_specs} specifications."
)

Expand Down Expand Up @@ -2724,7 +2726,8 @@ def _set_up_and_check_fit_args(

self.feature_dtypes_ = X.dtypes.to_dict()
self.has_missing_category_ = {
col: (self.cat_missing_method == "convert") and X[col].isna().any()
col: (getattr(self, "cat_missing_method", "fail") == "convert")
and X[col].isna().any()
for col, dtype in self.feature_dtypes_.items()
if isinstance(dtype, pd.CategoricalDtype)
}
Expand Down Expand Up @@ -2784,9 +2787,11 @@ def _expand_categorical_penalties(
X = tm.from_pandas(
X,
drop_first=self.drop_first,
categorical_format=self.categorical_format,
cat_missing_method=self.cat_missing_method,
cat_missing_name=self.cat_missing_name,
categorical_format=getattr( # convention prior to v3
self, "categorical_format", "{name}__{category}"
),
cat_missing_method=getattr(self, "cat_missing_method", "fail"),
cat_missing_name=getattr(self, "cat_missing_name", "(MISSING)"),
)

if y is None:
Expand Down

0 comments on commit 7e86e3f

Please sign in to comment.