Skip to content

Commit 1a04812

Browse files
authored
[bistandbytes] improve replacement warnings for bnb (#11132)
* improve replacement warnings for bnb * updates to docs.
1 parent 4b27c4a commit 1a04812

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

src/diffusers/quantizers/bitsandbytes/utils.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
139139
models by reducing the precision of the weights and activations, thus making models more efficient in terms
140140
of both storage and computation.
141141
"""
142-
model, has_been_replaced = _replace_with_bnb_linear(
143-
model, modules_to_not_convert, current_key_name, quantization_config
144-
)
142+
model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
145143

144+
has_been_replaced = any(
145+
isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
146+
for _, replaced_module in model.named_modules()
147+
)
146148
if not has_been_replaced:
147149
logger.warning(
148150
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
@@ -283,16 +285,18 @@ def dequantize_and_replace(
283285
modules_to_not_convert=None,
284286
quantization_config=None,
285287
):
286-
model, has_been_replaced = _dequantize_and_replace(
288+
model, _ = _dequantize_and_replace(
287289
model,
288290
dtype=model.dtype,
289291
modules_to_not_convert=modules_to_not_convert,
290292
quantization_config=quantization_config,
291293
)
292-
294+
has_been_replaced = any(
295+
isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
296+
)
293297
if not has_been_replaced:
294298
logger.warning(
295-
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
299+
"Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model."
296300
)
297301

298302
return model

tests/quantization/bnb/test_4bit.py

+14
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def get_some_linear_layer(model):
7070
if is_bitsandbytes_available():
7171
import bitsandbytes as bnb
7272

73+
from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear
74+
7375

7476
@require_bitsandbytes_version_greater("0.43.2")
7577
@require_accelerate
@@ -371,6 +373,18 @@ def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
371373

372374
assert key_to_target in str(err_context.exception)
373375

376+
def test_bnb_4bit_logs_warning_for_no_quantization(self):
377+
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
378+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
379+
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
380+
logger.setLevel(30)
381+
with CaptureLogger(logger) as cap_logger:
382+
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
383+
assert (
384+
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
385+
in cap_logger.out
386+
)
387+
374388

375389
class BnB4BitTrainingTests(Base4bitTests):
376390
def setUp(self):

tests/quantization/bnb/test_mixed_int8.py

+14
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def get_some_linear_layer(model):
6868
if is_bitsandbytes_available():
6969
import bitsandbytes as bnb
7070

71+
from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear
72+
7173

7274
@require_bitsandbytes_version_greater("0.43.2")
7375
@require_accelerate
@@ -317,6 +319,18 @@ def test_device_and_dtype_assignment(self):
317319
# Check that this does not throw an error
318320
_ = self.model_fp16.to(torch_device)
319321

322+
def test_bnb_8bit_logs_warning_for_no_quantization(self):
323+
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
324+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
325+
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
326+
logger.setLevel(30)
327+
with CaptureLogger(logger) as cap_logger:
328+
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
329+
assert (
330+
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
331+
in cap_logger.out
332+
)
333+
320334

321335
class Bnb8bitDeviceTests(Base8bitTests):
322336
def setUp(self) -> None:

0 commit comments

Comments
 (0)