diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index a9771b368a86..e150281e81ce 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name models by reducing the precision of the weights and activations, thus making models more efficient in terms of both storage and computation. """ - model, has_been_replaced = _replace_with_bnb_linear( - model, modules_to_not_convert, current_key_name, quantization_config - ) + model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config) + has_been_replaced = any( + isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)) + for _, replaced_module in model.named_modules() + ) if not has_been_replaced: logger.warning( "You are loading your model in 8bit or 4bit but no linear modules were found in your model." @@ -283,16 +285,18 @@ def dequantize_and_replace( modules_to_not_convert=None, quantization_config=None, ): - model, has_been_replaced = _dequantize_and_replace( + model, _ = _dequantize_and_replace( model, dtype=model.dtype, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config, ) - + has_been_replaced = any( + isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules() + ) if not has_been_replaced: logger.warning( - "For some reason the model has not been properly dequantized. You might see unexpected behavior." + "Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model." ) return model diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 29a3e212c48d..6ae46f352c19 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -63,6 +63,8 @@ def get_some_linear_layer(model): if is_bitsandbytes_available(): import bitsandbytes as bnb + from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear + @require_bitsandbytes_version_greater("0.43.2") @require_accelerate @@ -364,6 +366,18 @@ def test_bnb_4bit_errors_loading_incorrect_state_dict(self): assert key_to_target in str(err_context.exception) + def test_bnb_4bit_logs_warning_for_no_quantization(self): + model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU()) + quantization_config = BitsAndBytesConfig(load_in_4bit=True) + logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + _ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config) + assert ( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + in cap_logger.out + ) + class BnB4BitTrainingTests(Base4bitTests): def setUp(self): diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index a5e38f931e09..1049bfecbaab 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -68,6 +68,8 @@ def get_some_linear_layer(model): if is_bitsandbytes_available(): import bitsandbytes as bnb + from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear + @require_bitsandbytes_version_greater("0.43.2") @require_accelerate @@ -317,6 +319,18 @@ def test_device_and_dtype_assignment(self): # Check that this does not throw an error _ = self.model_fp16.to(torch_device) + def test_bnb_8bit_logs_warning_for_no_quantization(self): + model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU()) + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + _ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config) + assert ( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + in cap_logger.out + ) + class Bnb8bitDeviceTests(Base8bitTests): def setUp(self) -> None: