Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create CatMatrix from codes and categories #389

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Changelog

- Added a new function, :func:`tabmat.from_polars`, to convert a :class:`polars.DataFrame` into a :class:`tabmat.SplitMatrix`.

**Other changes:**

- Allow :class:`CategoricalMatrix` to be initialized directly with indices and categories.

4.0.1 - 2024-06-25
------------------

Expand Down
44 changes: 22 additions & 22 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 31 additions & 24 deletions src/tabmat/categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,6 @@ def matvec(mat, vec):
import polars as pl


class _Categorical:
"""This class helps us avoid copies while subsetting."""

def __init__(self, indices, categories, dtype):
self.indices = indices
self.categories = categories
self.dtype = dtype


def _is_indexer_full_length(full_length: int, indexer: Union[slice, np.ndarray]):
if isinstance(indexer, np.ndarray):
if (indexer > full_length - 1).any():
Expand All @@ -232,10 +223,7 @@ def _is_polars(x) -> bool:


def _extract_codes_and_categories(cat_vec):
if isinstance(cat_vec, _Categorical):
categories = cat_vec.categories
indices = cat_vec.indices
elif _is_pandas(cat_vec):
if _is_pandas(cat_vec):
categories = cat_vec.categories.to_numpy()
indices = cat_vec.codes
elif _is_pandas(cat_vec.dtype):
Expand Down Expand Up @@ -284,6 +272,9 @@ class CategoricalMatrix(MatrixBase):
cat_vec:
array-like vector of categorical data.

categories: np.ndarray, default None
If provided, cat_vec is assumed to be an array-like vector of indices.

drop_first:
drop the first level of the dummy encoding. This allows a CategoricalMatrix
to be used in an unregularized setting.
Expand All @@ -306,6 +297,7 @@ class CategoricalMatrix(MatrixBase):
def __init__(
self,
cat_vec,
categories: Optional[np.ndarray] = None,
drop_first: bool = False,
dtype: np.dtype = np.float64,
column_name: Optional[str] = None,
Expand All @@ -321,13 +313,21 @@ def __init__(
)

if not hasattr(cat_vec, "dtype"):
cat_vec = np.array(cat_vec) # avoid errors in pd.factorize
cat_vec = np.asarray(cat_vec) # avoid errors in pd.factorize

self._input_dtype = cat_vec.dtype
self._missing_method = cat_missing_method
self._missing_category = cat_missing_name

indices, self.categories = _extract_codes_and_categories(cat_vec)
if categories is not None:
self.categories = categories
indices = np.nan_to_num(cat_vec, nan=-1)
if max(indices) >= len(categories):
raise ValueError("Indices exceed length of categories.")
if min(indices) < -1:
raise ValueError("Indices must be non-negative (or -1 for missing).")
else:
indices, self.categories = _extract_codes_and_categories(cat_vec)

if np.any(indices == -1):
if self._missing_method == "fail":
Expand Down Expand Up @@ -357,7 +357,13 @@ def __init__(
self._has_missings = False

self.drop_first = drop_first
self.indices = indices.astype(np.int32, copy=False)
try:
self.indices = indices.astype(np.int32, copy=False)
except ValueError:
raise ValueError(
"When creating a CategoricalMatrix with indices and categories, "
"indices must be castable to a numpy int32 dtype."
)
self.shape = (len(self.indices), len(self.categories) - int(drop_first))
self.x_csc = None
self.dtype = np.dtype(dtype)
Expand All @@ -382,13 +388,13 @@ def cat(self):
"This property will be removed in the next major release.",
category=DeprecationWarning,
)

if _is_polars(self._input_dtype):
out = self.categories[self.indices].astype("object", copy=False)
out = np.where(self.indices < 0, None, out)
return pl.Series(out, dtype=pl.Enum(self.categories))

return pd.Categorical.from_codes(self.indices, categories=self.categories)
try:
return pd.Categorical.from_codes(self.indices, categories=self.categories)
except NameError:
raise ModuleNotFoundError(
"The `cat` property is provided for backward compatibility and "
"requires pandas to be installed."
)

def recover_orig(self) -> np.ndarray:
"""
Expand Down Expand Up @@ -681,7 +687,8 @@ def __getitem__(self, item):
if isinstance(row, np.ndarray):
row = row.ravel()
return CategoricalMatrix(
_Categorical(self.indices[row], self.categories, self._input_dtype),
self.indices[row],
categories=self.categories,
drop_first=self.drop_first,
dtype=self.dtype,
column_name=self._colname,
Expand Down