diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9304c34b4e01..3c7e8654d223 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -155,6 +155,7 @@ "AutoencoderKLWan", "AutoencoderOobleck", "AutoencoderTiny", + "AutoModel", "CacheMixin", "CogVideoXTransformer3DModel", "CogView3PlusTransformer2DModel", @@ -197,6 +198,7 @@ "T2IAdapter", "T5FilmDecoder", "Transformer2DModel", + "TransformerTemporalModel", "UNet1DModel", "UNet2DConditionModel", "UNet2DModel", @@ -731,6 +733,7 @@ AutoencoderKLWan, AutoencoderOobleck, AutoencoderTiny, + AutoModel, CacheMixin, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, @@ -772,6 +775,7 @@ T2IAdapter, T5FilmDecoder, Transformer2DModel, + TransformerTemporalModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f7d70f1d9826..99a2f871c837 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -41,6 +41,7 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] + _import_structure["auto_model"] = ["AutoModel"] _import_structure["cache_utils"] = ["CacheMixin"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] @@ -103,6 +104,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter + from .auto_model import AutoModel from .autoencoders import ( AsymmetricAutoencoderKL, AutoencoderDC, diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py new file mode 100644 index 000000000000..1b742463aa2e --- /dev/null +++ b/src/diffusers/models/auto_model.py @@ -0,0 +1,169 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from typing import Optional, Union + +from huggingface_hub.utils import validate_hf_hub_args + +from ..configuration_utils import ConfigMixin + + +class AutoModel(ConfigMixin): + config_name = "config.json" + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." + ) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. Defaults to `None`, meaning that the model will be loaded on CPU. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import AutoModel + + unet = AutoModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + load_config_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "token": token, + "local_files_only": local_files_only, + "revision": revision, + "subfolder": subfolder, + } + + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + orig_class_name = config["_class_name"] + + library = importlib.import_module("diffusers") + + model_cls = getattr(library, orig_class_name, None) + if model_cls is None: + raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") + + kwargs = {**load_config_kwargs, **kwargs} + return model_cls.from_pretrained(pretrained_model_or_path, **kwargs) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6edbd737e32c..dd9117ddca18 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -280,6 +280,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CacheMixin(metaclass=DummyObject): _backends = ["torch"] @@ -895,6 +910,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class TransformerTemporalModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class UNet1DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 847677884a35..6155ac2e39fd 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -45,6 +45,7 @@ AttnProcessorNPU, XFormersAttnProcessor, ) +from diffusers.models.auto_model import AutoModel from diffusers.training_utils import EMAModel from diffusers.utils import ( SAFE_WEIGHTS_INDEX_NAME, @@ -1577,6 +1578,49 @@ def run_forward(model): self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) + def test_auto_model(self, expected_max_diff=5e-5): + if self.forward_requires_fresh_args: + model = self.model_class(**self.init_dict) + else: + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + model = model.eval() + model = model.to(torch_device) + + if hasattr(model, "set_default_attn_processor"): + model.set_default_attn_processor() + + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: + model.save_pretrained(tmpdirname, safe_serialization=False) + + auto_model = AutoModel.from_pretrained(tmpdirname) + if hasattr(auto_model, "set_default_attn_processor"): + auto_model.set_default_attn_processor() + + auto_model = auto_model.eval() + auto_model = auto_model.to(torch_device) + + with torch.no_grad(): + if self.forward_requires_fresh_args: + output_original = model(**self.inputs_dict(0)) + output_auto = auto_model(**self.inputs_dict(0)) + else: + output_original = model(**inputs_dict) + output_auto = auto_model(**inputs_dict) + + if isinstance(output_original, dict): + output_original = output_original.to_tuple()[0] + if isinstance(output_auto, dict): + output_auto = output_auto.to_tuple()[0] + + max_diff = (output_original - output_auto).abs().max().item() + self.assertLessEqual( + max_diff, + expected_max_diff, + f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}", + ) + @is_staging_test class ModelPushToHubTester(unittest.TestCase):