diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index f06f4832740c..9c77fa682d01 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -67,6 +67,23 @@ } +def _is_subclass(src_cls, dst_cls_str): + """ + Find if src_cls is a subclass of dst_cls whose name is dst_cls_str + """ + for cls in src_cls.__mro__: + if cls.__name__ == dst_cls_str: + return True + return False + + +def _get_single_file_loadable_mapping_class(cls): + for dst_cls_str in SINGLE_FILE_LOADABLE_CLASSES: + if _is_subclass(cls, dst_cls_str): + return dst_cls_str + return None + + def _get_mapping_function_kwargs(mapping_fn, **kwargs): parameters = inspect.signature(mapping_fn).parameters @@ -144,8 +161,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = ``` """ - class_name = cls.__name__ - if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + mapping_class_name = _get_single_file_loadable_mapping_class(cls) + # if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + if mapping_class_name is None: raise ValueError( f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}" ) @@ -190,7 +208,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = revision=revision, ) - mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name] + mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] if original_config: @@ -202,7 +220,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = if config_mapping_fn is None: raise ValueError( ( - f"`original_config` has been provided for {class_name} but no mapping function" + f"`original_config` has been provided for {mapping_class_name} but no mapping function" "was found to convert the original config to a Diffusers config in" "`diffusers.loaders.single_file_utils`" ) @@ -262,7 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = ) if not diffusers_format_checkpoint: raise SingleFileComponentError( - f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." + f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." ) ctx = init_empty_weights if is_accelerate_available() else nullcontext diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 733579b8c09c..31542947d333 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -65,7 +65,7 @@ def create_dynamic_module(name: Union[str, os.PathLike]): Creates a dynamic module in the cache directory for modules. """ init_hf_modules() - dynamic_module_path = Path(HF_MODULES_CACHE) / name + dynamic_module_path = Path(HF_MODULES_CACHE).absolute() / name # If the parent module does not exist yet, recursively create it. if not dynamic_module_path.parent.exists(): create_dynamic_module(dynamic_module_path.parent)