Skip to content

Commit be08279

Browse files
authored
[Fix]: Fix load weight when change num_classes
1 parent 4ecfb1c commit be08279

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tools/train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pytorch_lightning as pl
2121
from pytorch_lightning.callbacks import ProgressBar
2222

23-
from nanodet.util import mkdir, Logger, cfg, load_config, convert_old_model
23+
from nanodet.util import mkdir, Logger, cfg, load_config, convert_old_model, load_model_weight
2424
from nanodet.data.collate import collate_function
2525
from nanodet.data.dataset import build_dataset
2626
from nanodet.trainer.task import TrainingTask
@@ -75,7 +75,7 @@ def main(args):
7575
warnings.warn('Warning! Old .pth checkpoint is deprecated. '
7676
'Convert the checkpoint with tools/convert_old_checkpoint.py ')
7777
ckpt = convert_old_model(ckpt)
78-
task.load_state_dict(ckpt['state_dict'], strict=False)
78+
load_model_weight(task.model, ckpt, logger)
7979
logger.log('Loaded model weight from {}'.format(cfg.schedule.load_model))
8080

8181
model_resume_path = os.path.join(cfg.save_dir, 'model_last.ckpt') if 'resume' in cfg.schedule else None

0 commit comments

Comments
 (0)