22
22
23
23
24
24
def normalize (X ):
25
- ssX = tl .sum (X ** 2 , 0 )
25
+ ssX = tl .sum (X ** 2 , 0 )
26
26
ssX = tl .reshape (ssX , (1 , * tl .shape (ssX )))
27
27
return X / tl .sqrt (ssX )
28
28
@@ -74,8 +74,7 @@ def test_initialize_cmf_invalid_init(rng):
74
74
75
75
76
76
@pytest .mark .parametrize (
77
- "rank" ,
78
- [1 , 2 , 5 ],
77
+ "rank" , [1 , 2 , 5 ],
79
78
)
80
79
def test_initialize_aux (rng , rank ):
81
80
shapes = ((5 , 10 ), (10 , 10 ), (15 , 10 ))
@@ -98,8 +97,7 @@ def test_initialize_aux(rng, rank):
98
97
99
98
100
99
@pytest .mark .parametrize (
101
- "rank" ,
102
- [1 , 2 , 5 ],
100
+ "rank" , [1 , 2 , 5 ],
103
101
)
104
102
def test_initialize_dual (rng , rank ):
105
103
shapes = ((5 , 10 ), (10 , 10 ), (15 , 10 ))
@@ -128,7 +126,7 @@ def test_cmf_reconstruction_error(rng, random_ragged_cmf):
128
126
# Add random noise
129
127
noise = [tl .tensor (rng .standard_normal (size = shape )) for shape in shapes ]
130
128
noisy_matrices = [matrix + n for matrix , n in zip (matrices , noise )]
131
- noise_norm = tl .sqrt (sum (tl .sum (n ** 2 ) for n in noise ))
129
+ noise_norm = tl .sqrt (sum (tl .sum (n ** 2 ) for n in noise ))
132
130
133
131
# Check that the error is equal to the noise magnitude
134
132
error = decomposition ._cmf_reconstruction_error (noisy_matrices , cmf )
@@ -1457,15 +1455,11 @@ def test_cmf_aoadmm(rng, random_ragged_cmf):
1457
1455
1458
1456
# Construct matrices and compute their norm
1459
1457
matrices = nn_cmf .to_matrices ()
1460
- norm_matrices = tl .sqrt (sum (tl .sum (matrix ** 2 ) for matrix in matrices ))
1458
+ norm_matrices = tl .sqrt (sum (tl .sum (matrix ** 2 ) for matrix in matrices ))
1461
1459
1462
1460
# Decompose matrices with cmf_aoadmm with no constraints
1463
1461
out_cmf , (aux , dual ), diagnostics = decomposition .cmf_aoadmm (
1464
- matrices ,
1465
- rank ,
1466
- n_iter_max = 5_000 ,
1467
- return_errors = True ,
1468
- return_admm_vars = True ,
1462
+ matrices , rank , n_iter_max = 5_000 , return_errors = True , return_admm_vars = True ,
1469
1463
)
1470
1464
1471
1465
# Check that reconstruction error is low
@@ -1615,12 +1609,7 @@ def test_parafac2_makes_nn_cmf_unique(rng):
1615
1609
regularized_loss = [float ("inf" )]
1616
1610
for init in range (5 ):
1617
1611
out , diagnostics = decomposition .cmf_aoadmm (
1618
- matrices ,
1619
- rank ,
1620
- n_iter_max = 1_000 ,
1621
- return_errors = True ,
1622
- non_negative = [True , True , True ],
1623
- parafac2 = True ,
1612
+ matrices , rank , n_iter_max = 1_000 , return_errors = True , non_negative = [True , True , True ], parafac2 = True ,
1624
1613
)
1625
1614
1626
1615
if diagnostics .regularized_loss [- 1 ] < regularized_loss [- 1 ] and diagnostics .satisfied_feasibility_condition :
@@ -1650,11 +1639,7 @@ def test_cmf_aoadmm_not_updating_A_works(rng, random_rank5_ragged_cmf):
1650
1639
1651
1640
# Decompose matrices with cmf_aoadmm with no constraints
1652
1641
out_cmf = decomposition .cmf_aoadmm (
1653
- matrices ,
1654
- rank ,
1655
- n_iter_max = 5 ,
1656
- update_A = False ,
1657
- init = (None , (wrong_A_copy , B_is_copy , C_copy )),
1642
+ matrices , rank , n_iter_max = 5 , update_A = False , init = (None , (wrong_A_copy , B_is_copy , C_copy )),
1658
1643
)
1659
1644
1660
1645
out_weights , (out_A , out_B_is , out_C ) = out_cmf
@@ -1677,11 +1662,7 @@ def test_cmf_aoadmm_not_updating_C_works(rng, random_rank5_ragged_cmf):
1677
1662
1678
1663
# Decompose matrices with cmf_aoadmm with no constraints
1679
1664
out_cmf = decomposition .cmf_aoadmm (
1680
- matrices ,
1681
- rank ,
1682
- n_iter_max = 5 ,
1683
- update_C = False ,
1684
- init = (None , (A_copy , B_is_copy , wrong_C_copy )),
1665
+ matrices , rank , n_iter_max = 5 , update_C = False , init = (None , (A_copy , B_is_copy , wrong_C_copy )),
1685
1666
)
1686
1667
1687
1668
out_weights , (out_A , out_B_is , out_C ) = out_cmf
@@ -1703,11 +1684,7 @@ def test_cmf_aoadmm_not_updating_B_is_works(rng, random_rank5_ragged_cmf):
1703
1684
1704
1685
# Decompose matrices with cmf_aoadmm with no constraints
1705
1686
out_cmf = decomposition .cmf_aoadmm (
1706
- matrices ,
1707
- rank ,
1708
- n_iter_max = 5 ,
1709
- update_B_is = False ,
1710
- init = (None , (A_copy , wrong_B_is_copy , C_copy )),
1687
+ matrices , rank , n_iter_max = 5 , update_B_is = False , init = (None , (A_copy , wrong_B_is_copy , C_copy )),
1711
1688
)
1712
1689
1713
1690
out_weights , (out_A , out_B_is , out_C ) = out_cmf
@@ -1732,9 +1709,9 @@ def test_compute_l2_penalty(rng, random_ragged_cmf):
1732
1709
cmf , shapes , rank = random_ragged_cmf
1733
1710
weights , (A , B_is , C ) = cmf
1734
1711
1735
- SS_A = tl .sum (A ** 2 )
1736
- SS_B = sum (tl .sum (B_i ** 2 ) for B_i in B_is )
1737
- SS_C = tl .sum (C ** 2 )
1712
+ SS_A = tl .sum (A ** 2 )
1713
+ SS_B = sum (tl .sum (B_i ** 2 ) for B_i in B_is )
1714
+ SS_C = tl .sum (C ** 2 )
1738
1715
1739
1716
assert decomposition ._compute_l2_penalty (cmf , [0 , 0 , 0 ]) == 0
1740
1717
assert decomposition ._compute_l2_penalty (cmf , [1 , 0 , 0 ]) == pytest .approx (0.5 * SS_A )
@@ -1750,46 +1727,32 @@ def test_l2_penalty_is_included(rng, random_ragged_cmf):
1750
1727
1751
1728
# Decompose matrices with cmf_aoadmm with no constraints
1752
1729
out_cmf , diagnostics = decomposition .cmf_aoadmm (
1753
- matrices ,
1754
- rank ,
1755
- n_iter_max = 5 ,
1756
- return_errors = True ,
1757
- update_B_is = False ,
1730
+ matrices , rank , n_iter_max = 5 , return_errors = True , update_B_is = False ,
1758
1731
)
1759
1732
1760
1733
rel_sse = diagnostics .rec_errors [- 1 ] ** 2
1761
1734
assert diagnostics .regularized_loss [- 1 ] == pytest .approx (0.5 * rel_sse )
1762
1735
1763
1736
out_cmf , diagnostics = decomposition .cmf_aoadmm (
1764
- matrices ,
1765
- rank ,
1766
- n_iter_max = 5 ,
1767
- l2_penalty = 1 ,
1768
- return_errors = True ,
1769
- update_B_is = False ,
1737
+ matrices , rank , n_iter_max = 5 , l2_penalty = 1 , return_errors = True , update_B_is = False ,
1770
1738
)
1771
1739
1772
1740
out_weights , (out_A , out_B_is , out_C ) = out_cmf
1773
1741
rel_sse = diagnostics .rec_errors [- 1 ] ** 2
1774
- SS_A = tl .sum (out_A ** 2 )
1775
- SS_B = sum (tl .sum (out_B_i ** 2 ) for out_B_i in out_B_is )
1776
- SS_C = tl .sum (out_C ** 2 )
1742
+ SS_A = tl .sum (out_A ** 2 )
1743
+ SS_B = sum (tl .sum (out_B_i ** 2 ) for out_B_i in out_B_is )
1744
+ SS_C = tl .sum (out_C ** 2 )
1777
1745
assert diagnostics .regularized_loss [- 1 ] == pytest .approx (0.5 * rel_sse + 0.5 * (SS_A + SS_B + SS_C ))
1778
1746
1779
1747
out_cmf , diagnostics = decomposition .cmf_aoadmm (
1780
- matrices ,
1781
- rank ,
1782
- n_iter_max = 5 ,
1783
- l2_penalty = [1 , 2 , 3 ],
1784
- return_errors = True ,
1785
- update_B_is = False ,
1748
+ matrices , rank , n_iter_max = 5 , l2_penalty = [1 , 2 , 3 ], return_errors = True , update_B_is = False ,
1786
1749
)
1787
1750
1788
1751
out_weights , (out_A , out_B_is , out_C ) = out_cmf
1789
1752
rel_sse = diagnostics .rec_errors [- 1 ] ** 2
1790
- SS_A = tl .sum (out_A ** 2 )
1791
- SS_B = sum (tl .sum (out_B_i ** 2 ) for out_B_i in out_B_is )
1792
- SS_C = tl .sum (out_C ** 2 )
1753
+ SS_A = tl .sum (out_A ** 2 )
1754
+ SS_B = sum (tl .sum (out_B_i ** 2 ) for out_B_i in out_B_is )
1755
+ SS_C = tl .sum (out_C ** 2 )
1793
1756
assert diagnostics .regularized_loss [- 1 ] == pytest .approx (0.5 * rel_sse + 0.5 * (1 * SS_A + 2 * SS_B + 3 * SS_C ))
1794
1757
1795
1758
@@ -1890,11 +1853,11 @@ def test_first_loss_value_is_correct(random_ragged_cmf):
1890
1853
l1_C = tl .sum (tl .abs (C ))
1891
1854
l1_reg_penalty = l1_A + l1_B + l1_C
1892
1855
l2_A = tl .sum ((weights * A ) ** 2 )
1893
- l2_B = sum (tl .sum (B_i ** 2 ) for B_i in B_is )
1894
- l2_C = tl .sum (C ** 2 )
1856
+ l2_B = sum (tl .sum (B_i ** 2 ) for B_i in B_is )
1857
+ l2_C = tl .sum (C ** 2 )
1895
1858
l2_reg_penalty = l2_A + l2_B + l2_C
1896
1859
rec_error = 0
1897
- initial_loss = 0.5 * rec_error ** 2 + 0.5 * 0.1 * l2_reg_penalty + 0.2 * l1_reg_penalty
1860
+ initial_loss = 0.5 * rec_error ** 2 + 0.5 * 0.1 * l2_reg_penalty + 0.2 * l1_reg_penalty
1898
1861
1899
1862
# Check that we get correct output when none of the conditions are met
1900
1863
out_cmf , diagnostics = decomposition .cmf_aoadmm (
@@ -1913,8 +1876,7 @@ def test_first_loss_value_is_correct(random_ragged_cmf):
1913
1876
1914
1877
1915
1878
@pytest .mark .parametrize (
1916
- "n_iter_max" ,
1917
- [- 1 , 0 ],
1879
+ "n_iter_max" , [- 1 , 0 ],
1918
1880
)
1919
1881
def test_cmf_aoadmm_works_with_zero_iteration (random_ragged_cmf , n_iter_max ):
1920
1882
cmf , shapes , rank = random_ragged_cmf
@@ -1960,7 +1922,6 @@ def test_regs_list_is_not_modified(random_ragged_cmf, regs):
1960
1922
assert regs == regs_unmodified
1961
1923
1962
1924
1963
-
1964
1925
@pytest .mark .parametrize ("constant_feasibility_penalty" , ["" , "AB" , "C" ])
1965
1926
def test_constant_feasibility_penalty_fails_with_invalid (random_ragged_cmf , constant_feasibility_penalty ):
1966
1927
@@ -1977,7 +1938,7 @@ def test_constant_feasibility_penalty_fails_with_invalid(random_ragged_cmf, cons
1977
1938
verbose = False ,
1978
1939
non_negative = True ,
1979
1940
parafac2 = True ,
1980
- constant_feasibility_penalty = constant_feasibility_penalty
1941
+ constant_feasibility_penalty = constant_feasibility_penalty ,
1981
1942
)
1982
1943
1983
1944
@@ -1994,5 +1955,5 @@ def test_constant_feasibility_penalty_works_with_valid(random_ragged_cmf, consta
1994
1955
verbose = False ,
1995
1956
non_negative = True ,
1996
1957
parafac2 = True ,
1997
- constant_feasibility_penalty = constant_feasibility_penalty
1958
+ constant_feasibility_penalty = constant_feasibility_penalty ,
1998
1959
)
0 commit comments