Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use narwhals to support Polars, cuDF, Modin, etc. #388

Merged
merged 40 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
75b6505
Add dependencies to pixi.toml
stanmart Sep 2, 2024
618b583
Pixi-ize pre-commit
stanmart Sep 2, 2024
d82ca13
Add pixi tasks
stanmart Sep 2, 2024
7c4df48
Update CI
stanmart Sep 2, 2024
15dfb30
Fix build dependencies
stanmart Sep 2, 2024
35a3c2d
update lockfile
stanmart Sep 2, 2024
415ff89
Fix doctest
stanmart Sep 2, 2024
2e5ceac
Try to fix readthedocs
stanmart Sep 2, 2024
6e6c4c5
Use latest pixi on conda-forge
stanmart Sep 2, 2024
ff09902
Find some minimum versions
stanmart Sep 2, 2024
5529d20
Bump minimum formulaic version
stanmart Sep 3, 2024
1e0d892
Find minimum numpy version
stanmart Sep 3, 2024
f94f98b
Make polars a test dependency
stanmart Sep 3, 2024
4fda475
Update lockfile
stanmart Sep 3, 2024
2007569
Fix typing issues
stanmart Sep 3, 2024
bee91e1
Fix benchmarks
stanmart Sep 3, 2024
0e7fb36
Update contributing docs
stanmart Sep 3, 2024
2a23b81
Make ruff happy
stanmart Sep 3, 2024
4f7a47c
Remove unnecessary pre-commit option from CI
stanmart Sep 10, 2024
f89a57a
first try
MarcAntoineSchmidtQC Sep 11, 2024
e33db3d
Added deprecation, docstring
MarcAntoineSchmidtQC Sep 12, 2024
56b7dbe
replace from_pandas and from_polars
MarcAntoineSchmidtQC Sep 12, 2024
1f8dd90
keep sorting
MarcAntoineSchmidtQC Sep 12, 2024
e02b241
Merge remote-tracking branch 'origin/main' into narwhals
MarcAntoineSchmidtQC Sep 12, 2024
6aed762
add narwhals to conda recipe
MarcAntoineSchmidtQC Sep 12, 2024
a9ea4eb
bump minimum narwhals version
MarcAntoineSchmidtQC Sep 12, 2024
bf1d303
added narwhals to setup.py
MarcAntoineSchmidtQC Sep 12, 2024
0c9df08
Changelog
MarcAntoineSchmidtQC Sep 12, 2024
86a6ebe
Fix categoricals with non-numpy-or-pandas input
stanmart Sep 13, 2024
c6c5f6b
Fix categoricals from numpy/list input
stanmart Sep 13, 2024
42ef8ac
Remove unnecessary import
stanmart Sep 13, 2024
34c4789
Merge branch 'main' into narwhals
stanmart Sep 13, 2024
1d58498
Merge fix from #387
stanmart Sep 13, 2024
64d10cd
Bump minimum narwhals version
stanmart Sep 13, 2024
beb2ee3
Merge branch 'main' into narwhals
stanmart Sep 23, 2024
d0528b7
Update tests
stanmart Sep 23, 2024
621391a
Remove unnecessary argument
stanmart Sep 23, 2024
b41bf2b
Simplify `_extract_codes_and_categories`
stanmart Sep 23, 2024
267c321
Make the check work with the new changes
stanmart Sep 23, 2024
c53d490
Import narwhals' stable v1 API
stanmart Sep 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Changelog

**New feature:**

- Added a new function, :func:`tabmat.from_polars`, to convert a :class:`polars.DataFrame` into a :class:`tabmat.SplitMatrix`.
- Added a new function, :func:`tabmat.from_df`, to convert any dataframe supported by narwhals into a :class:`tabmat.SplitMatrix`.

**Other changes:**

Expand Down
1 change: 1 addition & 0 deletions conda.recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ requirements:
- {{ pin_compatible('numpy') }}
- formulaic>=0.6
- scipy
- narwhals

test:
requires:
Expand Down
56 changes: 56 additions & 0 deletions pixi.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ formulaic = ">=0.6.4"
numpy = ">=1.24.0"
pandas = ">=1.4.4"
scipy = ">=1.7.3"
narwhals = ">=1.4.1"

[feature.dev.dependencies]
ipython = "*"
Expand Down Expand Up @@ -154,6 +155,7 @@ numpy = "1.24.0"
pandas = "1.4.4"
scipy = "1.7.3"
formulaic = "0.6.4"
narwhals = "1.4.1"

[environments]
default = ["dev", "test"]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
],
package_dir={"": "src"},
packages=find_packages(where="src"),
install_requires=["formulaic>=0.6", "numpy", "scipy"],
install_requires=["formulaic>=0.6", "narwhals", "numpy", "scipy"],
python_requires=">=3.9",
ext_modules=cythonize(
ext_modules,
Expand Down
4 changes: 2 additions & 2 deletions src/tabmat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib.metadata

from .categorical_matrix import CategoricalMatrix
from .constructor import from_csc, from_formula, from_pandas, from_polars
from .constructor import from_csc, from_df, from_formula, from_pandas
from .dense_matrix import DenseMatrix
from .matrix_base import MatrixBase
from .sparse_matrix import SparseMatrix
Expand All @@ -23,7 +23,7 @@
"from_csc",
"from_formula",
"from_pandas",
"from_polars",
"from_df",
"as_tabmat",
"hstack",
]
84 changes: 61 additions & 23 deletions src/tabmat/categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def matvec(mat, vec):
import warnings
from typing import Optional, Union

import narwhals.stable.v1 as nw
import numpy as np
from scipy import sparse as sps

Expand Down Expand Up @@ -195,8 +196,13 @@ def matvec(mat, vec):

if importlib.util.find_spec("pandas"):
import pandas as pd
else:
pd = None # type: ignore

if importlib.util.find_spec("polars"):
import polars as pl
else:
pl = None # type: ignore


def _is_indexer_full_length(full_length: int, indexer: Union[slice, np.ndarray]):
Expand All @@ -210,35 +216,68 @@ def _is_indexer_full_length(full_length: int, indexer: Union[slice, np.ndarray])
return len(range(*indexer.indices(full_length))) == full_length


def _is_pandas(x) -> bool:
if importlib.util.find_spec("pandas"):
return isinstance(x, (pd.Categorical, pd.CategoricalDtype))
return False


def _is_polars(x) -> bool:
if importlib.util.find_spec("polars"):
return isinstance(x, (pl.Series, pl.Categorical, pl.Enum))
return False
def _factorize(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"A dumber version of pandas.factorize for when pandas is not available."
na_mask = (x == None) | (x != x) # noqa: E711 # The second part is for NaNs
categories, indices_nona = np.unique(x[~na_mask], return_inverse=True)
indices = np.full(x.shape, -1, dtype=np.int32)
indices[~na_mask] = indices_nona
return indices, categories


def _extract_codes_and_categories(cat_vec):
if _is_pandas(cat_vec):
categories = cat_vec.categories.to_numpy()
def _extract_codes_and_categories_pandas(cat_vec) -> tuple[np.ndarray, np.ndarray]:
if isinstance(cat_vec, pd.Categorical):
categories = cat_vec.categories
indices = cat_vec.codes
elif _is_pandas(cat_vec.dtype):
categories = cat_vec.cat.categories.to_numpy()
elif isinstance(cat_vec.dtype, pd.CategoricalDtype):
categories = cat_vec.cat.categories
indices = cat_vec.cat.codes.to_numpy()
elif _is_polars(cat_vec):
if not _is_polars(cat_vec.dtype):
cat_vec = cat_vec.cast(pl.Categorical)
categories = cat_vec.cat.to_local().cat.get_categories().to_numpy()
indices = cat_vec.cat.to_local().to_physical().fill_null(-1).to_numpy()
else:
indices, categories = pd.factorize(cat_vec, sort=True)
return indices, categories.to_numpy()


def _extract_codes_and_categories_polars(cat_vec) -> tuple[np.ndarray, np.ndarray]:
if not isinstance(cat_vec.dtype, (pl.Categorical, pl.Enum)):
cat_vec = cat_vec.cast(pl.Categorical)
categories = cat_vec.cat.to_local().cat.get_categories().to_numpy()
indices = cat_vec.cat.to_local().to_physical().fill_null(-1).to_numpy()
return indices, categories


def _extract_codes_and_categories_numpy(cat_vec) -> tuple[np.ndarray, np.ndarray]:
if pd:
indices, categories = pd.factorize(cat_vec, sort=True)
else:
indices, categories = _factorize(cat_vec)
return indices, categories


def _extract_codes_and_categories(cat_vec) -> tuple[np.ndarray, np.ndarray]:
"""
Extract codes and categories from a series or vector.

The input can be any series supported by narwhals, or an object that
can be converted to a numpy array. Pandas and polars inputs are special
cased for performance considerations.
"""
cat_vec_native = nw.to_native(cat_vec, strict=False)

if pd and isinstance(cat_vec_native, (pd.Series, pd.Categorical)):
return _extract_codes_and_categories_pandas(cat_vec_native)
elif pl and isinstance(cat_vec_native, pl.Series):
return _extract_codes_and_categories_polars(cat_vec_native)
else:
if isinstance(
cat_vec_narwhals := nw.from_native(cat_vec, series_only=True, strict=False),
nw.Series,
):
cat_vec = cat_vec_narwhals.cast(nw.String).to_numpy()
else:
cat_vec = np.asarray(cat_vec)
return _extract_codes_and_categories_numpy(cat_vec)


def _row_col_indexing(
arr: Union[np.ndarray, sps.spmatrix],
rows: Optional[np.ndarray],
Expand Down Expand Up @@ -315,7 +354,6 @@ def __init__(
if not hasattr(cat_vec, "dtype"):
cat_vec = np.asarray(cat_vec) # avoid errors in pd.factorize

self._input_dtype = cat_vec.dtype
self._missing_method = cat_missing_method
self._missing_category = cat_missing_name

Expand Down Expand Up @@ -388,9 +426,9 @@ def cat(self):
"This property will be removed in the next major release.",
category=DeprecationWarning,
)
try:
if pd:
return pd.Categorical.from_codes(self.indices, categories=self.categories)
except NameError:
else:
raise ModuleNotFoundError(
"The `cat` property is provided for backward compatibility and "
"requires pandas to be installed."
Expand Down
Loading