Skip to content

Commit 7e9a477

Browse files
committed
fixup finetune problem
Summary: support finetune from the other model with different number of classes, and simplify calling way (#325) close #325 close #325
1 parent f496193 commit 7e9a477

File tree

3 files changed

+114
-41
lines changed

3 files changed

+114
-41
lines changed

fastreid/engine/defaults.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ def default_argument_parser():
4444
"""
4545
parser = argparse.ArgumentParser(description="fastreid Training")
4646
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
47-
parser.add_argument(
48-
"--finetune",
49-
action="store_true",
50-
help="whether to attempt to finetune from the trained model",
51-
)
5247
parser.add_argument(
5348
"--resume",
5449
action="store_true",
@@ -244,8 +239,13 @@ def __init__(self, cfg):
244239

245240
def resume_or_load(self, resume=True):
246241
"""
247-
If `resume==True`, and last checkpoint exists, resume from it.
248-
Otherwise, load a model specified by the config.
242+
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
243+
a `last_checkpoint` file), resume from the file. Resuming means loading all
244+
available states (eg. optimizer and scheduler) and update iteration counter
245+
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
246+
Otherwise, this is considered as an independent training. The method will load model
247+
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
248+
from iteration 0.
249249
Args:
250250
resume (bool): whether to do resume or not
251251
"""
@@ -468,7 +468,6 @@ def auto_scale_hyperparams(cfg, data_loader):
468468
because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
469469
so we need to convert specific hyper-param to training iterations.
470470
"""
471-
472471
cfg = cfg.clone()
473472
frozen = cfg.is_frozen()
474473
cfg.defrost()

fastreid/utils/checkpoint.py

+107-32
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
33

4-
import collections
54
import copy
65
import logging
76
import os
87
from collections import defaultdict
98
from typing import Any
9+
from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
1010

1111
import numpy as np
1212
import torch
@@ -17,6 +17,23 @@
1717
from fastreid.utils.file_io import PathManager
1818

1919

20+
class _IncompatibleKeys(
21+
NamedTuple(
22+
# pyre-fixme[10]: Name `IncompatibleKeys` is used but not defined.
23+
"IncompatibleKeys",
24+
[
25+
("missing_keys", List[str]),
26+
("unexpected_keys", List[str]),
27+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
28+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
29+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
30+
("incorrect_shapes", List[Tuple]),
31+
],
32+
)
33+
):
34+
pass
35+
36+
2037
class Checkpointer(object):
2138
"""
2239
A checkpointer that can save/load model as well as extra checkpointable
@@ -50,7 +67,9 @@ def __init__(
5067
self.save_dir = save_dir
5168
self.save_to_disk = save_to_disk
5269

53-
def save(self, name: str, **kwargs: dict):
70+
self.path_manager = PathManager
71+
72+
def save(self, name: str, **kwargs: Dict[str, str]):
5473
"""
5574
Dump model and checkpointables to a file.
5675
Args:
@@ -74,13 +93,15 @@ def save(self, name: str, **kwargs: dict):
7493
torch.save(data, f)
7594
self.tag_last_checkpoint(basename)
7695

77-
def load(self, path: str):
96+
def load(self, path: str, checkpointables: Optional[List[str]] = None) -> object:
7897
"""
7998
Load from the given checkpoint. When path points to network file, this
8099
function has to be called on all ranks.
81100
Args:
82101
path (str): path or url to the checkpoint. If empty, will not load
83102
anything.
103+
checkpointables (list): List of checkpointable names to load. If not
104+
specified (None), will load all the possible checkpointables.
84105
Returns:
85106
dict:
86107
extra data loaded from the checkpoint that has not been
@@ -89,21 +110,25 @@ def load(self, path: str):
89110
"""
90111
if not path:
91112
# no checkpoint provided
92-
self.logger.info(
93-
"No checkpoint found. Training model from scratch"
94-
)
113+
self.logger.info("No checkpoint found. Training model from scratch")
95114
return {}
96115
self.logger.info("Loading checkpoint from {}".format(path))
97116
if not os.path.isfile(path):
98-
path = PathManager.get_local_path(path)
117+
path = self.path_manager.get_local_path(path)
99118
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
100119

101120
checkpoint = self._load_file(path)
102-
self._load_model(checkpoint)
103-
for key, obj in self.checkpointables.items():
104-
if key in checkpoint:
121+
incompatible = self._load_model(checkpoint)
122+
if (
123+
incompatible is not None
124+
): # handle some existing subclasses that returns None
125+
self._log_incompatible_keys(incompatible)
126+
127+
for key in self.checkpointables if checkpointables is None else checkpointables:
128+
if key in checkpoint: # pyre-ignore
105129
self.logger.info("Loading {} from {}".format(key, path))
106-
obj.load_state_dict(checkpoint.pop(key))
130+
obj = self.checkpointables[key]
131+
obj.load_state_dict(checkpoint.pop(key)) # pyre-ignore
107132

108133
# return any further checkpoint data
109134
return checkpoint
@@ -158,7 +183,9 @@ def resume_or_load(self, path: str, *, resume: bool = True):
158183
"""
159184
if resume and self.has_checkpoint():
160185
path = self.get_checkpoint_file()
161-
return self.load(path)
186+
return self.load(path)
187+
else:
188+
return self.load(path, checkpointables=[])
162189

163190
def tag_last_checkpoint(self, last_filename_basename: str):
164191
"""
@@ -199,26 +226,40 @@ def _load_model(self, checkpoint: Any):
199226

200227
# work around https://github.com/pytorch/pytorch/issues/24139
201228
model_state_dict = self.model.state_dict()
229+
incorrect_shapes = []
202230
for k in list(checkpoint_state_dict.keys()):
203231
if k in model_state_dict:
204232
shape_model = tuple(model_state_dict[k].shape)
205233
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
206234
if shape_model != shape_checkpoint:
207-
self.logger.warning(
208-
"'{}' has shape {} in the checkpoint but {} in the "
209-
"model! Skipped.".format(
210-
k, shape_checkpoint, shape_model
211-
)
212-
)
235+
incorrect_shapes.append((k, shape_checkpoint, shape_model))
213236
checkpoint_state_dict.pop(k)
214237

215-
incompatible = self.model.load_state_dict(
216-
checkpoint_state_dict, strict=False
238+
incompatible = self.model.load_state_dict(checkpoint_state_dict, strict=False)
239+
return _IncompatibleKeys(
240+
missing_keys=incompatible.missing_keys,
241+
unexpected_keys=incompatible.unexpected_keys,
242+
incorrect_shapes=incorrect_shapes,
217243
)
244+
245+
def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None:
246+
"""
247+
Log information about the incompatible keys returned by ``_load_model``.
248+
"""
249+
for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes:
250+
self.logger.warning(
251+
"Skip loading parameter '{}' to the model due to incompatible "
252+
"shapes: {} in the checkpoint but {} in the "
253+
"model! You might want to double check if this is expected.".format(
254+
k, shape_checkpoint, shape_model
255+
)
256+
)
218257
if incompatible.missing_keys:
219-
self.logger.info(
220-
get_missing_parameters_message(incompatible.missing_keys)
258+
missing_keys = _filter_reused_missing_keys(
259+
self.model, incompatible.missing_keys
221260
)
261+
if missing_keys:
262+
self.logger.info(get_missing_parameters_message(missing_keys))
222263
if incompatible.unexpected_keys:
223264
self.logger.info(
224265
get_unexpected_parameters_message(incompatible.unexpected_keys)
@@ -297,7 +338,27 @@ def save(self, name: str, **kwargs: Any):
297338
self.checkpointer.save(name, **kwargs)
298339

299340

300-
def get_missing_parameters_message(keys: list):
341+
def _filter_reused_missing_keys(model: nn.Module, keys: List[str]) -> List[str]:
342+
"""
343+
Filter "missing keys" to not include keys that have been loaded with another name.
344+
"""
345+
keyset = set(keys)
346+
param_to_names = defaultdict(set) # param -> names that points to it
347+
for module_prefix, module in _named_modules_with_dup(model):
348+
for name, param in list(module.named_parameters(recurse=False)) + list(
349+
module.named_buffers(recurse=False) # pyre-ignore
350+
):
351+
full_name = (module_prefix + "." if module_prefix else "") + name
352+
param_to_names[param].add(full_name)
353+
for names in param_to_names.values():
354+
# if one name appears missing but its alias exists, then this
355+
# name is not considered missing
356+
if any(n in keyset for n in names) and not all(n in keyset for n in names):
357+
[keyset.remove(n) for n in names if n in keyset]
358+
return list(keyset)
359+
360+
361+
def get_missing_parameters_message(keys: List[str]) -> str:
301362
"""
302363
Get a logging-friendly message to report parameter names (keys) that are in
303364
the model but not found in a checkpoint.
@@ -307,14 +368,14 @@ def get_missing_parameters_message(keys: list):
307368
str: message.
308369
"""
309370
groups = _group_checkpoint_keys(keys)
310-
msg = "Some model parameters are not in the checkpoint:\n"
371+
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
311372
msg += "\n".join(
312373
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
313374
)
314375
return msg
315376

316377

317-
def get_unexpected_parameters_message(keys: list):
378+
def get_unexpected_parameters_message(keys: List[str]) -> str:
318379
"""
319380
Get a logging-friendly message to report parameter names (keys) that are in
320381
the checkpoint but not found in the model.
@@ -324,15 +385,14 @@ def get_unexpected_parameters_message(keys: list):
324385
str: message.
325386
"""
326387
groups = _group_checkpoint_keys(keys)
327-
msg = "The checkpoint contains parameters not used by the model:\n"
388+
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
328389
msg += "\n".join(
329-
" " + colored(k + _group_to_str(v), "magenta")
330-
for k, v in groups.items()
390+
" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
331391
)
332392
return msg
333393

334394

335-
def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
395+
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
336396
"""
337397
Strip the prefix in metadata, if any.
338398
Args:
@@ -349,7 +409,7 @@ def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
349409

350410
# also strip the prefix in metadata, if any..
351411
try:
352-
metadata = state_dict._metadata
412+
metadata = state_dict._metadata # pyre-ignore
353413
except AttributeError:
354414
pass
355415
else:
@@ -365,7 +425,7 @@ def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
365425
metadata[newkey] = metadata.pop(key)
366426

367427

368-
def _group_checkpoint_keys(keys: list):
428+
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
369429
"""
370430
Group keys based on common prefixes. A prefix is the string up to the final
371431
"." in each key.
@@ -386,7 +446,7 @@ def _group_checkpoint_keys(keys: list):
386446
return groups
387447

388448

389-
def _group_to_str(group: list):
449+
def _group_to_str(group: List[str]) -> str:
390450
"""
391451
Format a group of parameter name suffixes into a loggable string.
392452
Args:
@@ -401,3 +461,18 @@ def _group_to_str(group: list):
401461
return "." + group[0]
402462

403463
return ".{" + ", ".join(group) + "}"
464+
465+
466+
def _named_modules_with_dup(
467+
model: nn.Module, prefix: str = ""
468+
) -> Iterable[Tuple[str, nn.Module]]:
469+
"""
470+
The same as `model.named_modules()`, except that it includes
471+
duplicated modules that have more than one name.
472+
"""
473+
yield prefix, model
474+
for name, module in model._modules.items(): # pyre-ignore
475+
if module is None:
476+
continue
477+
submodule_prefix = prefix + ("." if prefix else "") + name
478+
yield from _named_modules_with_dup(module, submodule_prefix)

tools/train_net.py

-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def main(args):
4040
return res
4141

4242
trainer = DefaultTrainer(cfg)
43-
if args.finetune: Checkpointer(trainer.model).load(cfg.MODEL.WEIGHTS) # load trained model to funetune
4443

4544
trainer.resume_or_load(resume=args.resume)
4645
return trainer.train()

0 commit comments

Comments
 (0)