@@ -212,11 +212,12 @@ def __init__(self, data_loaders, labelmap, criterion, lr,
212
212
load_wt = False ,
213
213
model_name = None ,
214
214
optimizer_method = 'adam' ,
215
- use_grayscale = False ):
215
+ use_grayscale = False ,
216
+ lr_step = []):
216
217
217
218
CIFAR10 .__init__ (self , data_loaders , labelmap , criterion , lr , batch_size , evaluator , experiment_name ,
218
219
experiment_dir , n_epochs , eval_interval , feature_extracting , use_pretrained ,
219
- load_wt , model_name , optimizer_method )
220
+ load_wt , model_name , optimizer_method , lr_step = lr_step )
220
221
221
222
if use_grayscale :
222
223
if model_name in ['alexnet' , 'vgg' ]:
@@ -421,7 +422,8 @@ def ETHEC_train_model(arguments):
421
422
load_wt = arguments .resume ,
422
423
model_name = arguments .model ,
423
424
optimizer_method = arguments .optimizer_method ,
424
- use_grayscale = arguments .use_grayscale )
425
+ use_grayscale = arguments .use_grayscale ,
426
+ lr_step = arguments .lr_step )
425
427
ETHEC_trainer .prepare_model ()
426
428
#if arguments.use_2d and arguments.resume:
427
429
# ETHEC_trainer.plot_label_representations()
@@ -458,6 +460,7 @@ def ETHEC_train_model(arguments):
458
460
required = True )
459
461
parser .add_argument ("--level_weights" , help = 'List of weights for each level' , nargs = 4 , default = None , type = float )
460
462
parser .add_argument ("--use_2d" , help = 'Use model with 2d features' , action = 'store_true' )
463
+ parser .add_argument ("--lr_step" , help = 'List of epochs to make multiple lr by 0.1' , nargs = '*' , default = [], type = int )
461
464
args = parser .parse_args ()
462
465
463
466
ETHEC_train_model (args )
0 commit comments