Skip to content

Commit

Permalink
Fix bug in SplitMatrix matvec (#154)
Browse files Browse the repository at this point in the history
* Add failing test.

* Deal with the case where there is no dense matrix.

* Update tests/test_split_matrix.py

Co-authored-by: Marc-Antoine Schmidt <[email protected]>
  • Loading branch information
jtilly and MarcAntoineSchmidtQC authored Nov 12, 2021
1 parent 4de2faa commit 57d8f59
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
10 changes: 9 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
Changelog
=========

3.0.6 - 2021-11-11
------------------

**Bug fix**

- We fixed a bug in :meth:`tabmat.SplitMatrix.matvec`, where incorrect matrix vector products were computed when a ``SplitMatrix`` did not contain any dense components.


3.0.5 - 2021-11-05
------------------

Expand Down Expand Up @@ -184,7 +192,7 @@ We are trying to make releases for Windows.
- Fix a bug in `matvec` for categorical components when the number of categories exceeds the number of rows.


0.0.6 - 2020-08-03
0.0.6 - 2020-08-03
------------------

See git history.
3 changes: 2 additions & 1 deletion src/tabmat/split_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,15 @@ def matvec(
# as the target for storing the final output. This reduces the number
# of output arrays allocated from 2 to 1.
is_matrix_dense = [isinstance(m, DenseMatrix) for m in self.matrices]
dense_matrix_idx = np.argmax(is_matrix_dense)
if np.any(is_matrix_dense):
dense_matrix_idx = np.argmax(is_matrix_dense)
sub_cols = subset_cols[dense_matrix_idx]
idx = self.indices[dense_matrix_idx]
mat = self.matrices[dense_matrix_idx]
in_vec = v[idx, ...]
out = np.asarray(mat.matvec(in_vec, sub_cols, out), dtype=out_dtype)
else:
dense_matrix_idx = -1
out = _prepare_out_array(out, out_shape, out_dtype)

for i, (sub_cols, idx, mat) in enumerate(
Expand Down
14 changes: 14 additions & 0 deletions tests/test_split_matrix.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import pytest
import scipy.sparse as sps

import tabmat as tm
from tabmat import from_pandas
from tabmat.constructor import _split_sparse_and_dense_parts
from tabmat.dense_matrix import DenseMatrix
from tabmat.ext.sparse import csr_dense_sandwich
Expand Down Expand Up @@ -237,3 +239,15 @@ def test_init_from_1d():

res = SplitMatrix([m1, m2])
assert res.shape == (10, 3)


@pytest.mark.parametrize("n_rows", [5, 10, 25])
def test_matvec(n_rows):
np.random.seed(1234)
n_cols = 2
categories = [f"cat_{val}" for val in range(5)]
X = pd.DataFrame(np.random.choice(categories, size=(n_rows, n_cols))).astype(
"category"
)
mat = from_pandas(X, cat_threshold=0)
np.testing.assert_allclose(mat.matvec(np.array(mat.shape[1] * [1])), n_cols)

0 comments on commit 57d8f59

Please sign in to comment.