Skip to content

ENH - Automatic support of L2 regulrization in Penalties #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions skglm/penalties/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from numba import float64
from abc import abstractmethod


Expand Down Expand Up @@ -60,3 +61,73 @@ def is_penalized(self, n_features):
@abstractmethod
def generalized_support(self, w):
r"""Return a mask which is True for coefficients in the generalized support."""


def overload_with_l2(klass):
"""Decorate a penalty class to add L2 regularization.

The resulting penalty reads

.. math::

"penalty"(w) + "l2"_"regularization" xx ||w||**2 / 2

Parameters
----------
klass : Penalty class
The penalty class to be overloaded with L2 regularization.

Returns
-------
klass : Penalty class
Penalty overloaded with L2 regularization.
"""
# keep ref to original methods
cls_constructor = klass.__init__
cls_prox_1d = klass.prox_1d
cls_value = klass.value
cls_subdiff_distance = klass. subdiff_distance
cls_params_to_dict = klass.params_to_dict
cls_get_spec = klass.get_spec

# implement new methods
def __init__(self, *args, l2_regularization=0., **kwargs):
cls_constructor(self, *args, **kwargs)
self.l2_regularization = l2_regularization

def prox_1d(self, value, stepsize, j):
if self.l2_regularization == 0.:
return cls_prox_1d(self, value, stepsize, j)

scale = 1 + stepsize * self.l2_regularization
return cls_prox_1d(self, value / scale, stepsize / scale, j)

def value(self, w):
l2_regularization = self.l2_regularization
if l2_regularization == 0.:
return cls_value(self, w)

return cls_value(self, w) + l2_regularization * 0.5 * w ** 2

def subdiff_distance(self, w, grad, ws):
if self.l2_regularization == 0.:
return cls_subdiff_distance(self, w, grad, ws)

return cls_subdiff_distance(self, w, grad + self.l2_regularization * w[ws], ws)

def get_spec(self):
return (('l2_regularization', float64), *cls_get_spec(self))

def params_to_dict(self):
return dict(l2_regularization=self.l2_regularization,
**cls_params_to_dict(self))

# override methods
klass.__init__ = __init__
klass.value = value
klass.prox_1d = prox_1d
klass.subdiff_distance = subdiff_distance
klass.get_spec = get_spec
klass.params_to_dict = params_to_dict

return klass
8 changes: 7 additions & 1 deletion skglm/penalties/separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from numba import float64
from numba.types import bool_

from skglm.penalties.base import BasePenalty
from skglm.penalties.base import BasePenalty, overload_with_l2
from skglm.utils.prox_funcs import (
ST, box_proj, prox_05, prox_2_3, prox_SCAD, value_SCAD, prox_MCP, value_MCP)

Expand Down Expand Up @@ -67,6 +67,12 @@ def alpha_max(self, gradient0):
return np.max(np.abs(gradient0))


# To add support of L2 regularization, one needs to decorate the penalty
@overload_with_l2
class _TestL1(L1):
pass


class L1_plus_L2(BasePenalty):
""":math:`ell_1 + ell_2` penalty (aka ElasticNet penalty)."""

Expand Down
38 changes: 38 additions & 0 deletions skglm/tests/test_penalties.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from skglm.solvers import AndersonCD, MultiTaskBCD, FISTA
from skglm.utils.data import make_correlated_data

from skglm.penalties.separable import _TestL1
from skglm.utils.jit_compilation import compiled_clone


n_samples = 20
n_features = 10
Expand Down Expand Up @@ -118,5 +121,40 @@ def test_nnls(fit_intercept):
np.testing.assert_allclose(clf.intercept_, reg_nnls.intercept_)


def test_overload_with_l2_ElasticNet():
lmbd = 0.2
l1_ratio = 0.7

elastic_net = L1_plus_L2(lmbd, l1_ratio)
implicit_elastic_net = _TestL1(alpha=lmbd * l1_ratio,
l2_regularization=lmbd * (1 - l1_ratio))

n_feautures, ws_size = 5, 3
stepsize = 0.8

rng = np.random.RandomState(425)
w = rng.randn(n_feautures)
grad = rng.randn(ws_size)
ws = rng.choice(n_feautures, size=ws_size, replace=False)

x = w[2]
np.testing.assert_equal(
elastic_net.value(x),
implicit_elastic_net.value(x)
)
np.testing.assert_equal(
elastic_net.prox_1d(x, stepsize, 0),
implicit_elastic_net.prox_1d(x, stepsize, 0)
)
np.testing.assert_array_equal(
elastic_net.subdiff_distance(w, grad, ws),
implicit_elastic_net.subdiff_distance(w, grad, ws)
)

# This will raise an error as *args and **kwargs are not supported in numba
with pytest.raises(Exception, match=r"VAR_POSITIONAL.*unsupported.*jitclass"):
compiled_clone(implicit_elastic_net)


if __name__ == "__main__":
pass