Skip to content

Commit b41bf2b

Browse files
committed
Simplify _extract_codes_and_categories
1 parent 621391a commit b41bf2b

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

src/tabmat/categorical_matrix.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _extract_codes_and_categories_polars(cat_vec) -> tuple[np.ndarray, np.ndarra
246246

247247

248248
def _extract_codes_and_categories_numpy(cat_vec) -> tuple[np.ndarray, np.ndarray]:
249-
if pd and pd.__version__ >= "1.5.0":
249+
if pd:
250250
indices, categories = pd.factorize(cat_vec, sort=True)
251251
else:
252252
indices, categories = _factorize(cat_vec)
@@ -261,25 +261,18 @@ def _extract_codes_and_categories(cat_vec) -> tuple[np.ndarray, np.ndarray]:
261261
can be converted to a numpy array. Pandas and polars inputs are special
262262
cased for performance considerations.
263263
"""
264-
if pd and isinstance(cat_vec, pd.Categorical):
265-
# Narwhals can only handle series, not bare Categoricals
266-
cat_vec = pd.Series(cat_vec, copy=False)
267-
# We convert to narwhals first so we handle narwhalized and non-narwhalized
268-
# pandas and polars inputs the same way.
269-
cat_vec = nw.from_native(cat_vec, series_only=True, strict=False)
270-
271-
if isinstance(cat_vec, nw.Series):
272-
package = nw.get_native_namespace(cat_vec).__name__
273-
else:
274-
package = None
264+
cat_vec_native = nw.to_native(cat_vec, strict=False)
275265

276-
if package == "pandas":
277-
return _extract_codes_and_categories_pandas(nw.to_native(cat_vec))
278-
elif package == "polars":
279-
return _extract_codes_and_categories_polars(nw.to_native(cat_vec))
266+
if pd and isinstance(cat_vec_native, (pd.Series, pd.Categorical)):
267+
return _extract_codes_and_categories_pandas(cat_vec_native)
268+
elif pl and isinstance(cat_vec_native, pl.Series):
269+
return _extract_codes_and_categories_polars(cat_vec_native)
280270
else:
281-
if isinstance(cat_vec, nw.Series):
282-
cat_vec = cat_vec.cast(nw.String).to_numpy()
271+
if isinstance(
272+
cat_vec_narwhals := nw.from_native(cat_vec, series_only=True, strict=False),
273+
nw.Series,
274+
):
275+
cat_vec = cat_vec_narwhals.cast(nw.String).to_numpy()
283276
else:
284277
cat_vec = np.asarray(cat_vec)
285278
return _extract_codes_and_categories_numpy(cat_vec)

0 commit comments

Comments
 (0)