1111from torch import optim
1212from torch .optim .lr_scheduler import _LRScheduler
1313
14- AVAI_SCH = {'single_step' , 'multi_step' , 'cosine' , 'warmup' , 'cosine_cycle' , 'reduce_on_plateau' , 'onecycle' }
14+ AVAI_SCH = {'single_step' , 'multi_step' , 'cosine' , 'warmup' , 'cosine_cycle' ,
15+ 'reduce_on_plateau_delayed' , 'reduce_on_plateau' , 'onecycle' }
1516
1617def build_lr_scheduler (optimizer , lr_scheduler , base_scheduler , ** kwargs ):
1718 if lr_scheduler == 'warmup' :
@@ -37,7 +38,8 @@ def _build_scheduler(optimizer,
3738 max_lr = 0.1 ,
3839 patience = 5 ,
3940 lr_decay_factor = 100 ,
40- pct_start = 0.3 ):
41+ pct_start = 0.3 ,
42+ epoch_delay = 0 ):
4143
4244 init_learning_rate = [param_group ['lr' ] for param_group in optimizer .param_groups ]
4345 if lr_scheduler not in AVAI_SCH :
@@ -96,7 +98,21 @@ def _build_scheduler(optimizer,
9698 lb_lr = [lr / lr_decay_factor for lr in init_learning_rate ]
9799 epoch_treshold = max (int (max_epoch * 0.75 ) - warmup , 1 ) # 75% of the training - warmup epochs
98100 scheduler = ReduceLROnPlateauV2 (optimizer , epoch_treshold , factor = gamma , patience = patience ,
99- threshold = 2e-4 , verbose = True , min_lr = lb_lr )
101+ threshold = 2e-4 , verbose = True , min_lr = min_lr )
102+ elif lr_scheduler == 'reduce_on_plateau_delayed' :
103+ if epoch_delay < 0 :
104+ raise ValueError (f'epoch_delay = { epoch_delay } should be greater than zero' )
105+
106+ if max_epoch < epoch_delay :
107+ raise ValueError (f'max_epoch param = { max_epoch } should be greater than'
108+ f' epoch_delay param = { epoch_delay } ' )
109+
110+ if epoch_delay < warmup :
111+ raise ValueError (f'warmap param = { warmup } should be less than'
112+ f' epoch_delay param = { epoch_delay } ' )
113+ epoch_treshold = max (int (max_epoch * 0.75 ) - epoch_delay , 1 ) # 75% of the training - skipped epochs
114+ scheduler = ReduceLROnPlateauV2Delayed (optimizer , epoch_treshold , epoch_delay , factor = gamma ,
115+ patience = patience , threshold = 2e-4 , verbose = True , min_lr = min_lr )
100116 else :
101117 raise ValueError ('Unknown scheduler: {}' .format (lr_scheduler ))
102118
@@ -275,3 +291,30 @@ class OneCycleLR(optim.lr_scheduler.OneCycleLR):
275291 @property
276292 def warmup_finished (self ):
277293 return self .last_epoch >= self ._schedule_phases [0 ]['end_step' ]
294+
295+
296+ class ReduceLROnPlateauV2Delayed (ReduceLROnPlateauV2 ):
297+ """
298+ ReduceOnPlateuV2 scheduler which starts working only
299+ after certain amount of epochs specified by epoch delay param.
300+ Useful when compression algorithms is applying to prevent
301+ lr drop before full model compression. Warmup included into epoch_delay.
302+ """
303+ def __init__ (self ,
304+ optimizer : optim .Optimizer ,
305+ epoch_treshold : int ,
306+ epoch_delay : int ,
307+ ** kwargs ) -> None :
308+
309+ super ().__init__ (optimizer , epoch_treshold , ** kwargs )
310+ self ._epoch_delay = epoch_delay
311+
312+ def step (self , metrics , epoch = None ):
313+ # If there was less than self._epoch_delay epochs
314+ # just update epochs counter
315+ if self .last_epoch <= self ._epoch_delay :
316+ if epoch is None :
317+ epoch = self .last_epoch + 1
318+ self .last_epoch = epoch
319+ else :
320+ super ().step (metrics , epoch = epoch )
0 commit comments