Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Sep 23, 2024
1 parent beb2ee3 commit d0528b7
Showing 1 changed file with 16 additions and 48 deletions.
64 changes: 16 additions & 48 deletions tests/test_categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,61 +208,29 @@ def test_categorical_indexing(drop_first, missing, cat_missing_method):
np.testing.assert_allclose(mat[:, [0, 1]].toarray(), expected)


@pytest.mark.parametrize("input_type", ["pandas", "polars", "pyarrow"])
def test_extract_codes_and_categories(input_type):
@pytest.mark.parametrize(
"input_type", ["pandas.Categorical", "pandas", "polars", "pyarrow", "list"]
)
@pytest.mark.parametrize("narwhals_input", [True, False])
def test_extract_codes_and_categories(input_type, narwhals_input):
cat_vec = pd.Series(["a", "b", "c", pd.NA, "b", "a", "d"], dtype="category")
if input_type == "polars":
if input_type == "pandas.Categorical":
cat_vec = pd.Categorical(cat_vec)
elif input_type == "polars":
cat_vec = pl.Series(cat_vec)
elif input_type == "pyarrow":
cat_vec = pyarrow.chunked_array([cat_vec])
elif input_type == "list":
cat_vec = cat_vec.astype("object")

nw_vec = nw.from_native(cat_vec, series_only=True)
if narwhals_input:
if input_type in ["list", "pandas.Categorical"]:
pytest.skip("Narwhals doesn't support list or pandas.Categorical inputs")
cat_vec = nw.from_native(cat_vec, series_only=True)

indices, categories, namespace = _extract_codes_and_categories(nw_vec)
indices, categories = _extract_codes_and_categories(cat_vec)
np.testing.assert_array_equal(indices, np.array([0, 1, 2, -1, 1, 0, 3]))
np.testing.assert_array_equal(categories, np.array(["a", "b", "c", "d"]))
assert namespace.__name__ == input_type


@pytest.mark.parametrize(
"input_type", ["pandas_vec", "pandas", "polars", "pyarrow", "numpy"]
)
def test_cat_property(input_type):
cat_vec = pd.Categorical(["a", "b", "c", pd.NA, "b", "a", "d"])
if input_type == "pandas_vec":
cat_in = cat_vec
elif input_type == "pandas":
cat_in = pd.Series(cat_vec)
elif input_type == "polars":
cat_in = pl.Series(pd.Series(cat_vec))
elif input_type == "pyarrow":
cat_in = pyarrow.chunked_array([pd.Series(cat_vec)])
elif input_type == "numpy":
cat_in = cat_vec.to_numpy()

cat_out = CategoricalMatrix(cat_in, cat_missing_method="zero").cat

if input_type == "pandas" or input_type == "pandas_vec":
assert isinstance(cat_out, pd.Categorical)
np.testing.assert_array_equal(cat_out.codes, cat_vec.codes)
np.testing.assert_array_equal(cat_out.categories, cat_vec.categories)
elif input_type == "polars":
assert isinstance(cat_out, pl.Series)
assert isinstance(cat_out.dtype, pl.Categorical)
polars.testing.assert_series_equal(cat_in, cat_out)
elif input_type == "pyarrow":
assert isinstance(cat_out, pyarrow.ChunkedArray)
assert isinstance(cat_out.type, pyarrow.DictionaryType)
np.testing.assert_array_equal(
cat_out.cast(pyarrow.string()).to_numpy(),
cat_out.cast(pyarrow.string()).to_numpy(),
)
elif input_type == "numpy":
assert isinstance(cat_out, dict)
np.testing.assert_array_equal(cat_out["indices"], cat_vec.codes)
np.testing.assert_array_equal(
cat_out["categories"], cat_vec.categories.to_numpy()
)


def test_polars_non_contiguous_codes():
Expand All @@ -271,5 +239,5 @@ def test_polars_non_contiguous_codes():
_ = pl.Series(["beagle", "poodle", "labrador"], dtype=pl.Categorical)
cat_series = pl.Series(str_series, dtype=pl.Categorical)

indices, categories, _ = _extract_codes_and_categories(cat_series)
indices, categories = _extract_codes_and_categories(cat_series)
np.testing.assert_array_equal(str_series, categories[indices].tolist())

0 comments on commit d0528b7

Please sign in to comment.