Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing MochiEncoder3D.gradient_checkpointing attribute #11146

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def forward(
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
conv_cache=conv_cache.get(conv_cache_key),
conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
Expand Down Expand Up @@ -306,7 +306,7 @@ def forward(

if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
resnet, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
Expand Down Expand Up @@ -382,7 +382,7 @@ def forward(
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
conv_cache=conv_cache.get(conv_cache_key),
conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
Expand Down Expand Up @@ -497,6 +497,8 @@ def __init__(
self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1])
self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False)

self.gradient_checkpointing = False

def forward(
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
) -> torch.Tensor:
Expand All @@ -513,13 +515,13 @@ def forward(

if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
self.block_in, hidden_states, conv_cache.get("block_in")
)

for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
down_block, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache["block_in"] = self.block_in(
Expand Down Expand Up @@ -623,13 +625,13 @@ def forward(
# 1. Mid
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
self.block_in, hidden_states, conv_cache.get("block_in")
)

for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
up_block, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache["block_in"] = self.block_in(
Expand Down
111 changes: 111 additions & 0 deletions tests/models/autoencoders/test_models_autoencoder_mochi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from diffusers import AutoencoderKLMochi
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)

from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin


enable_full_determinism()


class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMochi
main_input_name = "sample"
base_precision = 1e-2

def get_autoencoder_kl_mochi_config(self):
return {
"in_channels": 15,
"out_channels": 3,
"latent_channels": 4,
"encoder_block_out_channels": (32, 32, 32, 32),
"decoder_block_out_channels": (32, 32, 32, 32),
"layers_per_block": (1, 1, 1, 1, 1),
"act_fn": "silu",
"scaling_factor": 1,
}

@property
def dummy_input(self):
batch_size = 2
num_frames = 7
num_channels = 3
sizes = (16, 16)

image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)

return {"sample": image}

@property
def input_shape(self):
return (3, 7, 16, 16)

@property
def output_shape(self):
return (3, 7, 16, 16)

def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_mochi_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"MochiDecoder3D",
"MochiDownBlock3D",
"MochiEncoder3D",
"MochiMidBlock3D",
"MochiUpBlock3D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups -
TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups'
"""
pass

@unittest.skip("Unsupported test.")
def test_model_parallelism(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence -
RuntimeError: values expected sparse tensor layout but got Strided
"""
pass

@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence -
RuntimeError: values expected sparse tensor layout but got Strided
"""
pass

@unittest.skip("Unsupported test.")
def test_sharded_checkpoints_device_map(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_sharded_checkpoints_device_map -
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:5!
"""
Loading