1
1
#!/usr/bin/env python3
2
2
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
3
4
- import collections
5
4
import copy
6
5
import logging
7
6
import os
8
7
from collections import defaultdict
9
8
from typing import Any
9
+ from typing import Optional , List , Dict , NamedTuple , Tuple , Iterable
10
10
11
11
import numpy as np
12
12
import torch
17
17
from fastreid .utils .file_io import PathManager
18
18
19
19
20
+ class _IncompatibleKeys (
21
+ NamedTuple (
22
+ # pyre-fixme[10]: Name `IncompatibleKeys` is used but not defined.
23
+ "IncompatibleKeys" ,
24
+ [
25
+ ("missing_keys" , List [str ]),
26
+ ("unexpected_keys" , List [str ]),
27
+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
28
+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
29
+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
30
+ ("incorrect_shapes" , List [Tuple ]),
31
+ ],
32
+ )
33
+ ):
34
+ pass
35
+
36
+
20
37
class Checkpointer (object ):
21
38
"""
22
39
A checkpointer that can save/load model as well as extra checkpointable
@@ -50,7 +67,9 @@ def __init__(
50
67
self .save_dir = save_dir
51
68
self .save_to_disk = save_to_disk
52
69
53
- def save (self , name : str , ** kwargs : dict ):
70
+ self .path_manager = PathManager
71
+
72
+ def save (self , name : str , ** kwargs : Dict [str , str ]):
54
73
"""
55
74
Dump model and checkpointables to a file.
56
75
Args:
@@ -74,13 +93,15 @@ def save(self, name: str, **kwargs: dict):
74
93
torch .save (data , f )
75
94
self .tag_last_checkpoint (basename )
76
95
77
- def load (self , path : str ) :
96
+ def load (self , path : str , checkpointables : Optional [ List [ str ]] = None ) -> object :
78
97
"""
79
98
Load from the given checkpoint. When path points to network file, this
80
99
function has to be called on all ranks.
81
100
Args:
82
101
path (str): path or url to the checkpoint. If empty, will not load
83
102
anything.
103
+ checkpointables (list): List of checkpointable names to load. If not
104
+ specified (None), will load all the possible checkpointables.
84
105
Returns:
85
106
dict:
86
107
extra data loaded from the checkpoint that has not been
@@ -89,21 +110,25 @@ def load(self, path: str):
89
110
"""
90
111
if not path :
91
112
# no checkpoint provided
92
- self .logger .info (
93
- "No checkpoint found. Training model from scratch"
94
- )
113
+ self .logger .info ("No checkpoint found. Training model from scratch" )
95
114
return {}
96
115
self .logger .info ("Loading checkpoint from {}" .format (path ))
97
116
if not os .path .isfile (path ):
98
- path = PathManager .get_local_path (path )
117
+ path = self . path_manager .get_local_path (path )
99
118
assert os .path .isfile (path ), "Checkpoint {} not found!" .format (path )
100
119
101
120
checkpoint = self ._load_file (path )
102
- self ._load_model (checkpoint )
103
- for key , obj in self .checkpointables .items ():
104
- if key in checkpoint :
121
+ incompatible = self ._load_model (checkpoint )
122
+ if (
123
+ incompatible is not None
124
+ ): # handle some existing subclasses that returns None
125
+ self ._log_incompatible_keys (incompatible )
126
+
127
+ for key in self .checkpointables if checkpointables is None else checkpointables :
128
+ if key in checkpoint : # pyre-ignore
105
129
self .logger .info ("Loading {} from {}" .format (key , path ))
106
- obj .load_state_dict (checkpoint .pop (key ))
130
+ obj = self .checkpointables [key ]
131
+ obj .load_state_dict (checkpoint .pop (key )) # pyre-ignore
107
132
108
133
# return any further checkpoint data
109
134
return checkpoint
@@ -158,7 +183,9 @@ def resume_or_load(self, path: str, *, resume: bool = True):
158
183
"""
159
184
if resume and self .has_checkpoint ():
160
185
path = self .get_checkpoint_file ()
161
- return self .load (path )
186
+ return self .load (path )
187
+ else :
188
+ return self .load (path , checkpointables = [])
162
189
163
190
def tag_last_checkpoint (self , last_filename_basename : str ):
164
191
"""
@@ -199,26 +226,40 @@ def _load_model(self, checkpoint: Any):
199
226
200
227
# work around https://github.com/pytorch/pytorch/issues/24139
201
228
model_state_dict = self .model .state_dict ()
229
+ incorrect_shapes = []
202
230
for k in list (checkpoint_state_dict .keys ()):
203
231
if k in model_state_dict :
204
232
shape_model = tuple (model_state_dict [k ].shape )
205
233
shape_checkpoint = tuple (checkpoint_state_dict [k ].shape )
206
234
if shape_model != shape_checkpoint :
207
- self .logger .warning (
208
- "'{}' has shape {} in the checkpoint but {} in the "
209
- "model! Skipped." .format (
210
- k , shape_checkpoint , shape_model
211
- )
212
- )
235
+ incorrect_shapes .append ((k , shape_checkpoint , shape_model ))
213
236
checkpoint_state_dict .pop (k )
214
237
215
- incompatible = self .model .load_state_dict (
216
- checkpoint_state_dict , strict = False
238
+ incompatible = self .model .load_state_dict (checkpoint_state_dict , strict = False )
239
+ return _IncompatibleKeys (
240
+ missing_keys = incompatible .missing_keys ,
241
+ unexpected_keys = incompatible .unexpected_keys ,
242
+ incorrect_shapes = incorrect_shapes ,
217
243
)
244
+
245
+ def _log_incompatible_keys (self , incompatible : _IncompatibleKeys ) -> None :
246
+ """
247
+ Log information about the incompatible keys returned by ``_load_model``.
248
+ """
249
+ for k , shape_checkpoint , shape_model in incompatible .incorrect_shapes :
250
+ self .logger .warning (
251
+ "Skip loading parameter '{}' to the model due to incompatible "
252
+ "shapes: {} in the checkpoint but {} in the "
253
+ "model! You might want to double check if this is expected." .format (
254
+ k , shape_checkpoint , shape_model
255
+ )
256
+ )
218
257
if incompatible .missing_keys :
219
- self . logger . info (
220
- get_missing_parameters_message ( incompatible .missing_keys )
258
+ missing_keys = _filter_reused_missing_keys (
259
+ self . model , incompatible .missing_keys
221
260
)
261
+ if missing_keys :
262
+ self .logger .info (get_missing_parameters_message (missing_keys ))
222
263
if incompatible .unexpected_keys :
223
264
self .logger .info (
224
265
get_unexpected_parameters_message (incompatible .unexpected_keys )
@@ -297,7 +338,27 @@ def save(self, name: str, **kwargs: Any):
297
338
self .checkpointer .save (name , ** kwargs )
298
339
299
340
300
- def get_missing_parameters_message (keys : list ):
341
+ def _filter_reused_missing_keys (model : nn .Module , keys : List [str ]) -> List [str ]:
342
+ """
343
+ Filter "missing keys" to not include keys that have been loaded with another name.
344
+ """
345
+ keyset = set (keys )
346
+ param_to_names = defaultdict (set ) # param -> names that points to it
347
+ for module_prefix , module in _named_modules_with_dup (model ):
348
+ for name , param in list (module .named_parameters (recurse = False )) + list (
349
+ module .named_buffers (recurse = False ) # pyre-ignore
350
+ ):
351
+ full_name = (module_prefix + "." if module_prefix else "" ) + name
352
+ param_to_names [param ].add (full_name )
353
+ for names in param_to_names .values ():
354
+ # if one name appears missing but its alias exists, then this
355
+ # name is not considered missing
356
+ if any (n in keyset for n in names ) and not all (n in keyset for n in names ):
357
+ [keyset .remove (n ) for n in names if n in keyset ]
358
+ return list (keyset )
359
+
360
+
361
+ def get_missing_parameters_message (keys : List [str ]) -> str :
301
362
"""
302
363
Get a logging-friendly message to report parameter names (keys) that are in
303
364
the model but not found in a checkpoint.
@@ -307,14 +368,14 @@ def get_missing_parameters_message(keys: list):
307
368
str: message.
308
369
"""
309
370
groups = _group_checkpoint_keys (keys )
310
- msg = "Some model parameters are not in the checkpoint:\n "
371
+ msg = "Some model parameters or buffers are not found in the checkpoint:\n "
311
372
msg += "\n " .join (
312
373
" " + colored (k + _group_to_str (v ), "blue" ) for k , v in groups .items ()
313
374
)
314
375
return msg
315
376
316
377
317
- def get_unexpected_parameters_message (keys : list ) :
378
+ def get_unexpected_parameters_message (keys : List [ str ]) -> str :
318
379
"""
319
380
Get a logging-friendly message to report parameter names (keys) that are in
320
381
the checkpoint but not found in the model.
@@ -324,15 +385,14 @@ def get_unexpected_parameters_message(keys: list):
324
385
str: message.
325
386
"""
326
387
groups = _group_checkpoint_keys (keys )
327
- msg = "The checkpoint contains parameters not used by the model:\n "
388
+ msg = "The checkpoint state_dict contains keys that are not used by the model:\n "
328
389
msg += "\n " .join (
329
- " " + colored (k + _group_to_str (v ), "magenta" )
330
- for k , v in groups .items ()
390
+ " " + colored (k + _group_to_str (v ), "magenta" ) for k , v in groups .items ()
331
391
)
332
392
return msg
333
393
334
394
335
- def _strip_prefix_if_present (state_dict : collections . OrderedDict , prefix : str ):
395
+ def _strip_prefix_if_present (state_dict : Dict [ str , Any ], prefix : str ) -> None :
336
396
"""
337
397
Strip the prefix in metadata, if any.
338
398
Args:
@@ -349,7 +409,7 @@ def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
349
409
350
410
# also strip the prefix in metadata, if any..
351
411
try :
352
- metadata = state_dict ._metadata
412
+ metadata = state_dict ._metadata # pyre-ignore
353
413
except AttributeError :
354
414
pass
355
415
else :
@@ -365,7 +425,7 @@ def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
365
425
metadata [newkey ] = metadata .pop (key )
366
426
367
427
368
- def _group_checkpoint_keys (keys : list ) :
428
+ def _group_checkpoint_keys (keys : List [ str ]) -> Dict [ str , List [ str ]] :
369
429
"""
370
430
Group keys based on common prefixes. A prefix is the string up to the final
371
431
"." in each key.
@@ -386,7 +446,7 @@ def _group_checkpoint_keys(keys: list):
386
446
return groups
387
447
388
448
389
- def _group_to_str (group : list ) :
449
+ def _group_to_str (group : List [ str ]) -> str :
390
450
"""
391
451
Format a group of parameter name suffixes into a loggable string.
392
452
Args:
@@ -401,3 +461,18 @@ def _group_to_str(group: list):
401
461
return "." + group [0 ]
402
462
403
463
return ".{" + ", " .join (group ) + "}"
464
+
465
+
466
+ def _named_modules_with_dup (
467
+ model : nn .Module , prefix : str = ""
468
+ ) -> Iterable [Tuple [str , nn .Module ]]:
469
+ """
470
+ The same as `model.named_modules()`, except that it includes
471
+ duplicated modules that have more than one name.
472
+ """
473
+ yield prefix , model
474
+ for name , module in model ._modules .items (): # pyre-ignore
475
+ if module is None :
476
+ continue
477
+ submodule_prefix = prefix + ("." if prefix else "" ) + name
478
+ yield from _named_modules_with_dup (module , submodule_prefix )
0 commit comments