Skip to content

Commit

Permalink
Simplify _extract_codes_and_categories
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Sep 23, 2024
1 parent 621391a commit b41bf2b
Showing 1 changed file with 11 additions and 18 deletions.
29 changes: 11 additions & 18 deletions src/tabmat/categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _extract_codes_and_categories_polars(cat_vec) -> tuple[np.ndarray, np.ndarra


def _extract_codes_and_categories_numpy(cat_vec) -> tuple[np.ndarray, np.ndarray]:
if pd and pd.__version__ >= "1.5.0":
if pd:
indices, categories = pd.factorize(cat_vec, sort=True)
else:
indices, categories = _factorize(cat_vec)
Expand All @@ -261,25 +261,18 @@ def _extract_codes_and_categories(cat_vec) -> tuple[np.ndarray, np.ndarray]:
can be converted to a numpy array. Pandas and polars inputs are special
cased for performance considerations.
"""
if pd and isinstance(cat_vec, pd.Categorical):
# Narwhals can only handle series, not bare Categoricals
cat_vec = pd.Series(cat_vec, copy=False)
# We convert to narwhals first so we handle narwhalized and non-narwhalized
# pandas and polars inputs the same way.
cat_vec = nw.from_native(cat_vec, series_only=True, strict=False)

if isinstance(cat_vec, nw.Series):
package = nw.get_native_namespace(cat_vec).__name__
else:
package = None
cat_vec_native = nw.to_native(cat_vec, strict=False)

if package == "pandas":
return _extract_codes_and_categories_pandas(nw.to_native(cat_vec))
elif package == "polars":
return _extract_codes_and_categories_polars(nw.to_native(cat_vec))
if pd and isinstance(cat_vec_native, (pd.Series, pd.Categorical)):
return _extract_codes_and_categories_pandas(cat_vec_native)
elif pl and isinstance(cat_vec_native, pl.Series):
return _extract_codes_and_categories_polars(cat_vec_native)
else:
if isinstance(cat_vec, nw.Series):
cat_vec = cat_vec.cast(nw.String).to_numpy()
if isinstance(
cat_vec_narwhals := nw.from_native(cat_vec, series_only=True, strict=False),
nw.Series,
):
cat_vec = cat_vec_narwhals.cast(nw.String).to_numpy()
else:
cat_vec = np.asarray(cat_vec)
return _extract_codes_and_categories_numpy(cat_vec)
Expand Down

0 comments on commit b41bf2b

Please sign in to comment.