Skip to content

Commit 2066f43

Browse files
authored
Treat boolean columns as numeric (#380)
1 parent 24bd564 commit 2066f43

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
Changelog
88
=========
99

10-
4.0.1 - 2024-06-25
10+
4.1.0 - unreleased
1111
------------------
1212

1313
**New feature:**
1414

1515
- Added a new function, :func:`tabmat.from_polars`, to convert a :class:`polars.DataFrame` into a :class:`tabmat.SplitMatrix`.
1616

17+
4.0.1 - 2024-06-25
18+
------------------
19+
1720
**Other changes:**
1821

1922
- Removed reference to the ``.A`` attribute and replaced it with ``.toarray()``.

src/tabmat/constructor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@
2828
pd = None
2929

3030

31+
def _is_boolean(series, engine: str):
32+
if engine == "pandas":
33+
return pd.api.types.is_bool_dtype(series)
34+
elif engine == "polars":
35+
return series.dtype.is_(pl.Boolean)
36+
else:
37+
raise ValueError(f"Unknown engine: {engine}")
38+
39+
3140
def _is_numeric(series, engine: str):
3241
if engine == "pandas":
3342
return pd.api.types.is_numeric_dtype(series)
@@ -154,6 +163,15 @@ def _from_dataframe(
154163
mxcolidx += cat.shape[1]
155164
elif cat_position == "end":
156165
indices.append(np.arange(cat.shape[1]))
166+
elif _is_boolean(coldata, engine):
167+
if (coldata != False).mean() <= sparse_threshold: # noqa E712
168+
sparse_dfidx.append(dfcolidx)
169+
sparse_tmidx.append(mxcolidx)
170+
mxcolidx += 1
171+
else:
172+
dense_dfidx.append(dfcolidx)
173+
dense_tmidx.append(mxcolidx)
174+
mxcolidx += 1
157175
elif _is_numeric(coldata, engine):
158176
if (coldata != 0).mean() <= sparse_threshold:
159177
sparse_dfidx.append(dfcolidx)

0 commit comments

Comments
 (0)