Skip to content

Commit 2b57ce2

Browse files
committed
Set default step for GLM and l2 penalization to a better value
Useful for comparison with scikit.
1 parent f34b33a commit 2b57ce2

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tick/base/learner/learner_glm.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from warnings import warn
44

55
import numpy as np
6-
from tick.base import Base
76

7+
from tick.base import Base
88
from tick.base_model import ModelLipschitz
99
from .learner_optim import LearnerOptim
1010

@@ -167,11 +167,20 @@ def fit(self, X: object, y: np.array):
167167
if self.step is None and self.solver in self._solvers_with_step:
168168
if self.solver in self._solvers_with_linesearch:
169169
self._solver_obj.linesearch = True
170-
elif self.solver == 'svrg':
170+
elif self.solver == 'svrg' or self.solver == 'saga':
171+
L = self._model_obj.get_lip_max()
172+
if self.penalty == 'l2':
173+
L += 1. / self.C
174+
mun = min(2 * self._model_obj.n_samples / self.C, L)
175+
self.step = 1. / (2 * L + mun)
176+
else:
177+
self.step = 1. / L
178+
171179
if isinstance(self._model_obj, ModelLipschitz):
172180
self.step = 1. / self._model_obj.get_lip_max()
173181
else:
174-
warn('SVRG step needs to be tuned manually', RuntimeWarning)
182+
warn('SVRG and SAGA steps needs to be tuned manually',
183+
RuntimeWarning)
175184
self.step = 1.
176185
elif self.solver == 'sgd':
177186
warn('SGD step needs to be tuned manually', RuntimeWarning)

0 commit comments

Comments
 (0)