|
6 | 6 |
|
7 | 7 | from skglm.penalties.base import BasePenalty
|
8 | 8 | from skglm.utils.prox_funcs import (
|
9 |
| - BST, prox_block_2_05, prox_SCAD, value_SCAD, prox_MCP, value_MCP) |
| 9 | + BST, ST_vec, prox_block_2_05, prox_SCAD, value_SCAD, prox_MCP, value_MCP) |
10 | 10 |
|
11 | 11 |
|
12 | 12 | class L2_1(BasePenalty):
|
@@ -382,3 +382,109 @@ def generalized_support(self, w):
|
382 | 382 | gsupp[g] = True
|
383 | 383 |
|
384 | 384 | return gsupp
|
| 385 | + |
| 386 | + |
| 387 | +class WeightedL1GroupL2(BasePenalty): |
| 388 | + r"""Weighted Group L2 penalty, aka sparse group Lasso. |
| 389 | +
|
| 390 | + The penalty reads |
| 391 | +
|
| 392 | + .. math:: |
| 393 | + sum_{g=1}^{n_"groups"} "weights"^1_g ||w_{[g]}|| + |
| 394 | + sum_{j=1}^{n_"features"} "weights"^2_j ||w_{j}|| |
| 395 | +
|
| 396 | + with :math:`w_{[g]}` being the coefficients of the g-th group and |
| 397 | +
|
| 398 | + Attributes |
| 399 | + ---------- |
| 400 | + alpha : float |
| 401 | + The regularization parameter. |
| 402 | +
|
| 403 | + weights_groups : array, shape (n_groups,) |
| 404 | + The penalization weights of the groups. |
| 405 | +
|
| 406 | + weights_features : array, shape (n_features,) |
| 407 | + The penalization weights of the features. |
| 408 | +
|
| 409 | + grp_indices : array, shape (n_features,) |
| 410 | + The group indices stacked contiguously |
| 411 | + ([grp1_indices, grp2_indices, ...]). |
| 412 | +
|
| 413 | + grp_ptr : array, shape (n_groups + 1,) |
| 414 | + The group pointers such that two consecutive elements delimit |
| 415 | + the indices of a group in ``grp_indices``. |
| 416 | +
|
| 417 | + """ |
| 418 | + |
| 419 | + def __init__( |
| 420 | + self, alpha, weights_groups, weights_features, grp_ptr, grp_indices): |
| 421 | + self.alpha = alpha |
| 422 | + self.grp_ptr, self.grp_indices = grp_ptr, grp_indices |
| 423 | + self.weights_groups = weights_groups |
| 424 | + self.weights_features = weights_features |
| 425 | + |
| 426 | + def get_spec(self): |
| 427 | + spec = ( |
| 428 | + ('alpha', float64), |
| 429 | + ('weights_groups', float64[:]), |
| 430 | + ('weights_features', float64[:]), |
| 431 | + ('grp_ptr', int32[:]), |
| 432 | + ('grp_indices', int32[:]), |
| 433 | + ) |
| 434 | + return spec |
| 435 | + |
| 436 | + def params_to_dict(self): |
| 437 | + return dict(alpha=self.alpha, weights_features=self.weights_features, |
| 438 | + weights_groups=self.weights_groups, grp_ptr=self.grp_ptr, |
| 439 | + grp_indices=self.grp_indices) |
| 440 | + |
| 441 | + def value(self, w): |
| 442 | + """Value of penalty at vector ``w``.""" |
| 443 | + grp_ptr, grp_indices = self.grp_ptr, self.grp_indices |
| 444 | + n_grp = len(grp_ptr) - 1 |
| 445 | + |
| 446 | + sum_penalty = 0. |
| 447 | + for g in range(n_grp): |
| 448 | + grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] |
| 449 | + w_g = w[grp_g_indices] |
| 450 | + |
| 451 | + sum_penalty += self.weights_groups[g] * norm(w_g) |
| 452 | + sum_penalty += np.sum(self.weights_features * np.abs(w)) |
| 453 | + |
| 454 | + return self.alpha * sum_penalty |
| 455 | + |
| 456 | + def prox_1group(self, value, stepsize, g): |
| 457 | + """Compute the proximal operator of group ``g``.""" |
| 458 | + res = ST_vec(value, self.alpha * stepsize * self.weights_features[g]) |
| 459 | + return BST(res, self.alpha * stepsize * self.weights_groups[g]) |
| 460 | + |
| 461 | + def subdiff_distance(self, w, grad_ws, ws): |
| 462 | + """Compute distance to the subdifferential at ``w`` of negative gradient. |
| 463 | +
|
| 464 | + Refer to :ref:`subdiff_positive_group_lasso` for details of the derivation. |
| 465 | +
|
| 466 | + Note: |
| 467 | + ---- |
| 468 | + ``grad_ws`` is a stacked array of gradients ``[grad_ws_1, grad_ws_2, ...]``. |
| 469 | + """ |
| 470 | + raise NotImplementedError("Too hard for now") |
| 471 | + |
| 472 | + def is_penalized(self, n_groups): |
| 473 | + return np.ones(n_groups, dtype=np.bool_) |
| 474 | + |
| 475 | + def generalized_support(self, w): |
| 476 | + grp_indices, grp_ptr = self.grp_indices, self.grp_ptr |
| 477 | + n_groups = len(grp_ptr) - 1 |
| 478 | + is_penalized = self.is_penalized(n_groups) |
| 479 | + |
| 480 | + gsupp = np.zeros(n_groups, dtype=np.bool_) |
| 481 | + for g in range(n_groups): |
| 482 | + if not is_penalized[g]: |
| 483 | + gsupp[g] = True |
| 484 | + continue |
| 485 | + |
| 486 | + grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] |
| 487 | + if np.any(w[grp_g_indices]): |
| 488 | + gsupp[g] = True |
| 489 | + |
| 490 | + return gsupp |
0 commit comments