diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 2bf5f50f7c..3d834f71b2 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -1,25 +1,28 @@ # Copyright (c) OpenMMLab. All rights reserved. import io -import logging import os -import os.path as osp -import pkgutil import re -from collections import OrderedDict, namedtuple +import logging +import pkgutil +import os.path as osp +from codecs import encode from importlib import import_module from tempfile import TemporaryDirectory from typing import Callable, Dict, Optional +from collections import OrderedDict, namedtuple import torch +import numpy as np +from numpy.dtypes import Float64DType, Int64DType +from numpy.core.multiarray import scalar, _reconstruct import mmengine from mmengine.dist import get_dist_info from mmengine.fileio import FileClient, get_file_backend from mmengine.fileio import load as load_file -from mmengine.logging import print_log +from mmengine.logging import print_log, HistoryBuffer from mmengine.model import BaseTTAModel, is_model_wrapper -from mmengine.utils import (apply_to, deprecated_function, digit_version, - mkdir_or_exist) +from mmengine.utils import apply_to, deprecated_function, digit_version, mkdir_or_exist from mmengine.utils.dl_utils import load_url # `MMENGINE_HOME` is the highest priority directory to save checkpoints @@ -28,17 +31,33 @@ # Note that `XDG_CACHE_HOME` defines the base directory relative to which # user-specific non-essential data files should be stored. If `XDG_CACHE_HOME` # is either not set or empty, a default equal to `~/.cache` should be used. -ENV_MMENGINE_HOME = 'MMENGINE_HOME' -ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' -DEFAULT_CACHE_DIR = '~/.cache' +ENV_MMENGINE_HOME = "MMENGINE_HOME" +ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" +DEFAULT_CACHE_DIR = "~/.cache" + +# allowlist these globals so that it can be loaded from torch.load +torch.serialization.add_safe_globals( + [ + getattr, + encode, + np.dtype, + Float64DType, + Int64DType, + np.ndarray, + scalar, + _reconstruct, + HistoryBuffer, + ] +) class _IncompatibleKeys( - namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): + namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]) +): def __repr__(self): if not self.missing_keys and not self.unexpected_keys: - return '' + return "" return super().__repr__() __str__ = __repr__ @@ -48,8 +67,9 @@ def _get_mmengine_home(): mmengine_home = os.path.expanduser( os.getenv( ENV_MMENGINE_HOME, - os.path.join( - os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmengine'))) + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmengine"), + ) + ) mkdir_or_exist(mmengine_home) return mmengine_home @@ -76,25 +96,30 @@ def load_state_dict(module, state_dict, strict=False, logger=None): err_msg = [] # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) + metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # use _load_from_state_dict to enable checkpoint version control - def load(module, local_state_dict, prefix=''): + def load(module, local_state_dict, prefix=""): # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_model_wrapper(module) or isinstance(module, BaseTTAModel): module = module.module - local_metadata = {} if metadata is None else metadata.get( - prefix[:-1], {}) - module._load_from_state_dict(local_state_dict, prefix, local_metadata, - True, missing_keys, unexpected_keys, - err_msg) + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + local_state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + err_msg, + ) for name, child in module._modules.items(): if child is not None: - child_prefix = prefix + name + '.' + child_prefix = prefix + name + "." child_state_dict = { k: v for k, v in local_state_dict.items() @@ -104,35 +129,35 @@ def load(module, local_state_dict, prefix=''): # Note that the hook can modify missing_keys and unexpected_keys. incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) - if hasattr(module, '_load_state_dict_post_hooks'): + if hasattr(module, "_load_state_dict_post_hooks"): for hook in module._load_state_dict_post_hooks.values(): out = hook(module, incompatible_keys) assert out is None, ( - 'Hooks registered with ' - '``register_load_state_dict_post_hook`` are not expected ' - 'to return new values, if incompatible_keys need to be ' - 'modified, it should be done inplace.') + "Hooks registered with " + "``register_load_state_dict_post_hook`` are not expected " + "to return new values, if incompatible_keys need to be " + "modified, it should be done inplace." + ) load(module, state_dict) load = None # break load->load reference cycle # ignore "num_batches_tracked" of BN layers - missing_keys = [ - key for key in missing_keys if 'num_batches_tracked' not in key - ] + missing_keys = [key for key in missing_keys if "num_batches_tracked" not in key] if unexpected_keys: - err_msg.append('unexpected key in source ' - f'state_dict: {", ".join(unexpected_keys)}\n') + err_msg.append( + "unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n' + ) if missing_keys: err_msg.append( - f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + f'missing keys in source state_dict: {", ".join(missing_keys)}\n' + ) rank, _ = get_dist_info() if len(err_msg) > 0 and rank == 0: - err_msg.insert( - 0, 'The model and loaded state dict do not match exactly\n') - err_msg = '\n'.join(err_msg) + err_msg.insert(0, "The model and loaded state dict do not match exactly\n") + err_msg = "\n".join(err_msg) if strict: raise RuntimeError(err_msg) else: @@ -141,19 +166,19 @@ def load(module, local_state_dict, prefix=''): def get_torchvision_models(): import torchvision - if digit_version(torchvision.__version__) < digit_version('0.13.0a0'): + + if digit_version(torchvision.__version__) < digit_version("0.13.0a0"): model_urls = dict() # When the version of torchvision is lower than 0.13, the model url is # not declared in `torchvision.model.__init__.py`, so we need to # iterate through `torchvision.models.__path__` to get the url for each # model. - for _, name, ispkg in pkgutil.walk_packages( - torchvision.models.__path__): + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): if ispkg: continue - _zoo = import_module(f'torchvision.models.{name}') - if hasattr(_zoo, 'model_urls'): - _urls = getattr(_zoo, 'model_urls') + _zoo = import_module(f"torchvision.models.{name}") + if hasattr(_zoo, "model_urls"): + _urls = getattr(_zoo, "model_urls") model_urls.update(_urls) else: # Since torchvision bumps to v0.13, the weight loading logic, @@ -162,12 +187,13 @@ def get_torchvision_models(): # torchvision version>=0.13.0, new URLs will be added. Users can get # the resnet50 checkpoint by setting 'resnet50.imagent1k_v1', # 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config. - json_path = osp.join(mmengine.__path__[0], 'hub/torchvision_0.12.json') + json_path = osp.join(mmengine.__path__[0], "hub/torchvision_0.12.json") model_urls = mmengine.load(json_path) - if digit_version(torchvision.__version__) < digit_version('0.14.0a0'): + if digit_version(torchvision.__version__) < digit_version("0.14.0a0"): weights_list = [ - cls for cls_name, cls in torchvision.models.__dict__.items() - if cls_name.endswith('_Weights') + cls + for cls_name, cls in torchvision.models.__dict__.items() + if cls_name.endswith("_Weights") ] else: weights_list = [ @@ -181,16 +207,16 @@ def get_torchvision_models(): # classes, such as `MNASNet0_75_Weights` does not have any urls in # torchvision 0.13.0 and cannot be iterated. Here we simply check # `DEFAULT` attribute to ensure the class is not empty. - if not hasattr(cls, 'DEFAULT'): + if not hasattr(cls, "DEFAULT"): continue # Since `cls.DEFAULT` can not be accessed by iterating cls, we set # default urls explicitly. cls_name = cls.__name__ - cls_key = cls_name.replace('_Weights', '').lower() - model_urls[f'{cls_key}.default'] = cls.DEFAULT.url + cls_key = cls_name.replace("_Weights", "").lower() + model_urls[f"{cls_key}.default"] = cls.DEFAULT.url for weight_enum in cls: - cls_key = cls_name.replace('_Weights', '').lower() - cls_key = f'{cls_key}.{weight_enum.name.lower()}' + cls_key = cls_name.replace("_Weights", "").lower() + cls_key = f"{cls_key}.{weight_enum.name.lower()}" model_urls[cls_key] = weight_enum.url return model_urls @@ -198,10 +224,10 @@ def get_torchvision_models(): def get_external_models(): mmengine_home = _get_mmengine_home() - default_json_path = osp.join(mmengine.__path__[0], 'hub/openmmlab.json') + default_json_path = osp.join(mmengine.__path__[0], "hub/openmmlab.json") default_urls = load_file(default_json_path) assert isinstance(default_urls, dict) - external_json_path = osp.join(mmengine_home, 'open_mmlab.json') + external_json_path = osp.join(mmengine_home, "open_mmlab.json") if osp.exists(external_json_path): external_urls = load_file(external_json_path) assert isinstance(external_urls, dict) @@ -211,14 +237,14 @@ def get_external_models(): def get_mmcls_models(): - mmcls_json_path = osp.join(mmengine.__path__[0], 'hub/mmcls.json') + mmcls_json_path = osp.join(mmengine.__path__[0], "hub/mmcls.json") mmcls_urls = load_file(mmcls_json_path) return mmcls_urls def get_deprecated_model_names(): - deprecate_json_path = osp.join(mmengine.__path__[0], 'hub/deprecated.json') + deprecate_json_path = osp.join(mmengine.__path__[0], "hub/deprecated.json") deprecate_urls = load_file(deprecate_json_path) assert isinstance(deprecate_urls, dict) @@ -226,15 +252,15 @@ def get_deprecated_model_names(): def _process_mmcls_checkpoint(checkpoint): - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] else: # Some checkpoints converted from 3rd-party repo don't # have the "state_dict" key. state_dict = checkpoint new_state_dict = OrderedDict() for k, v in state_dict.items(): - if k.startswith('backbone.'): + if k.startswith("backbone."): new_state_dict[k[9:]] = v new_checkpoint = dict(state_dict=new_state_dict) @@ -257,11 +283,13 @@ def _register_scheme(cls, prefixes, loader, force=False): cls._schemes[prefix] = loader else: raise KeyError( - f'{prefix} is already registered as a loader backend, ' - 'add "force=True" if you want to override it') + f"{prefix} is already registered as a loader backend, " + 'add "force=True" if you want to override it' + ) # sort, longer prefixes take priority cls._schemes = OrderedDict( - sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) + sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True) + ) @classmethod def register_scheme(cls, prefixes, loader=None, force=False): @@ -308,7 +336,7 @@ def _get_checkpoint_loader(cls, path): return cls._schemes[p] @classmethod - def load_checkpoint(cls, filename, map_location=None, logger='current'): + def load_checkpoint(cls, filename, map_location=None, logger="current"): """Load checkpoint through URL scheme path. Args: @@ -324,13 +352,13 @@ def load_checkpoint(cls, filename, map_location=None, logger='current'): checkpoint_loader = cls._get_checkpoint_loader(filename) class_name = checkpoint_loader.__name__ print_log( - f'Loads checkpoint by {class_name[10:]} backend from path: ' - f'{filename}', - logger=logger) + f"Loads checkpoint by {class_name[10:]} backend from path: " f"{filename}", + logger=logger, + ) return checkpoint_loader(filename, map_location) -@CheckpointLoader.register_scheme(prefixes='') +@CheckpointLoader.register_scheme(prefixes="") def load_from_local(filename, map_location): """Load checkpoint by local file path. @@ -343,16 +371,13 @@ def load_from_local(filename, map_location): """ filename = osp.expanduser(filename) if not osp.isfile(filename): - raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + raise FileNotFoundError(f"{filename} can not be found.") + checkpoint = torch.load(filename, map_location=map_location, weights_only=True) return checkpoint -@CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) -def load_from_http(filename, - map_location=None, - model_dir=None, - progress=os.isatty(0)): +@CheckpointLoader.register_scheme(prefixes=("http://", "https://")) +def load_from_http(filename, map_location=None, model_dir=None, progress=os.isatty(0)): """Load checkpoint through HTTP or HTTPS scheme path. In distributed setting, this function only download checkpoint at local rank 0. @@ -369,10 +394,8 @@ def load_from_http(filename, rank, world_size = get_dist_info() if rank == 0: checkpoint = load_url( - filename, - model_dir=model_dir, - map_location=map_location, - progress=progress) + filename, model_dir=model_dir, map_location=map_location, progress=progress + ) if world_size > 1: torch.distributed.barrier() if rank > 0: @@ -380,11 +403,12 @@ def load_from_http(filename, filename, model_dir=model_dir, map_location=map_location, - progress=progress) + progress=progress, + ) return checkpoint -@CheckpointLoader.register_scheme(prefixes='pavi://') +@CheckpointLoader.register_scheme(prefixes="pavi://") def load_from_pavi(filename, map_location=None): """Load checkpoint through the file path prefixed with pavi. In distributed setting, this function download ckpt at all ranks to different temporary @@ -398,27 +422,28 @@ def load_from_pavi(filename, map_location=None): Returns: dict or OrderedDict: The loaded checkpoint. """ - assert filename.startswith('pavi://'), \ - f'Expected filename startswith `pavi://`, but get {filename}' + assert filename.startswith( + "pavi://" + ), f"Expected filename startswith `pavi://`, but get {filename}" model_path = filename[7:] try: from pavi import modelcloud except ImportError: - raise ImportError( - 'Please install pavi to load checkpoint from modelcloud.') + raise ImportError("Please install pavi to load checkpoint from modelcloud.") model = modelcloud.get(model_path) with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) - checkpoint = torch.load(downloaded_file, map_location=map_location) + checkpoint = torch.load( + downloaded_file, map_location=map_location, weights_only=True + ) return checkpoint -@CheckpointLoader.register_scheme( - prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://']) -def load_from_ceph(filename, map_location=None, backend='petrel'): +@CheckpointLoader.register_scheme(prefixes=[r"(\S+\:)?s3://", r"(\S+\:)?petrel://"]) +def load_from_ceph(filename, map_location=None, backend="petrel"): """Load checkpoint through the file path prefixed with s3. In distributed setting, this function download ckpt at all ranks to different temporary directories. @@ -432,14 +457,13 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): Returns: dict or OrderedDict: The loaded checkpoint. """ - file_backend = get_file_backend( - filename, backend_args={'backend': backend}) + file_backend = get_file_backend(filename, backend_args={"backend": backend}) with io.BytesIO(file_backend.get(filename)) as buffer: - checkpoint = torch.load(buffer, map_location=map_location) + checkpoint = torch.load(buffer, map_location=map_location, weights_only=True) return checkpoint -@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) +@CheckpointLoader.register_scheme(prefixes=("modelzoo://", "torchvision://")) def load_from_torchvision(filename, map_location=None): """Load checkpoint through the file path prefixed with modelzoo or torchvision. @@ -453,19 +477,20 @@ def load_from_torchvision(filename, map_location=None): dict or OrderedDict: The loaded checkpoint. """ model_urls = get_torchvision_models() - if filename.startswith('modelzoo://'): + if filename.startswith("modelzoo://"): print_log( 'The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead', - logger='current', - level=logging.WARNING) + logger="current", + level=logging.WARNING, + ) model_name = filename[11:] else: model_name = filename[14:] return load_from_http(model_urls[model_name], map_location=map_location) -@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) +@CheckpointLoader.register_scheme(prefixes=("open-mmlab://", "openmmlab://")) def load_from_openmmlab(filename, map_location=None): """Load checkpoint through the file path prefixed with open-mmlab or openmmlab. @@ -481,34 +506,35 @@ def load_from_openmmlab(filename, map_location=None): """ model_urls = get_external_models() - prefix_str = 'open-mmlab://' + prefix_str = "open-mmlab://" if filename.startswith(prefix_str): model_name = filename[13:] else: model_name = filename[12:] - prefix_str = 'openmmlab://' + prefix_str = "openmmlab://" deprecated_urls = get_deprecated_model_names() if model_name in deprecated_urls: print_log( - f'{prefix_str}{model_name} is deprecated in favor ' - f'of {prefix_str}{deprecated_urls[model_name]}', - logger='current', - level=logging.WARNING) + f"{prefix_str}{model_name} is deprecated in favor " + f"of {prefix_str}{deprecated_urls[model_name]}", + logger="current", + level=logging.WARNING, + ) model_name = deprecated_urls[model_name] model_url = model_urls[model_name] # check if is url - if model_url.startswith(('http://', 'https://')): + if model_url.startswith(("http://", "https://")): checkpoint = load_from_http(model_url, map_location=map_location) else: filename = osp.join(_get_mmengine_home(), model_url) if not osp.isfile(filename): - raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + raise FileNotFoundError(f"{filename} can not be found.") + checkpoint = torch.load(filename, map_location=map_location, weights_only=True) return checkpoint -@CheckpointLoader.register_scheme(prefixes='mmcls://') +@CheckpointLoader.register_scheme(prefixes="mmcls://") def load_from_mmcls(filename, map_location=None): """Load checkpoint through the file path prefixed with mmcls. @@ -522,8 +548,7 @@ def load_from_mmcls(filename, map_location=None): model_urls = get_mmcls_models() model_name = filename[8:] - checkpoint = load_from_http( - model_urls[model_name], map_location=map_location) + checkpoint = load_from_http(model_urls[model_name], map_location=map_location) checkpoint = _process_mmcls_checkpoint(checkpoint) return checkpoint @@ -565,41 +590,36 @@ def _load_checkpoint_with_prefix(prefix, filename, map_location=None): checkpoint = _load_checkpoint(filename, map_location=map_location) - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] else: state_dict = checkpoint - if not prefix.endswith('.'): - prefix += '.' + if not prefix.endswith("."): + prefix += "." prefix_len = len(prefix) state_dict = { - k[prefix_len:]: v - for k, v in state_dict.items() if k.startswith(prefix) + k[prefix_len:]: v for k, v in state_dict.items() if k.startswith(prefix) } - assert state_dict, f'{prefix} is not in the pretrained model' + assert state_dict, f"{prefix} is not in the pretrained model" return state_dict -def _load_checkpoint_to_model(model, - checkpoint, - strict=False, - logger=None, - revise_keys=[(r'^module\.', '')]): +def _load_checkpoint_to_model( + model, checkpoint, strict=False, logger=None, revise_keys=[(r"^module\.", "")] +): # get state_dict from checkpoint - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] else: state_dict = checkpoint # strip prefix of state_dict - metadata = getattr(state_dict, '_metadata', OrderedDict()) + metadata = getattr(state_dict, "_metadata", OrderedDict()) for p, r in revise_keys: - state_dict = OrderedDict( - {re.sub(p, r, k): v - for k, v in state_dict.items()}) + state_dict = OrderedDict({re.sub(p, r, k): v for k, v in state_dict.items()}) # Keep metadata in state_dict state_dict._metadata = metadata @@ -608,12 +628,14 @@ def _load_checkpoint_to_model(model, return checkpoint -def load_checkpoint(model, - filename, - map_location=None, - strict=False, - logger=None, - revise_keys=[(r'^module\.', '')]): +def load_checkpoint( + model, + filename, + map_location=None, + strict=False, + logger=None, + revise_keys=[(r"^module\.", "")], +): """Load checkpoint from a file or URI. Args: @@ -636,11 +658,9 @@ def load_checkpoint(model, checkpoint = _load_checkpoint(filename, map_location, logger) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): - raise RuntimeError( - f'No state_dict found in checkpoint file {filename}') + raise RuntimeError(f"No state_dict found in checkpoint file {filename}") - return _load_checkpoint_to_model(model, checkpoint, strict, logger, - revise_keys) + return _load_checkpoint_to_model(model, checkpoint, strict, logger, revise_keys) def weights_to_cpu(state_dict): @@ -653,18 +673,18 @@ def weights_to_cpu(state_dict): OrderedDict: Model weights on GPU. """ # stash metadata to put in state_dict later - metadata = getattr(state_dict, '_metadata', OrderedDict()) - state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'), - lambda x: x.cpu()) + metadata = getattr(state_dict, "_metadata", OrderedDict()) + state_dict = apply_to(state_dict, lambda x: hasattr(x, "cpu"), lambda x: x.cpu()) state_dict._metadata = metadata return state_dict @deprecated_function( - since='0.3.0', - removed_in='0.5.0', - instructions='`_save_to_state_dict` will be deprecated in the future, ' - 'please use `nn.Module._save_to_state_dict` directly.') + since="0.3.0", + removed_in="0.5.0", + instructions="`_save_to_state_dict` will be deprecated in the future, " + "please use `nn.Module._save_to_state_dict` directly.", +) def _save_to_state_dict(module, destination, prefix, keep_vars): """Saves module state to `destination` dictionary. @@ -686,7 +706,7 @@ def _save_to_state_dict(module, destination, prefix, keep_vars): destination[prefix + name] = buf if keep_vars else buf.detach() -def get_state_dict(module, destination=None, prefix='', keep_vars=False): +def get_state_dict(module, destination=None, prefix="", keep_vars=False): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are @@ -715,13 +735,11 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() - destination._metadata[prefix[:-1]] = local_metadata = dict( - version=module._version) + destination._metadata[prefix[:-1]] = local_metadata = dict(version=module._version) module._save_to_state_dict(destination, prefix, keep_vars) for name, child in module._modules.items(): if child is not None: - get_state_dict( - child, destination, prefix + name + '.', keep_vars=keep_vars) + get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars) for hook in module._state_dict_hooks.values(): hook_result = hook(module, destination, prefix, local_metadata) if hook_result is not None: @@ -729,10 +747,7 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): return destination -def save_checkpoint(checkpoint, - filename, - file_client_args=None, - backend_args=None): +def save_checkpoint(checkpoint, filename, file_client_args=None, backend_args=None): """Save checkpoint to file. Args: @@ -750,23 +765,25 @@ def save_checkpoint(checkpoint, print_log( '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', - logger='current', - level=logging.WARNING) + logger="current", + level=logging.WARNING, + ) if backend_args is not None: raise ValueError( '"file_client_args" and "backend_args" cannot be set ' - 'at the same time.') + "at the same time." + ) - if filename.startswith('pavi://'): + if filename.startswith("pavi://"): if file_client_args is not None or backend_args is not None: raise ValueError( '"file_client_args" or "backend_args" should be "None" if ' - 'filename starts with "pavi://"') + 'filename starts with "pavi://"' + ) try: from pavi import exception, modelcloud except ImportError: - raise ImportError( - 'Please install pavi to load checkpoint from modelcloud.') + raise ImportError("Please install pavi to load checkpoint from modelcloud.") model_path = filename[7:] root = modelcloud.Folder() model_dir, model_name = osp.split(model_path) @@ -776,15 +793,14 @@ def save_checkpoint(checkpoint, model = root.create_training_model(model_dir) with TemporaryDirectory() as tmp_dir: checkpoint_file = osp.join(tmp_dir, model_name) - with open(checkpoint_file, 'wb') as f: + with open(checkpoint_file, "wb") as f: torch.save(checkpoint, f) f.flush() model.create_file(checkpoint_file, name=model_name) else: file_client = FileClient.infer_client(file_client_args, filename) if file_client_args is None: - file_backend = get_file_backend( - filename, backend_args=backend_args) + file_backend = get_file_backend(filename, backend_args=backend_args) else: file_backend = file_client @@ -804,12 +820,12 @@ def find_latest_checkpoint(path: str) -> Optional[str]: Returns: str or None: File path of the latest checkpoint. """ - save_file = osp.join(path, 'last_checkpoint') + save_file = osp.join(path, "last_checkpoint") last_saved: Optional[str] if os.path.exists(save_file): with open(save_file) as f: last_saved = f.read().strip() else: - print_log('Did not find last_checkpoint to be resumed.') + print_log("Did not find last_checkpoint to be resumed.") last_saved = None return last_saved diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 5a678db7b9..822756ef25 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -403,7 +403,8 @@ def run_iter(self, idx, data_batch: Sequence[dict]): with autocast(enabled=self.fp16): outputs = self.runner.model.val_step(data_batch) - outputs, self.val_loss = _update_losses(outputs, self.val_loss) + if isinstance(outputs, list): + outputs, self.val_loss = _update_losses(outputs, self.val_loss) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( @@ -486,7 +487,8 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: with autocast(enabled=self.fp16): outputs = self.runner.model.test_step(data_batch) - outputs, self.test_loss = _update_losses(outputs, self.test_loss) + if isinstance(outputs, list): + outputs, self.test_loss = _update_losses(outputs, self.test_loss) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( diff --git a/mmengine/utils/dl_utils/hub.py b/mmengine/utils/dl_utils/hub.py index 7f7f1a087d..219769a118 100644 --- a/mmengine/utils/dl_utils/hub.py +++ b/mmengine/utils/dl_utils/hub.py @@ -48,7 +48,7 @@ def _legacy_zip_load(filename, model_dir, map_location): f.extractall(model_dir) extraced_name = members[0].filename extracted_file = os.path.join(model_dir, extraced_name) - return torch.load(extracted_file, map_location=map_location) + return torch.load(extracted_file, map_location=map_location, weights_only=True) def load_url(url, model_dir=None, @@ -114,7 +114,7 @@ def load_url(url, return _legacy_zip_load(cached_file, model_dir, map_location) try: - return torch.load(cached_file, map_location=map_location) + return torch.load(cached_file, map_location=map_location, weights_only=True) except RuntimeError as error: if digit_version(TORCH_VERSION) < digit_version('1.5.0'): warnings.warn( diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index d731a42b76..8c176203ba 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -458,13 +458,13 @@ def test_with_runner(self, training_type): cfg = copy.deepcopy(common_cfg) runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'), weights_only=True) self.assertIn('optimizer', ckpt) cfg.default_hooks.checkpoint.save_optimizer = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'), weights_only=True) self.assertNotIn('optimizer', ckpt) # Test save_param_scheduler=False @@ -479,13 +479,13 @@ def test_with_runner(self, training_type): ] runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'), weights_only=True) self.assertIn('param_schedulers', ckpt) cfg.default_hooks.checkpoint.save_param_scheduler = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'), weights_only=True) self.assertNotIn('param_schedulers', ckpt) self.clear_work_dir() @@ -533,7 +533,7 @@ def test_with_runner(self, training_type): self.assertFalse( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'), weights_only=True) self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'], [9, 10, 11]) @@ -574,9 +574,9 @@ def test_with_runner(self, training_type): runner.train() best_ckpt_path = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') - best_ckpt = torch.load(best_ckpt_path) + best_ckpt = torch.load(best_ckpt_path, weights_only=True) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'), weights_only=True) self.assertEqual(best_ckpt_path, ckpt['message_hub']['runtime_info']['best_ckpt']) @@ -603,11 +603,11 @@ def test_with_runner(self, training_type): runner.train() best_ckpt_path = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') - best_ckpt = torch.load(best_ckpt_path) + best_ckpt = torch.load(best_ckpt_path, weights_only=True) # if the current ckpt is the best, the interval will be ignored the # the ckpt will also be saved - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'), weights_only=True) self.assertEqual(best_ckpt_path, ckpt['message_hub']['runtime_info']['best_ckpt']) diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 6dad7ba4f0..b83739b7eb 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -230,7 +230,7 @@ def test_with_runner(self): self.assertTrue( isinstance(ema_hook.ema_model, ExponentialMovingAverage)) - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=True) self.assertTrue('ema_state_dict' in checkpoint) self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8) @@ -245,7 +245,7 @@ def test_with_runner(self): runner.test() # Test load checkpoint without ema_state_dict - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=True) checkpoint.pop('ema_state_dict') torch.save(checkpoint, osp.join(self.temp_dir.name, 'without_ema_state_dict.pth')) @@ -274,7 +274,7 @@ def test_with_runner(self): runner = self.build_runner(cfg) runner.train() state_dict = torch.load( - osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu') + osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu', weights_only=True) self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) @@ -287,12 +287,12 @@ def test_with_runner(self): runner = self.build_runner(cfg) runner.train() state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu') + osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu', weights_only=True) self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu') + osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu', weights_only=True) self.assertIn('ema_state_dict', state_dict) def _test_swap_parameters(self, func_name, *args, **kwargs): diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index a13072dc6e..003ede7264 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -685,7 +685,7 @@ def _check_scheduler_state_dict(self, scheduler_copy = construct2() torch.save(scheduler.state_dict(), osp.join(self.temp_dir.name, 'tmp.pth')) - state_dict = torch.load(osp.join(self.temp_dir.name, 'tmp.pth')) + state_dict = torch.load(osp.join(self.temp_dir.name, 'tmp.pth'), weights_only=True) scheduler_copy.load_state_dict(state_dict) for key in scheduler.__dict__.keys(): if key != 'optimizer': diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index e7668054bb..b7a427fec7 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -2272,7 +2272,7 @@ def test_checkpoint(self): self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth'))) - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=True) self.assertEqual(ckpt['meta']['epoch'], 3) self.assertEqual(ckpt['meta']['iter'], 12) self.assertEqual(ckpt['meta']['experiment_name'], @@ -2444,7 +2444,7 @@ def test_checkpoint(self): self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth'))) - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=True) self.assertEqual(ckpt['meta']['epoch'], 0) self.assertEqual(ckpt['meta']['iter'], 12) assert isinstance(ckpt['optimizer'], dict) @@ -2455,7 +2455,7 @@ def test_checkpoint(self): self.assertEqual(message_hub.get_info('iter'), 11) # 2.1.2 check class attribute _statistic_methods can be saved HistoryBuffer._statistics_methods.clear() - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=True) self.assertIn('min', HistoryBuffer._statistics_methods) # 2.2 test `load_checkpoint`