Skip to content

Commit ad4e207

Browse files
committed
add AutoencoderMixin
1 parent e27df8d commit ad4e207

18 files changed

+122
-391
lines changed

mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
from ...configuration_utils import ConfigMixin, register_to_config
2525
from ..modeling_outputs import AutoencoderKLOutput
2626
from ..modeling_utils import ModelMixin
27-
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
27+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
2828

2929

30-
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
30+
class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
3131
r"""
3232
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
3333
KL loss for encoding images into latents and decoding latent representations into images.
@@ -112,9 +112,6 @@ def __init__(
112112
self.quant_conv = mint.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
113113
self.post_quant_conv = mint.nn.Conv2d(latent_channels, latent_channels, 1)
114114

115-
self.use_slicing = False
116-
self.use_tiling = False
117-
118115
self.register_to_config(block_out_channels=up_block_out_channels)
119116
self.register_to_config(force_upcast=False)
120117

mindone/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..modeling_utils import ModelMixin
3131
from ..normalization import RMSNorm, get_normalization
3232
from ..transformers.sana_transformer import GLUMBConv
33-
from .vae import DecoderOutput, EncoderOutput
33+
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
3434

3535

3636
class ResBlock(nn.Cell):
@@ -393,7 +393,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
393393
return hidden_states
394394

395395

396-
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
396+
class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
397397
r"""
398398
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
399399
[SANA](https://huggingface.co/papers/2410.10629).
@@ -551,27 +551,6 @@ def enable_tiling(
551551
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
552552
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
553553

554-
def disable_tiling(self) -> None:
555-
r"""
556-
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
557-
decoding in one step.
558-
"""
559-
self.use_tiling = False
560-
561-
def enable_slicing(self) -> None:
562-
r"""
563-
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
564-
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
565-
"""
566-
self.use_slicing = True
567-
568-
def disable_slicing(self) -> None:
569-
r"""
570-
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
571-
decoding in one step.
572-
"""
573-
self.use_slicing = False
574-
575554
def _encode(self, x: ms.Tensor) -> ms.Tensor:
576555
batch_size, num_channels, height, width = x.shape
577556

mindone/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
2929
from ..modeling_outputs import AutoencoderKLOutput
3030
from ..modeling_utils import ModelMixin
31-
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
31+
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3232

3333

34-
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
34+
class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
3535
r"""
3636
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
3737
@@ -135,35 +135,6 @@ def __init__(
135135
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
136136
self.tile_overlap_factor = 0.25
137137

138-
def enable_tiling(self, use_tiling: bool = True):
139-
r"""
140-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
141-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
142-
processing larger images.
143-
"""
144-
self.use_tiling = use_tiling
145-
146-
def disable_tiling(self):
147-
r"""
148-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
149-
decoding in one step.
150-
"""
151-
self.enable_tiling(False)
152-
153-
def enable_slicing(self):
154-
r"""
155-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
156-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
157-
"""
158-
self.use_slicing = True
159-
160-
def disable_slicing(self):
161-
r"""
162-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
163-
decoding in one step.
164-
"""
165-
self.use_slicing = False
166-
167138
@property
168139
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
169140
def attn_processors(self) -> Dict[str, AttentionProcessor]:

mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..modeling_utils import ModelMixin
3333
from ..resnet import ResnetBlock2D
3434
from ..upsampling import Upsample2D
35+
from .vae import AutoencoderMixin
3536

3637

3738
class AllegroTemporalConvLayer(nn.Cell):
@@ -685,7 +686,7 @@ def construct(self, sample: ms.Tensor) -> ms.Tensor:
685686
return sample
686687

687688

688-
class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
689+
class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
689690
r"""
690691
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
691692
[Allegro](https://github.com/rhymes-ai/Allegro).
@@ -808,35 +809,6 @@ def __init__(
808809
sample_size - self.tile_overlap_w,
809810
)
810811

811-
def enable_tiling(self) -> None:
812-
r"""
813-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
814-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
815-
processing larger images.
816-
"""
817-
self.use_tiling = True
818-
819-
def disable_tiling(self) -> None:
820-
r"""
821-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
822-
decoding in one step.
823-
"""
824-
self.use_tiling = False
825-
826-
def enable_slicing(self) -> None:
827-
r"""
828-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
829-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
830-
"""
831-
self.use_slicing = True
832-
833-
def disable_slicing(self) -> None:
834-
r"""
835-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
836-
decoding in one step.
837-
"""
838-
self.use_slicing = False
839-
840812
def _encode(self, x: ms.Tensor) -> ms.Tensor:
841813
# TODO(aryan)
842814
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):

mindone/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ..modeling_outputs import AutoencoderKLOutput
3333
from ..modeling_utils import ModelMixin
3434
from ..upsampling import CogVideoXUpsample3D
35-
from .vae import DecoderOutput, DiagonalGaussianDistribution
35+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
3636

3737
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3838

@@ -891,7 +891,7 @@ def construct(
891891
return hidden_states, new_conv_cache
892892

893893

894-
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
894+
class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
895895
r"""
896896
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
897897
[CogVideoX](https://github.com/THUDM/CogVideo).
@@ -1061,27 +1061,6 @@ def enable_tiling(
10611061
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
10621062
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
10631063

1064-
def disable_tiling(self) -> None:
1065-
r"""
1066-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1067-
decoding in one step.
1068-
"""
1069-
self.use_tiling = False
1070-
1071-
def enable_slicing(self) -> None:
1072-
r"""
1073-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1074-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1075-
"""
1076-
self.use_slicing = True
1077-
1078-
def disable_slicing(self) -> None:
1079-
r"""
1080-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1081-
decoding in one step.
1082-
"""
1083-
self.use_slicing = False
1084-
10851064
def _encode(self, x: ms.Tensor) -> ms.Tensor:
10861065
batch_size, num_channels, num_frames, height, width = x.shape
10871066

mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..layers_compat import conv_transpose3d, unflatten
2626
from ..modeling_outputs import AutoencoderKLOutput
2727
from ..modeling_utils import ModelMixin
28-
from .vae import DecoderOutput, IdentityDistribution
28+
from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution
2929

3030
logger = get_logger(__name__)
3131

@@ -915,7 +915,7 @@ def construct(self, hidden_states: ms.tensor) -> ms.tensor:
915915
return hidden_states
916916

917917

918-
class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
918+
class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
919919
r"""
920920
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
921921
@@ -1072,28 +1072,7 @@ def enable_tiling(
10721072
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
10731073
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
10741074

1075-
def disable_tiling(self) -> None:
1076-
r"""
1077-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1078-
decoding in one step.
1079-
"""
1080-
self.use_tiling = False
1081-
1082-
def enable_slicing(self) -> None:
1083-
r"""
1084-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1085-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1086-
"""
1087-
self.use_slicing = True
1088-
1089-
def disable_slicing(self) -> None:
1090-
r"""
1091-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1092-
decoding in one step.
1093-
"""
1094-
self.use_slicing = False
1095-
1096-
def _encode(self, x: ms.tensor) -> ms.tensor:
1075+
def _encode(self, x: ms.Tensor) -> ms.Tensor:
10971076
x = self.encoder(x)
10981077
enc = self.quant_conv(x)
10991078
return enc

mindone/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..layers_compat import unflatten
3131
from ..modeling_outputs import AutoencoderKLOutput
3232
from ..modeling_utils import ModelMixin
33-
from .vae import DecoderOutput, DiagonalGaussianDistribution
33+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
3434

3535
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3636

@@ -595,7 +595,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
595595
return hidden_states
596596

597597

598-
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
598+
class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
599599
r"""
600600
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
601601
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
@@ -736,27 +736,6 @@ def enable_tiling(
736736
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
737737
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
738738

739-
def disable_tiling(self) -> None:
740-
r"""
741-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
742-
decoding in one step.
743-
"""
744-
self.use_tiling = False
745-
746-
def enable_slicing(self) -> None:
747-
r"""
748-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
749-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
750-
"""
751-
self.use_slicing = True
752-
753-
def disable_slicing(self) -> None:
754-
r"""
755-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
756-
decoding in one step.
757-
"""
758-
self.use_slicing = False
759-
760739
def _encode(self, x: ms.Tensor) -> ms.Tensor:
761740
batch_size, num_channels, num_frames, height, width = x.shape
762741

@@ -777,7 +756,7 @@ def encode(
777756
Encode a batch of images into latents.
778757
779758
Args:
780-
x (`torch.Tensor`): Input batch of images.
759+
x (`ms.Tensor`): Input batch of images.
781760
return_dict (`bool`, *optional*, defaults to `True`):
782761
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
783762
@@ -823,7 +802,7 @@ def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput
823802
Decode a batch of images.
824803
825804
Args:
826-
z (`torch.Tensor`): Input batch of latent vectors.
805+
z (`ms.Tensor`): Input batch of latent vectors.
827806
return_dict (`bool`, *optional*, defaults to `True`):
828807
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
829808
@@ -871,10 +850,10 @@ def tiled_encode(self, x: ms.Tensor) -> AutoencoderKLOutput:
871850
r"""Encode a batch of images using a tiled encoder.
872851
873852
Args:
874-
x (`torch.Tensor`): Input batch of videos.
853+
x (`ms.Tensor`): Input batch of videos.
875854
876855
Returns:
877-
`torch.Tensor`:
856+
`ms.Tensor`:
878857
The latent representation of the encoded videos.
879858
"""
880859
batch_size, num_channels, num_frames, height, width = x.shape
@@ -922,7 +901,7 @@ def tiled_decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[Decoder
922901
Decode a batch of images using a tiled decoder.
923902
924903
Args:
925-
z (`torch.Tensor`): Input batch of latent vectors.
904+
z (`ms.Tensor`): Input batch of latent vectors.
926905
return_dict (`bool`, *optional*, defaults to `True`):
927906
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
928907
@@ -1051,7 +1030,7 @@ def construct(
10511030
) -> Union[DecoderOutput, ms.Tensor]:
10521031
r"""
10531032
Args:
1054-
sample (`torch.Tensor`): Input sample.
1033+
sample (`ms.Tensor`): Input sample.
10551034
sample_posterior (`bool`, *optional*, defaults to `False`):
10561035
Whether to sample from the posterior.
10571036
return_dict (`bool`, *optional*, defaults to `True`):

0 commit comments

Comments
 (0)