Skip to content

Commit d54bcbe

Browse files
hhaAndroidlkm2835Czm369ohwi
authored andcommitted
Support to collect the best models (open-mmlab#6560)
* Fix mosaic repr typo (open-mmlab#6523) * Include mmflow in readme (open-mmlab#6545) * Include mmflow in readme * Include mmflow in README_zh-CN * Add mmflow url into the document menu in docs/conf.py and docs_zh-CN/conf.py. * Make OHEM work with seesaw loss (open-mmlab#6514) * update * support gather best model Co-authored-by: Kyungmin Lee <[email protected]> Co-authored-by: Czm369 <[email protected]> Co-authored-by: ohwi <[email protected]>
1 parent d113f5d commit d54bcbe

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

.dev_scripts/gather_models.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ def get_final_epoch(config):
5353
return cfg.runner.max_epochs
5454

5555

56+
def get_best_epoch(exp_dir):
57+
best_epoch_full_path = list(
58+
sorted(glob.glob(osp.join(exp_dir, 'best_*.pth'))))[-1]
59+
best_epoch_model_path = best_epoch_full_path.split('/')[-1]
60+
best_epoch = best_epoch_model_path.split('_')[-1].split('.')[0]
61+
return best_epoch_model_path, int(best_epoch)
62+
63+
5664
def get_real_epoch(config):
5765
cfg = mmcv.Config.fromfile('./configs/' + config)
5866
epoch = cfg.runner.max_epochs
@@ -160,6 +168,10 @@ def parse_args():
160168
help='root path of benchmarked models to be gathered')
161169
parser.add_argument(
162170
'out', type=str, help='output path of gathered models to be stored')
171+
parser.add_argument(
172+
'--best',
173+
action='store_true',
174+
help='whether to gather the best model.')
163175

164176
args = parser.parse_args()
165177
return args
@@ -187,10 +199,13 @@ def main():
187199
for used_config in used_configs:
188200
exp_dir = osp.join(models_root, used_config)
189201
# check whether the exps is finished
190-
final_epoch = get_final_epoch(used_config)
191-
final_model = 'epoch_{}.pth'.format(final_epoch)
192-
model_path = osp.join(exp_dir, final_model)
202+
if args.best is True:
203+
final_model, final_epoch = get_best_epoch(exp_dir)
204+
else:
205+
final_epoch = get_final_epoch(used_config)
206+
final_model = 'epoch_{}.pth'.format(final_epoch)
193207

208+
model_path = osp.join(exp_dir, final_model)
194209
# skip if the model is still training
195210
if not osp.exists(model_path):
196211
continue
@@ -221,6 +236,7 @@ def main():
221236
results=model_performance,
222237
epochs=final_epoch,
223238
model_time=model_time,
239+
final_model=final_model,
224240
log_json_path=osp.split(log_json_path)[-1]))
225241

226242
# publish model for each checkpoint
@@ -234,7 +250,7 @@ def main():
234250
model_name += '_' + model['model_time']
235251
publish_model_path = osp.join(model_publish_dir, model_name)
236252
trained_model_path = osp.join(models_root, model['config'],
237-
'epoch_{}.pth'.format(model['epochs']))
253+
model['final_model'])
238254

239255
# convert model
240256
final_model_path = process_checkpoint(trained_model_path,
@@ -254,9 +270,9 @@ def main():
254270
config_path = osp.join(
255271
'configs',
256272
config_path) if 'configs' not in config_path else config_path
257-
target_cconfig_path = osp.split(config_path)[-1]
258-
shutil.copy(config_path,
259-
osp.join(model_publish_dir, target_cconfig_path))
273+
target_config_path = osp.split(config_path)[-1]
274+
shutil.copy(config_path, osp.join(model_publish_dir,
275+
target_config_path))
260276

261277
model['model_path'] = final_model_path
262278
publish_model_infos.append(model)

0 commit comments

Comments
 (0)