Skip to content

Commit 7dc52ea

Browse files
authored
[Quantization] dtype fix for GGUF + fix BnB tests (#11159)
* update * update * update * update
1 parent 739d6ec commit 7dc52ea

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
282282
if quantization_config is not None:
283283
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
284284
hf_quantizer.validate_environment()
285+
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
285286

286287
else:
287288
hf_quantizer = None

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,16 @@ class Base8bitTests(unittest.TestCase):
9090

9191
def get_dummy_inputs(self):
9292
prompt_embeds = load_pt(
93-
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
93+
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
94+
map_location="cpu",
9495
)
9596
pooled_prompt_embeds = load_pt(
96-
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
97+
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
98+
map_location="cpu",
9799
)
98100
latent_model_input = load_pt(
99-
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
101+
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
102+
map_location="cpu",
100103
)
101104

102105
input_dict_for_transformer = {

tests/quantization/gguf/test_gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_gguf_linear_layers(self):
5757
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
5858
assert module.weight.dtype == torch.uint8
5959
if module.bias is not None:
60-
assert module.bias.dtype == torch.float32
60+
assert module.bias.dtype == self.torch_dtype
6161

6262
def test_gguf_memory_usage(self):
6363
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)

0 commit comments

Comments
 (0)