diff --git a/tick/base/learner/learner_glm.py b/tick/base/learner/learner_glm.py index 1a43d03ff..c688e5e93 100644 --- a/tick/base/learner/learner_glm.py +++ b/tick/base/learner/learner_glm.py @@ -3,8 +3,8 @@ from warnings import warn import numpy as np -from tick.base import Base +from tick.base import Base from tick.base_model import ModelLipschitz from .learner_optim import LearnerOptim @@ -167,11 +167,20 @@ def fit(self, X: object, y: np.array): if self.step is None and self.solver in self._solvers_with_step: if self.solver in self._solvers_with_linesearch: self._solver_obj.linesearch = True - elif self.solver == 'svrg': + elif self.solver == 'svrg' or self.solver == 'saga': + L = self._model_obj.get_lip_max() + if self.penalty == 'l2': + L += 1. / self.C + mun = min(2 * self._model_obj.n_samples / self.C, L) + self.step = 1. / (2 * L + mun) + else: + self.step = 1. / L + if isinstance(self._model_obj, ModelLipschitz): self.step = 1. / self._model_obj.get_lip_max() else: - warn('SVRG step needs to be tuned manually', RuntimeWarning) + warn('SVRG and SAGA steps needs to be tuned manually', + RuntimeWarning) self.step = 1. elif self.solver == 'sgd': warn('SGD step needs to be tuned manually', RuntimeWarning)