Skip to content

Commit

Permalink
Properly handle missings when checking for unseen
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Jan 23, 2024
1 parent 4d829a4 commit 3a30eb0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
19 changes: 13 additions & 6 deletions src/tabmat/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def from_categorical(
reduced_rank: bool,
missing_method: str = "fail",
missing_name: str = "(MISSING)",
force_convert: bool = False,
add_category_for_nan: bool = False,
) -> "_InteractableCategoricalVector":
"""Create an interactable categorical vector from a pandas categorical."""
categories = list(cat.categories)
Expand All @@ -446,7 +446,7 @@ def from_categorical(
"if cat_missing_method='fail'."
)

if missing_method == "convert" and (-1 in codes or force_convert):
if missing_method == "convert" and (-1 in codes or add_category_for_nan):
codes[codes == -1] = len(categories)
categories.append(missing_name)

Expand Down Expand Up @@ -723,25 +723,32 @@ def encode_contrasts(
order to avoid spanning the intercept.
"""
levels = levels if levels is not None else _state.get("categories")
force_convert = _state.get("force_convert", False)
add_category_for_nan = _state.get("add_category_for_nan", False)

# Check for unseen categories when levels are specified
if levels is not None:
unseen_categories = set(data.dropna().unique()) - set(levels)
if missing_method == "convert" and not add_category_for_nan:
unseen_categories = set(data.unique()) - set(levels)
else:
unseen_categories = set(data.dropna().unique()) - set(levels)

if unseen_categories:
raise ValueError(
f"Column {data.name} contains unseen categories: {unseen_categories}."
)

cat = pandas.Categorical(data._values, categories=levels)
_state["categories"] = cat.categories
_state["force_convert"] = missing_method == "convert" and cat.isna().any()
_state["add_category_for_nan"] = add_category_for_nan or (
missing_method == "convert" and cat.isna().any()
)

return _InteractableCategoricalVector.from_categorical(
cat,
reduced_rank=reduced_rank,
missing_method=missing_method,
missing_name=missing_name,
force_convert=force_convert,
add_category_for_nan=add_category_for_nan,
)


Expand Down
19 changes: 18 additions & 1 deletion tests/test_formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def test_cat_missing_interactions():


@pytest.mark.parametrize(
"cat_missing_method", ["zero", "convert"], ids=["zero", "convert"]
"cat_missing_method", ["zero", "convert", "fail"], ids=["zero", "convert", "fail"]
)
def test_unseen_category(cat_missing_method):
df = pd.DataFrame(
Expand All @@ -768,6 +768,23 @@ def test_unseen_category(cat_missing_method):
result_seen.model_spec.get_model_matrix(df_unseen)


def test_unseen_missing_convert():
df = pd.DataFrame(
{
"cat_1": pd.Categorical(["a", "b"]),
}
)
df_unseen = pd.DataFrame(
{
"cat_1": pd.Categorical(["a", "b", pd.NA]),
}
)
result_seen = tm.from_formula("cat_1 - 1", df, cat_missing_method="convert")

with pytest.raises(ValueError, match="contains unseen categories"):
result_seen.model_spec.get_model_matrix(df_unseen)


# Tests from formulaic's test suite
# ---------------------------------

Expand Down

0 comments on commit 3a30eb0

Please sign in to comment.