Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bistandbytes] improve replacement warnings for bnb #11132

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/diffusers/quantizers/bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
models by reducing the precision of the weights and activations, thus making models more efficient in terms
of both storage and computation.
"""
model, has_been_replaced = _replace_with_bnb_linear(
model, modules_to_not_convert, current_key_name, quantization_config
)
model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)

has_been_replaced = any(
isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
for _, replaced_module in model.named_modules()
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
Expand Down Expand Up @@ -283,13 +285,15 @@ def dequantize_and_replace(
modules_to_not_convert=None,
quantization_config=None,
):
model, has_been_replaced = _dequantize_and_replace(
model, _ = _dequantize_and_replace(
model,
dtype=model.dtype,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)

has_been_replaced = any(
isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
)
if not has_been_replaced:
logger.warning(
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
Expand Down
14 changes: 14 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def get_some_linear_layer(model):
if is_bitsandbytes_available():
import bitsandbytes as bnb

from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear


@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
Expand Down Expand Up @@ -364,6 +366,18 @@ def test_bnb_4bit_errors_loading_incorrect_state_dict(self):

assert key_to_target in str(err_context.exception)

def test_bnb_4bit_logs_warning_for_no_quantization(self):
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
assert (
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in cap_logger.out
)


class BnB4BitTrainingTests(Base4bitTests):
def setUp(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def get_some_linear_layer(model):
if is_bitsandbytes_available():
import bitsandbytes as bnb

from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear


@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
Expand Down Expand Up @@ -317,6 +319,18 @@ def test_device_and_dtype_assignment(self):
# Check that this does not throw an error
_ = self.model_fp16.to(torch_device)

def test_bnb_8bit_logs_warning_for_no_quantization(self):
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
assert (
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
in cap_logger.out
)


class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None:
Expand Down
Loading