Skip to content

Commit ccc6344

Browse files
FIX ProxNewton solver with fixpoint strategy (#259)
Co-authored-by: Badr-MOUFAD <[email protected]>
1 parent 9682660 commit ccc6344

File tree

4 files changed

+34
-29
lines changed

4 files changed

+34
-29
lines changed

skglm/solvers/anderson_cd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
184184
opt_ws = penalty.subdiff_distance(w[:n_features], grad_ws, ws)
185185
elif self.ws_strategy == "fixpoint":
186186
opt_ws = dist_fix_point_cd(
187-
w[:n_features], grad_ws, lipschitz, datafit, penalty, ws
187+
w[:n_features], grad_ws, lipschitz[ws], datafit, penalty, ws
188188
)
189189

190190
stop_crit_in = np.max(opt_ws)

skglm/solvers/common.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
@njit
6-
def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
6+
def dist_fix_point_cd(w, grad_ws, lipschitz_ws, datafit, penalty, ws):
77
"""Compute the violation of the fixed point iterate scheme.
88
99
Parameters
@@ -14,16 +14,16 @@ def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
1414
grad_ws : array, shape (ws_size,)
1515
Gradient restricted to the working set.
1616
17-
lipschitz : array, shape (n_features,)
18-
Coordinatewise gradient Lipschitz constants.
17+
lipschitz_ws : array, shape (len(ws),)
18+
Coordinatewise gradient Lipschitz constants, restricted to working set.
1919
2020
datafit: instance of BaseDatafit
2121
Datafit.
2222
2323
penalty: instance of BasePenalty
2424
Penalty.
2525
26-
ws : array, shape (ws_size,)
26+
ws : array, shape (len(ws),)
2727
The working set.
2828
2929
Returns
@@ -34,10 +34,10 @@ def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
3434
dist = np.zeros(ws.shape[0], dtype=w.dtype)
3535

3636
for idx, j in enumerate(ws):
37-
if lipschitz[j] == 0.:
37+
if lipschitz_ws[idx] == 0.:
3838
continue
3939

40-
step_j = 1 / lipschitz[j]
40+
step_j = 1 / lipschitz_ws[idx]
4141
dist[idx] = np.abs(
4242
w[j] - penalty.prox_1d(w[j] - step_j * grad_ws[idx], step_j, j)
4343
)

skglm/solvers/multitask_bcd.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
6666
if self.ws_strategy == "subdiff":
6767
opt = penalty.subdiff_distance(W, grad, all_feats)
6868
elif self.ws_strategy == "fixpoint":
69-
opt = dist_fix_point_bcd(W, grad, datafit, penalty, all_feats)
69+
opt = dist_fix_point_bcd(
70+
W, grad, lipschitz, datafit, penalty, all_feats
71+
)
7072
stop_crit = np.max(opt)
7173
if self.verbose:
7274
print(f"Stopping criterion max violation: {stop_crit:.2e}")
@@ -151,7 +153,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
151153
opt_ws = penalty.subdiff_distance(W, grad_ws, ws)
152154
elif self.ws_strategy == "fixpoint":
153155
opt_ws = dist_fix_point_bcd(
154-
W, grad_ws, lipschitz, datafit, penalty, ws
156+
W, grad_ws, lipschitz[ws], datafit, penalty, ws
155157
)
156158

157159
stop_crit_in = np.max(opt_ws)
@@ -231,27 +233,27 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False)
231233

232234

233235
@njit
234-
def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws):
236+
def dist_fix_point_bcd(W, grad_ws, lipschitz_ws, datafit, penalty, ws):
235237
"""Compute the violation of the fixed point iterate schema.
236238
237239
Parameters
238240
----------
239241
W : array, shape (n_features, n_tasks)
240242
Coefficient matrix.
241243
242-
grad_ws : array, shape (ws_size, n_tasks)
244+
grad_ws : array, shape (len(ws), n_tasks)
243245
Gradient restricted to the working set.
244246
245247
datafit: instance of BaseMultiTaskDatafit
246248
Datafit.
247249
248-
lipschitz : array, shape (n_features,)
249-
Blockwise gradient Lipschitz constants.
250+
lipschitz_ws : array, shape (len(ws),)
251+
Blockwise gradient Lipschitz constants, restricted to working set.
250252
251253
penalty: instance of BasePenalty
252254
Penalty.
253255
254-
ws : array, shape (ws_size,)
256+
ws : array, shape (len(ws),)
255257
The working set.
256258
257259
Returns
@@ -262,10 +264,10 @@ def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws):
262264
dist = np.zeros(ws.shape[0])
263265

264266
for idx, j in enumerate(ws):
265-
if lipschitz[j] == 0.:
267+
if lipschitz_ws[idx] == 0.:
266268
continue
267269

268-
step_j = 1 / lipschitz[j]
270+
step_j = 1 / lipschitz_ws[idx]
269271
dist[idx] = norm(
270272
W[j] - penalty.prox_1feat(W[j] - step_j * grad_ws[idx], step_j, j)
271273
)

skglm/solvers/prox_newton.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
6565
self.verbose = verbose
6666

6767
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
68+
if self.ws_strategy not in ("subdiff", "fixpoint"):
69+
raise ValueError("ws_strategy must be `subdiff` or `fixpoint`, "
70+
f"got {self.ws_strategy}.")
6871
dtype = X.dtype
6972
n_samples, n_features = X.shape
7073
fit_intercept = self.fit_intercept
@@ -206,9 +209,9 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
206209
dtype = X.dtype
207210
raw_hess = datafit.raw_hessian(y, Xw_epoch)
208211

209-
lipschitz = np.zeros(len(ws), dtype)
212+
lipschitz_ws = np.zeros(len(ws), dtype)
210213
for idx, j in enumerate(ws):
211-
lipschitz[idx] = raw_hess @ X[:, j] ** 2
214+
lipschitz_ws[idx] = raw_hess @ X[:, j] ** 2
212215

213216
# for a less costly stopping criterion, we do not compute the exact gradient,
214217
# but store each coordinate-wise gradient every time we update one coordinate
@@ -224,12 +227,12 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
224227
for cd_iter in range(MAX_CD_ITER):
225228
for idx, j in enumerate(ws):
226229
# skip when X[:, j] == 0
227-
if lipschitz[idx] == 0:
230+
if lipschitz_ws[idx] == 0:
228231
continue
229232

230233
past_grads[idx] = grad_ws[idx] + X[:, j] @ (raw_hess * X_delta_w_ws)
231234
old_w_idx = w_ws[idx]
232-
stepsize = 1 / lipschitz[idx]
235+
stepsize = 1 / lipschitz_ws[idx]
233236

234237
w_ws[idx] = penalty.prox_1d(
235238
old_w_idx - stepsize * past_grads[idx], stepsize, j)
@@ -253,7 +256,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
253256
opt = penalty.subdiff_distance(current_w, past_grads, ws)
254257
elif ws_strategy == "fixpoint":
255258
opt = dist_fix_point_cd(
256-
current_w, past_grads, lipschitz, datafit, penalty, ws
259+
current_w, past_grads, lipschitz_ws, datafit, penalty, ws
257260
)
258261
stop_crit = np.max(opt)
259262

@@ -264,7 +267,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
264267
break
265268

266269
# descent direction
267-
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz
270+
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz_ws
268271

269272

270273
# sparse version of _descent_direction
@@ -275,10 +278,10 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
275278
dtype = X_data.dtype
276279
raw_hess = datafit.raw_hessian(y, Xw_epoch)
277280

278-
lipschitz = np.zeros(len(ws), dtype)
281+
lipschitz_ws = np.zeros(len(ws), dtype)
279282
for idx, j in enumerate(ws):
280-
# equivalent to: lipschitz[idx] += raw_hess * X[:, j] ** 2
281-
lipschitz[idx] = _sparse_squared_weighted_norm(
283+
# equivalent to: lipschitz_ws[idx] += raw_hess * X[:, j] ** 2
284+
lipschitz_ws[idx] = _sparse_squared_weighted_norm(
282285
X_data, X_indptr, X_indices, j, raw_hess)
283286

284287
# see _descent_direction() comment
@@ -294,7 +297,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
294297
for cd_iter in range(MAX_CD_ITER):
295298
for idx, j in enumerate(ws):
296299
# skip when X[:, j] == 0
297-
if lipschitz[idx] == 0:
300+
if lipschitz_ws[idx] == 0:
298301
continue
299302

300303
past_grads[idx] = grad_ws[idx]
@@ -303,7 +306,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
303306
X_data, X_indptr, X_indices, j, X_delta_w_ws, raw_hess)
304307

305308
old_w_idx = w_ws[idx]
306-
stepsize = 1 / lipschitz[idx]
309+
stepsize = 1 / lipschitz_ws[idx]
307310

308311
w_ws[idx] = penalty.prox_1d(
309312
old_w_idx - stepsize * past_grads[idx], stepsize, j)
@@ -328,7 +331,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
328331
opt = penalty.subdiff_distance(current_w, past_grads, ws)
329332
elif ws_strategy == "fixpoint":
330333
opt = dist_fix_point_cd(
331-
current_w, past_grads, lipschitz, datafit, penalty, ws
334+
current_w, past_grads, lipschitz_ws, datafit, penalty, ws
332335
)
333336
stop_crit = np.max(opt)
334337

@@ -339,7 +342,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
339342
break
340343

341344
# descent direction
342-
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz
345+
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz_ws
343346

344347

345348
@njit

0 commit comments

Comments
 (0)