diff --git a/skglm/penalties/base.py b/skglm/penalties/base.py index b45254b71..f9b68bb1f 100644 --- a/skglm/penalties/base.py +++ b/skglm/penalties/base.py @@ -1,3 +1,4 @@ +from numba import float64 from abc import abstractmethod @@ -60,3 +61,73 @@ def is_penalized(self, n_features): @abstractmethod def generalized_support(self, w): r"""Return a mask which is True for coefficients in the generalized support.""" + + +def overload_with_l2(klass): + """Decorate a penalty class to add L2 regularization. + + The resulting penalty reads + + .. math:: + + "penalty"(w) + "l2"_"regularization" xx ||w||**2 / 2 + + Parameters + ---------- + klass : Penalty class + The penalty class to be overloaded with L2 regularization. + + Returns + ------- + klass : Penalty class + Penalty overloaded with L2 regularization. + """ + # keep ref to original methods + cls_constructor = klass.__init__ + cls_prox_1d = klass.prox_1d + cls_value = klass.value + cls_subdiff_distance = klass. subdiff_distance + cls_params_to_dict = klass.params_to_dict + cls_get_spec = klass.get_spec + + # implement new methods + def __init__(self, *args, l2_regularization=0., **kwargs): + cls_constructor(self, *args, **kwargs) + self.l2_regularization = l2_regularization + + def prox_1d(self, value, stepsize, j): + if self.l2_regularization == 0.: + return cls_prox_1d(self, value, stepsize, j) + + scale = 1 + stepsize * self.l2_regularization + return cls_prox_1d(self, value / scale, stepsize / scale, j) + + def value(self, w): + l2_regularization = self.l2_regularization + if l2_regularization == 0.: + return cls_value(self, w) + + return cls_value(self, w) + l2_regularization * 0.5 * w ** 2 + + def subdiff_distance(self, w, grad, ws): + if self.l2_regularization == 0.: + return cls_subdiff_distance(self, w, grad, ws) + + return cls_subdiff_distance(self, w, grad + self.l2_regularization * w[ws], ws) + + def get_spec(self): + return (('l2_regularization', float64), *cls_get_spec(self)) + + def params_to_dict(self): + return dict(l2_regularization=self.l2_regularization, + **cls_params_to_dict(self)) + + # override methods + klass.__init__ = __init__ + klass.value = value + klass.prox_1d = prox_1d + klass.subdiff_distance = subdiff_distance + klass.get_spec = get_spec + klass.params_to_dict = params_to_dict + + return klass diff --git a/skglm/penalties/separable.py b/skglm/penalties/separable.py index 2c1429a87..4901dbd98 100644 --- a/skglm/penalties/separable.py +++ b/skglm/penalties/separable.py @@ -2,7 +2,7 @@ from numba import float64 from numba.types import bool_ -from skglm.penalties.base import BasePenalty +from skglm.penalties.base import BasePenalty, overload_with_l2 from skglm.utils.prox_funcs import ( ST, box_proj, prox_05, prox_2_3, prox_SCAD, value_SCAD, prox_MCP, value_MCP) @@ -67,6 +67,12 @@ def alpha_max(self, gradient0): return np.max(np.abs(gradient0)) +# To add support of L2 regularization, one needs to decorate the penalty +@overload_with_l2 +class _TestL1(L1): + pass + + class L1_plus_L2(BasePenalty): """:math:`ell_1 + ell_2` penalty (aka ElasticNet penalty).""" diff --git a/skglm/tests/test_penalties.py b/skglm/tests/test_penalties.py index cafeb9d03..3a5b2dcf4 100644 --- a/skglm/tests/test_penalties.py +++ b/skglm/tests/test_penalties.py @@ -14,6 +14,9 @@ from skglm.solvers import AndersonCD, MultiTaskBCD, FISTA from skglm.utils.data import make_correlated_data +from skglm.penalties.separable import _TestL1 +from skglm.utils.jit_compilation import compiled_clone + n_samples = 20 n_features = 10 @@ -118,5 +121,40 @@ def test_nnls(fit_intercept): np.testing.assert_allclose(clf.intercept_, reg_nnls.intercept_) +def test_overload_with_l2_ElasticNet(): + lmbd = 0.2 + l1_ratio = 0.7 + + elastic_net = L1_plus_L2(lmbd, l1_ratio) + implicit_elastic_net = _TestL1(alpha=lmbd * l1_ratio, + l2_regularization=lmbd * (1 - l1_ratio)) + + n_feautures, ws_size = 5, 3 + stepsize = 0.8 + + rng = np.random.RandomState(425) + w = rng.randn(n_feautures) + grad = rng.randn(ws_size) + ws = rng.choice(n_feautures, size=ws_size, replace=False) + + x = w[2] + np.testing.assert_equal( + elastic_net.value(x), + implicit_elastic_net.value(x) + ) + np.testing.assert_equal( + elastic_net.prox_1d(x, stepsize, 0), + implicit_elastic_net.prox_1d(x, stepsize, 0) + ) + np.testing.assert_array_equal( + elastic_net.subdiff_distance(w, grad, ws), + implicit_elastic_net.subdiff_distance(w, grad, ws) + ) + + # This will raise an error as *args and **kwargs are not supported in numba + with pytest.raises(Exception, match=r"VAR_POSITIONAL.*unsupported.*jitclass"): + compiled_clone(implicit_elastic_net) + + if __name__ == "__main__": pass