Skip to content

Commit 4daf7a6

Browse files
committed
Added tests
1 parent 9646aae commit 4daf7a6

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

tests/test_decomposition.py

+38
Original file line numberDiff line numberDiff line change
@@ -1958,3 +1958,41 @@ def test_regs_list_is_not_modified(random_ragged_cmf, regs):
19581958
)
19591959

19601960
assert regs == regs_unmodified
1961+
1962+
1963+
1964+
@pytest.mark.parametrize("constant_feasibility_penalty", ["", "AB", "C"])
1965+
def test_constant_feasibility_penalty_fails_with_invalid(random_ragged_cmf, constant_feasibility_penalty):
1966+
1967+
cmf, shapes, rank = random_ragged_cmf
1968+
matrices = cmf.to_matrices()
1969+
1970+
# Check that we get correct output when none of the conditions are met
1971+
with pytest.raises(ValueError):
1972+
decomposition.cmf_aoadmm(
1973+
matrices,
1974+
rank,
1975+
n_iter_max=0,
1976+
return_errors=True,
1977+
verbose=False,
1978+
non_negative=True,
1979+
parafac2=True,
1980+
constant_feasibility_penalty=constant_feasibility_penalty
1981+
)
1982+
1983+
1984+
@pytest.mark.parametrize("constant_feasibility_penalty", ["A", "B", True, False, None])
1985+
def test_constant_feasibility_penalty_works_with_valid(random_ragged_cmf, constant_feasibility_penalty):
1986+
cmf, shapes, rank = random_ragged_cmf
1987+
matrices = cmf.to_matrices()
1988+
1989+
decomposition.cmf_aoadmm(
1990+
matrices,
1991+
rank,
1992+
n_iter_max=0,
1993+
return_errors=True,
1994+
verbose=False,
1995+
non_negative=True,
1996+
parafac2=True,
1997+
constant_feasibility_penalty=constant_feasibility_penalty
1998+
)

tests/test_utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# MIT License: Copyright (c) 2022, Marie Roald.
22
# See the LICENSE file in the root directory for full license text.
3+
from unittest.mock import patch
34

45
import numpy as np
56
import pytest
@@ -62,6 +63,14 @@ def test_get_svd(rng, svd):
6263
assert allclose(U1TU2 * Vh1Vh2T, np.eye(U1TU2.shape[0]))
6364

6465

66+
def test_get_svd_works_with_old_tensorly():
67+
svds = {"TEST": "SVD"}
68+
with patch("matcouply._utils.tl.SVD_FUNS", svds) as mock:
69+
svd = utils.get_svd("TEST")
70+
71+
assert svd == svds["TEST"]
72+
73+
6574
def test_get_svd_fails_with_invalid_svd_name():
6675
with pytest.raises(ValueError):
6776
utils.get_svd("THIS_IS_NOT_A_VALID_SVD")

0 commit comments

Comments
 (0)