Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions tick/base/learner/learner_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from warnings import warn

import numpy as np
from tick.base import Base

from tick.base import Base
from tick.base_model import ModelLipschitz
from .learner_optim import LearnerOptim

Expand Down Expand Up @@ -167,11 +167,20 @@ def fit(self, X: object, y: np.array):
if self.step is None and self.solver in self._solvers_with_step:
if self.solver in self._solvers_with_linesearch:
self._solver_obj.linesearch = True
elif self.solver == 'svrg':
elif self.solver == 'svrg' or self.solver == 'saga':
L = self._model_obj.get_lip_max()
if self.penalty == 'l2':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if penalty is elastic-net ?

L += 1. / self.C
mun = min(2 * self._model_obj.n_samples / self.C, L)
self.step = 1. / (2 * L + mun)
else:
self.step = 1. / L

if isinstance(self._model_obj, ModelLipschitz):
self.step = 1. / self._model_obj.get_lip_max()
else:
warn('SVRG step needs to be tuned manually', RuntimeWarning)
warn('SVRG and SAGA steps needs to be tuned manually',
RuntimeWarning)
self.step = 1.
elif self.solver == 'sgd':
warn('SGD step needs to be tuned manually', RuntimeWarning)
Expand Down