Skip to content

Commit

Permalink
Check sandwich dimensions (#393)
Browse files Browse the repository at this point in the history
* Add additional compatibility checks for matvec and sandwich

* Don't be strict about matvec dtypes

* Revert test

* Add changelog entry
  • Loading branch information
stanmart authored Sep 26, 2024
1 parent 1e9579a commit 94ad89e
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Changelog
**Other changes:**

- Allow :class:`CategoricalMatrix` to be initialized directly with indices and categories.
- Added checks for dimension and ``dtype`` mismatch in :meth:`MatrixBasesandwich.sandwich`.

**Bug fix:**

Expand Down
2 changes: 2 additions & 0 deletions src/tabmat/categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def matvec(mat, vec):
_check_indexer,
check_matvec_dimensions,
check_matvec_out_shape,
check_sandwich_compatible,
check_transpose_matvec_out_shape,
set_up_rows_or_cols,
setup_restrictions,
Expand Down Expand Up @@ -584,6 +585,7 @@ def sandwich(
matrix without making a copy.
"""
d = np.asarray(d)
check_sandwich_compatible(self, d)
rows = set_up_rows_or_cols(rows, self.shape[0])
if self.drop_first or self._has_missings:
res_diag = sandwich_categorical_complex(
Expand Down
2 changes: 2 additions & 0 deletions src/tabmat/dense_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_check_indexer,
check_matvec_dimensions,
check_matvec_out_shape,
check_sandwich_compatible,
check_transpose_matvec_out_shape,
setup_restrictions,
)
Expand Down Expand Up @@ -143,6 +144,7 @@ def sandwich(
) -> np.ndarray:
"""Perform a sandwich product: X.T @ diag(d) @ X."""
d = np.asarray(d)
check_sandwich_compatible(self, d)
rows, cols = setup_restrictions(self.shape, rows, cols)
return dense_sandwich(self._array, d, rows, cols)

Expand Down
9 changes: 2 additions & 7 deletions src/tabmat/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_check_indexer,
check_matvec_dimensions,
check_matvec_out_shape,
check_sandwich_compatible,
check_transpose_matvec_out_shape,
set_up_rows_or_cols,
setup_restrictions,
Expand Down Expand Up @@ -179,13 +180,7 @@ def sandwich(
) -> np.ndarray:
"""Perform a sandwich product: X.T @ diag(d) @ X."""
d = np.asarray(d)
if not self.dtype == d.dtype:
raise TypeError(
f"""self and d need to be of same dtype, either np.float64
or np.float32. self is of type {self.dtype}, while d is of type
{d.dtype}."""
)

check_sandwich_compatible(self, d)
rows, cols = setup_restrictions(self.shape, rows, cols, dtype=self.idx_dtype)
return sparse_sandwich(self, self.array_csr, d, rows, cols)

Expand Down
12 changes: 8 additions & 4 deletions src/tabmat/split_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from .sparse_matrix import SparseMatrix
from .standardized_mat import StandardizedMatrix
from .util import (
check_matvec_dimensions,
check_matvec_out_shape,
check_sandwich_compatible,
check_transpose_matvec_out_shape,
set_up_rows_or_cols,
)
Expand Down Expand Up @@ -326,9 +328,8 @@ def sandwich(
cols: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Perform a sandwich product: X.T @ diag(d) @ X."""
if np.shape(d) != (self.shape[0],):
raise ValueError
d = np.asarray(d)
check_sandwich_compatible(self, d)

subset_cols_indices, subset_cols, n_cols = self._split_col_subsets(cols)

Expand Down Expand Up @@ -377,9 +378,10 @@ def matvec(
) -> np.ndarray:
"""Perform self[:, cols] @ other[cols]."""
assert not isinstance(v, sps.spmatrix)
v = np.asarray(v)
check_matvec_dimensions(self, v, transpose=False)
check_matvec_out_shape(self, out)

v = np.asarray(v)
if v.shape[0] != self.shape[1]:
raise ValueError(f"shapes {self.shape} and {v.shape} not aligned")

Expand Down Expand Up @@ -435,9 +437,11 @@ def transpose_matvec(
= sum_{j in rows} sum_{mat in self.matrices} 1(cols[i] in mat)
self[j, cols[i]] v[j]
"""
check_transpose_matvec_out_shape(self, out)

v = np.asarray(v)
check_matvec_dimensions(self, v, transpose=True)
check_transpose_matvec_out_shape(self, out)

subset_cols_indices, subset_cols, n_cols = self._split_col_subsets(cols)

out_shape = [n_cols] + list(v.shape[1:])
Expand Down
8 changes: 2 additions & 6 deletions src/tabmat/standardized_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .sparse_matrix import SparseMatrix
from .util import (
check_matvec_dimensions,
check_sandwich_compatible,
check_transpose_matvec_out_shape,
set_up_rows_or_cols,
setup_restrictions,
Expand Down Expand Up @@ -128,12 +129,7 @@ def sandwich(
"""Perform a sandwich product: X.T @ diag(d) @ X."""
if not hasattr(d, "dtype"):
d = np.asarray(d)
if not self.mat.dtype == d.dtype:
raise TypeError(
f"""self.mat and d need to be of same dtype, either
np.float64 or np.float32. This matrix is of type {self.mat.dtype},
while d is of type {d.dtype}."""
)
check_sandwich_compatible(self, d)

if rows is not None or cols is not None:
setup_rows, setup_cols = setup_restrictions(self.shape, rows, cols)
Expand Down
15 changes: 15 additions & 0 deletions src/tabmat/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ def check_matvec_dimensions(mat, vec: np.ndarray, transpose: bool) -> None:
)


def check_sandwich_compatible(mat, d: np.ndarray):
"""Assert that the dimensions and dtypes for the sandwich product are compatible."""
if mat.shape[0] != d.shape[0]:
raise ValueError(
f"shapes {mat.shape} and {d.shape} not aligned: "
f"{mat.shape[0]} (dim 0) != {d.shape[0]} (dim 0)"
)
if not mat.dtype == d.dtype:
raise TypeError(
f"""self and d need to be of same dtype, either np.float64
or np.float32. self is of type {mat.dtype}, while d is of type
{d.dtype}."""
)


def _check_indexer(indexer):
"""Check that the indexer is valid, and transform it to a canonical format."""
if not isinstance(indexer, tuple):
Expand Down
22 changes: 22 additions & 0 deletions tests/test_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,28 @@ def test_matvec_dimension_mismatch_raises(mat, rows, cols):
mat.transpose_matvec(too_long_transpose, rows=rows, cols=cols)


@pytest.mark.parametrize("mat", get_matrices())
@pytest.mark.parametrize("cols", [None, [], [1], np.array([0, 1])])
@pytest.mark.parametrize("rows", [None, [], [1], np.array([0, 2])])
def test_sandwich_dimension_mismatch_raises(mat, rows, cols):
too_short = np.ones(mat.shape[0] - 1, dtype=mat.dtype)
just_right = np.ones(mat.shape[0], dtype=mat.dtype)
too_long = np.ones(mat.shape[0] + 1, dtype=mat.dtype)
mat.sandwich(just_right, cols=cols)
with pytest.raises(ValueError, match="not aligned"):
mat.sandwich(too_short, cols=cols)
with pytest.raises(ValueError, match="not aligned"):
mat.sandwich(too_long, cols=cols)


@pytest.mark.parametrize("mat", get_matrices())
def test_sandwich_dtype_mismatch_raises(mat):
with pytest.raises(TypeError, match="same dtype"):
mat.astype(np.float64).sandwich(np.ones(mat.shape[0], dtype=np.float32))
with pytest.raises(TypeError, match="same dtype"):
mat.astype(np.float32).sandwich(np.ones(mat.shape[0], dtype=np.float64))


@pytest.mark.parametrize("mat", get_matrices())
@pytest.mark.parametrize("i", [1, -2])
def test_getcol(mat: Union[tm.MatrixBase, tm.StandardizedMatrix], i):
Expand Down

0 comments on commit 94ad89e

Please sign in to comment.