Skip to content

Commit 4cd6524

Browse files
committed
add lr_step
1 parent 1eeeb13 commit 4cd6524

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

network/ethec_experiments.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,12 @@ def __init__(self, data_loaders, labelmap, criterion, lr,
212212
load_wt=False,
213213
model_name=None,
214214
optimizer_method='adam',
215-
use_grayscale=False):
215+
use_grayscale=False,
216+
lr_step=[]):
216217

217218
CIFAR10.__init__(self, data_loaders, labelmap, criterion, lr, batch_size, evaluator, experiment_name,
218219
experiment_dir, n_epochs, eval_interval, feature_extracting, use_pretrained,
219-
load_wt, model_name, optimizer_method)
220+
load_wt, model_name, optimizer_method, lr_step=lr_step)
220221

221222
if use_grayscale:
222223
if model_name in ['alexnet', 'vgg']:
@@ -421,7 +422,8 @@ def ETHEC_train_model(arguments):
421422
load_wt=arguments.resume,
422423
model_name=arguments.model,
423424
optimizer_method=arguments.optimizer_method,
424-
use_grayscale=arguments.use_grayscale)
425+
use_grayscale=arguments.use_grayscale,
426+
lr_step=arguments.lr_step)
425427
ETHEC_trainer.prepare_model()
426428
#if arguments.use_2d and arguments.resume:
427429
# ETHEC_trainer.plot_label_representations()
@@ -458,6 +460,7 @@ def ETHEC_train_model(arguments):
458460
required=True)
459461
parser.add_argument("--level_weights", help='List of weights for each level', nargs=4, default=None, type=float)
460462
parser.add_argument("--use_2d", help='Use model with 2d features', action='store_true')
463+
parser.add_argument("--lr_step", help='List of epochs to make multiple lr by 0.1', nargs='*', default=[], type=int)
461464
args = parser.parse_args()
462465

463466
ETHEC_train_model(args)

network/experiment.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
class Experiment:
2424

2525
def __init__(self, model, dataloaders, criterion, classes, experiment_name, n_epochs, eval_interval, batch_size,
26-
exp_dir, load_wt, evaluator):
26+
exp_dir, load_wt, evaluator, lr_step=[]):
2727
self.epoch = 0
2828
self.exp_dir = exp_dir
2929
self.load_wt = load_wt
@@ -52,6 +52,8 @@ def __init__(self, model, dataloaders, criterion, classes, experiment_name, n_ep
5252

5353
self.writer = SummaryWriter(log_dir=os.path.join(self.log_dir, 'tensorboard'))
5454

55+
self.lr_step = lr_step
56+
5557
@staticmethod
5658
def make_dir_if_non_existent(dir):
5759
if not os.path.exists(dir):
@@ -152,6 +154,7 @@ def pass_samples(self, phase, save_to_tensorboard=True):
152154

153155
def run_model(self, optimizer):
154156
self.optimizer = optimizer
157+
scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.lr_step, gamma=0.1)
155158

156159
if self.load_wt:
157160
self.find_existing_weights()
@@ -174,6 +177,8 @@ def run_model(self, optimizer):
174177
self.pass_samples(phase='test')
175178
self.eval.disable_plotting()
176179

180+
scheduler.step()
181+
177182
time_elapsed = time.time() - since
178183
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
179184
print('Best val score: {:4f}'.format(self.best_score))

network/finetuner.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def __init__(self, data_loaders, labelmap, criterion, lr,
9898
use_pretrained=True,
9999
load_wt=False,
100100
model_name=None,
101-
optimizer_method='adam'):
101+
optimizer_method='adam',
102+
lr_step=[]):
102103

103104
self.classes = labelmap.classes
104105
self.n_classes = labelmap.n_classes
@@ -111,6 +112,7 @@ def __init__(self, data_loaders, labelmap, criterion, lr,
111112
self.optimal_thresholds = np.zeros(self.n_classes)
112113
self.optimizer_method = optimizer_method
113114
self.labelmap = labelmap
115+
self.lr_step = lr_step
114116

115117
if model_name == 'alexnet':
116118
model = models.alexnet(pretrained=use_pretrained)
@@ -127,7 +129,7 @@ def __init__(self, data_loaders, labelmap, criterion, lr,
127129

128130
Experiment.__init__(self, model, data_loaders, criterion, self.classes, experiment_name, n_epochs,
129131
eval_interval,
130-
batch_size, experiment_dir, load_wt, evaluator)
132+
batch_size, experiment_dir, load_wt, evaluator, lr_step=self.lr_step)
131133
self.model_name = model_name
132134

133135
def prepare_model(self, loading=False):

0 commit comments

Comments
 (0)