@@ -65,6 +65,9 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
65
65
self .verbose = verbose
66
66
67
67
def solve (self , X , y , datafit , penalty , w_init = None , Xw_init = None ):
68
+ if self .ws_strategy not in ("subdiff" , "fixpoint" ):
69
+ raise ValueError ("ws_strategy must be `subdiff` or `fixpoint`, "
70
+ f"got { self .ws_strategy } ." )
68
71
dtype = X .dtype
69
72
n_samples , n_features = X .shape
70
73
fit_intercept = self .fit_intercept
@@ -206,9 +209,9 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
206
209
dtype = X .dtype
207
210
raw_hess = datafit .raw_hessian (y , Xw_epoch )
208
211
209
- lipschitz = np .zeros (len (ws ), dtype )
212
+ lipschitz_ws = np .zeros (len (ws ), dtype )
210
213
for idx , j in enumerate (ws ):
211
- lipschitz [idx ] = raw_hess @ X [:, j ] ** 2
214
+ lipschitz_ws [idx ] = raw_hess @ X [:, j ] ** 2
212
215
213
216
# for a less costly stopping criterion, we do not compute the exact gradient,
214
217
# but store each coordinate-wise gradient every time we update one coordinate
@@ -224,12 +227,12 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
224
227
for cd_iter in range (MAX_CD_ITER ):
225
228
for idx , j in enumerate (ws ):
226
229
# skip when X[:, j] == 0
227
- if lipschitz [idx ] == 0 :
230
+ if lipschitz_ws [idx ] == 0 :
228
231
continue
229
232
230
233
past_grads [idx ] = grad_ws [idx ] + X [:, j ] @ (raw_hess * X_delta_w_ws )
231
234
old_w_idx = w_ws [idx ]
232
- stepsize = 1 / lipschitz [idx ]
235
+ stepsize = 1 / lipschitz_ws [idx ]
233
236
234
237
w_ws [idx ] = penalty .prox_1d (
235
238
old_w_idx - stepsize * past_grads [idx ], stepsize , j )
@@ -253,7 +256,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
253
256
opt = penalty .subdiff_distance (current_w , past_grads , ws )
254
257
elif ws_strategy == "fixpoint" :
255
258
opt = dist_fix_point_cd (
256
- current_w , past_grads , lipschitz , datafit , penalty , ws
259
+ current_w , past_grads , lipschitz_ws , datafit , penalty , ws
257
260
)
258
261
stop_crit = np .max (opt )
259
262
@@ -264,7 +267,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
264
267
break
265
268
266
269
# descent direction
267
- return w_ws - w_epoch [ws_intercept ], X_delta_w_ws , lipschitz
270
+ return w_ws - w_epoch [ws_intercept ], X_delta_w_ws , lipschitz_ws
268
271
269
272
270
273
# sparse version of _descent_direction
@@ -275,10 +278,10 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
275
278
dtype = X_data .dtype
276
279
raw_hess = datafit .raw_hessian (y , Xw_epoch )
277
280
278
- lipschitz = np .zeros (len (ws ), dtype )
281
+ lipschitz_ws = np .zeros (len (ws ), dtype )
279
282
for idx , j in enumerate (ws ):
280
- # equivalent to: lipschitz [idx] += raw_hess * X[:, j] ** 2
281
- lipschitz [idx ] = _sparse_squared_weighted_norm (
283
+ # equivalent to: lipschitz_ws [idx] += raw_hess * X[:, j] ** 2
284
+ lipschitz_ws [idx ] = _sparse_squared_weighted_norm (
282
285
X_data , X_indptr , X_indices , j , raw_hess )
283
286
284
287
# see _descent_direction() comment
@@ -294,7 +297,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
294
297
for cd_iter in range (MAX_CD_ITER ):
295
298
for idx , j in enumerate (ws ):
296
299
# skip when X[:, j] == 0
297
- if lipschitz [idx ] == 0 :
300
+ if lipschitz_ws [idx ] == 0 :
298
301
continue
299
302
300
303
past_grads [idx ] = grad_ws [idx ]
@@ -303,7 +306,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
303
306
X_data , X_indptr , X_indices , j , X_delta_w_ws , raw_hess )
304
307
305
308
old_w_idx = w_ws [idx ]
306
- stepsize = 1 / lipschitz [idx ]
309
+ stepsize = 1 / lipschitz_ws [idx ]
307
310
308
311
w_ws [idx ] = penalty .prox_1d (
309
312
old_w_idx - stepsize * past_grads [idx ], stepsize , j )
@@ -328,7 +331,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
328
331
opt = penalty .subdiff_distance (current_w , past_grads , ws )
329
332
elif ws_strategy == "fixpoint" :
330
333
opt = dist_fix_point_cd (
331
- current_w , past_grads , lipschitz , datafit , penalty , ws
334
+ current_w , past_grads , lipschitz_ws , datafit , penalty , ws
332
335
)
333
336
stop_crit = np .max (opt )
334
337
@@ -339,7 +342,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
339
342
break
340
343
341
344
# descent direction
342
- return w_ws - w_epoch [ws_intercept ], X_delta_w_ws , lipschitz
345
+ return w_ws - w_epoch [ws_intercept ], X_delta_w_ws , lipschitz_ws
343
346
344
347
345
348
@njit
0 commit comments