ENH flexible gram solver with penalty and using datafit#16
ENH flexible gram solver with penalty and using datafit#16mathurinm wants to merge 23 commits intoscikit-learn-contrib:mainfrom
Conversation
|
@mathurinm ready for a quick review ;) |
| if (isinstance(datafit, (Quadratic, Quadratic_32)) and n_samples > n_features | ||
| and n_features < 10_000) or solver in ("cd_gram", "fista"): | ||
| # Gram matrix must fit in memory hence the restriction n_features < 1e5 | ||
| if not isinstance(datafit, (Quadratic, Quadratic_32)): |
There was a problem hiding this comment.
I think this bit is unreachable because the check is already performed L155
There was a problem hiding this comment.
I've placed it because the first condition is "isinstance.... OR solver in ...". If the user manually inputs "cd_gram", I think we enter the if statement and I want to catch a wrong datafit, hence L158. Overkill maybe? Should we even expose solver? I think it is convenient for benchmarks.
WDYT?
There was a problem hiding this comment.
Ok I understood, thanks.
Maybe we can indent the first if, breaking line before or solver to make it more visible ?
There was a problem hiding this comment.
I tried to make it more obvious. WDYT?
|
|
||
| coefs : array, shape (n_features, n_alphas) | ||
| Coefficients along the path. | ||
| obj_out : array, shape (n_iter,) |
There was a problem hiding this comment.
do we really return this? or the optimality condition violation instead
There was a problem hiding this comment.
We do return this. See L371.
|
|
||
|
|
||
| @njit | ||
| def prox_vec(penalty, z, stepsize, n_features): |
There was a problem hiding this comment.
arf, I though we had penalty.prox
make this function private, remove n_features (access as z.shape[1])
we need a reflection on solvers, but probably all penalties will need to implement it. We can do so in basepenalty, but I fear looping over all coordinates will be slower than performing it in one step as ST_vec does
There was a problem hiding this comment.
In [16]: %%time
...: out = _prox_vec(pen, z, 0.01)
CPU times: user 28 µs, sys: 1e+03 ns, total: 29 µs
Wall time: 34.1 µs
In [17]: %%time
...: out2 = ST_vec(z, 0.01)
CPU times: user 23 µs, sys: 0 ns, total: 23 µs
Wall time: 25.7 µs
not a big difference, from my experiments. I tried with different thresholds.
There was a problem hiding this comment.
with @QB3 we had an issue a while ago on flashcd with finance where this caused a big overhead. Just to keep it in mind
Co-authored-by: mathurinm <mathurinm@users.noreply.github.com>
Co-authored-by: mathurinm <mathurinm@users.noreply.github.com>
…kglm into gram_penalty_nogroup
…enalty_nogroup
PABannier
left a comment
There was a problem hiding this comment.
Overall LGTM.
Tests are missing for the solvers though, I can write some if needed.
| return_n_iter : bool, optional | ||
| If True, number of iterations along the path are returned. | ||
|
|
||
| solver : ('cd_ws'|'cd_gram'|'fista'), optional |
There was a problem hiding this comment.
FISTA is not a CD solver, it's confusing to expose it to the user like this.
@mathurinm WDYT?
| @njit | ||
| def _cd_epoch_gram(XtX, grad, w, datafit, penalty, n_samples, n_features): | ||
| lc = datafit.lipschitz | ||
| for j in range(n_features): |
There was a problem hiding this comment.
since we have complete access to grad at each iteration, it would be interesting to use a greedy selection rule here: do not pick j cyclically, but instead take j = np.argmax(np.abs(grad))
One "epoch" in this setting would only be the update of n_features coordinates.
|
closing in favor of #59 |
This is a smaller version of #4 : only without groups, but reusing more code and supporting any penalty