|
1 |
| -from abc import abstractmethod |
| 1 | +from abc import abstractmethod, ABC |
| 2 | +from skglm.utils.validation import check_attrs |
2 | 3 |
|
3 | 4 |
|
4 |
| -class BaseSolver(): |
5 |
| - """Base class for solvers.""" |
| 5 | +class BaseSolver(ABC): |
| 6 | + """Base class for solvers. |
| 7 | +
|
| 8 | + Attributes |
| 9 | + ---------- |
| 10 | + _datafit_required_attr : list |
| 11 | + List of attributes that must be implemented in Datafit. |
| 12 | +
|
| 13 | + _penalty_required_attr : list |
| 14 | + List of attributes that must be implemented in Penalty. |
| 15 | +
|
| 16 | + Notes |
| 17 | + ----- |
| 18 | + For required attributes, if an attribute is given as a list of attributes |
| 19 | + it means at least one of them should be implemented. |
| 20 | + For instance, if |
| 21 | +
|
| 22 | + _datafit_required_attr = ( |
| 23 | + "get_global_lipschitz", |
| 24 | + ("gradient", "gradient_scalar") |
| 25 | + ) |
| 26 | +
|
| 27 | + it mean datafit must implement the methods ``get_global_lipschitz`` |
| 28 | + and (``gradient`` or ``gradient_scaler``). |
| 29 | + """ |
| 30 | + |
| 31 | + _datafit_required_attr: list |
| 32 | + _penalty_required_attr: list |
6 | 33 |
|
7 | 34 | @abstractmethod
|
8 |
| - def solve(self, X, y, datafit, penalty, w_init, Xw_init): |
| 35 | + def _solve(self, X, y, datafit, penalty, w_init, Xw_init): |
9 | 36 | """Solve an optimization problem.
|
10 | 37 |
|
11 | 38 | Parameters
|
@@ -39,3 +66,51 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init):
|
39 | 66 | stop_crit : float
|
40 | 67 | Value of stopping criterion at convergence.
|
41 | 68 | """
|
| 69 | + |
| 70 | + def custom_checks(self, X, y, datafit, penalty): |
| 71 | + """Ensure the solver is suited for the `datafit` + `penalty` problem. |
| 72 | +
|
| 73 | + This method includes extra checks to perform |
| 74 | + aside from checking attributes compatibility. |
| 75 | +
|
| 76 | + Parameters |
| 77 | + ---------- |
| 78 | + X : array, shape (n_samples, n_features) |
| 79 | + Training data. |
| 80 | +
|
| 81 | + y : array, shape (n_samples,) |
| 82 | + Target values. |
| 83 | +
|
| 84 | + datafit : instance of BaseDatafit |
| 85 | + Datafit. |
| 86 | +
|
| 87 | + penalty : instance of BasePenalty |
| 88 | + Penalty. |
| 89 | + """ |
| 90 | + pass |
| 91 | + |
| 92 | + def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None, |
| 93 | + *, run_checks=True): |
| 94 | + """Solve the optimization problem after validating its compatibility. |
| 95 | +
|
| 96 | + A proxy of ``_solve`` method that implicitly ensures the compatibility |
| 97 | + of ``datafit`` and ``penalty`` with the solver. |
| 98 | +
|
| 99 | + Examples |
| 100 | + -------- |
| 101 | + >>> ... |
| 102 | + >>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty) |
| 103 | + """ |
| 104 | + if run_checks: |
| 105 | + self._validate(X, y, datafit, penalty) |
| 106 | + |
| 107 | + return self._solve(X, y, datafit, penalty, w_init, Xw_init) |
| 108 | + |
| 109 | + def _validate(self, X, y, datafit, penalty): |
| 110 | + # execute: `custom_checks` then check attributes |
| 111 | + self.custom_checks(X, y, datafit, penalty) |
| 112 | + |
| 113 | + # do not check for sparse support here, make the check at the solver level |
| 114 | + # some solvers like ProxNewton don't require methods for sparse support |
| 115 | + check_attrs(datafit, self, self._datafit_required_attr) |
| 116 | + check_attrs(penalty, self, self._penalty_required_attr) |
0 commit comments