Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only random noise is generated with Flux + LoRA with optimum-quanto >= 0.2.5 #343

Open
nelapetrzelkova opened this issue Oct 30, 2024 · 2 comments

Comments

@nelapetrzelkova
Copy link

nelapetrzelkova commented Oct 30, 2024

Hello,
I am facing an issue with generating images with FLUX.1[dev] + LoRA that I trained with SimpleTuner. I need to be able to load the LoRAs dynamically, therefore I want to use the already quantized FLUX before the LoRA is loaded into it. With optimum-quanto version 0.2.4 and lower I got the following error: KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data’. After bumping the version to 0.2.5 or 0.2.6, no error is thrown but the results look like this:
noise

My code:

import torch
from diffusers import DiffusionPipeline
from optimum.quanto import freeze, qfloat8, quantize

model_id = 'black-forest-labs/FLUX.1-dev'
pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
quantize(pipeline.transformer, weights=qfloat8)
freeze(pipeline.transformer)
pipeline.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
lora_path = <path_to_lora>
pipeline.load_lora_weights(lora_path)

prompts = {"candy": "Candy bar surrounded by playful, abstract shapes resembling candy sprinkles and whimsical clouds of cream. The atmosphere is vibrant and joyful, filled with bright colors that evoke childhood memories of sweetness and fun. This imagery invites viewers to imagine the delight of savoring a piece of chocolate that brings happiness to any moment."}

seed = 19640904
for prompt_key, prompt_value in prompts.items():
    print(prompt_key, prompt_value)
    images = pipeline(
        prompt=prompt_value,
        num_inference_steps=10,
        num_images_per_prompt=1,
        generator=torch.Generator(device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu').manual_seed(seed),
        width=1024,
        height=1024,
    ).images
    # get key of the prompt
    for idx, image in enumerate(images):
        display(image)

Is there a way how to solve this? A workaround could be to load the LoRA into the model before quantization and save the quantized merged model and work with that, but I lose the benefit of working with the LoRA only, which is much faster and less memory expensive.

Thanks!

@tyyff
Copy link

tyyff commented Nov 7, 2024

Hello, I am facing an issue with generating images with FLUX.1[dev] + LoRA that I trained with SimpleTuner. I need to be able to load the LoRAs dynamically, therefore I want to use the already quantized FLUX before the LoRA is loaded into it. With optimum-quanto version 0.2.4 and lower I got the following error: KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data’. After bumping the version to 0.2.5 or 0.2.6, no error is thrown but the results look like this: noise

My code:

import torch
from diffusers import DiffusionPipeline
from optimum.quanto import freeze, qfloat8, quantize

model_id = 'black-forest-labs/FLUX.1-dev'
pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
quantize(pipeline.transformer, weights=qfloat8)
freeze(pipeline.transformer)
pipeline.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
lora_path = <path_to_lora>
pipeline.load_lora_weights(lora_path)

prompts = {"candy": "Candy bar surrounded by playful, abstract shapes resembling candy sprinkles and whimsical clouds of cream. The atmosphere is vibrant and joyful, filled with bright colors that evoke childhood memories of sweetness and fun. This imagery invites viewers to imagine the delight of savoring a piece of chocolate that brings happiness to any moment."}

seed = 19640904
for prompt_key, prompt_value in prompts.items():
    print(prompt_key, prompt_value)
    images = pipeline(
        prompt=prompt_value,
        num_inference_steps=10,
        num_images_per_prompt=1,
        generator=torch.Generator(device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu').manual_seed(seed),
        width=1024,
        height=1024,
    ).images
    # get key of the prompt
    for idx, image in enumerate(images):
        display(image)

Is there a way how to solve this? A workaround could be to load the LoRA into the model before quantization and save the quantized merged model and work with that, but I lose the benefit of working with the LoRA only, which is much faster and less memory expensive.

Thanks!

I encountered a similar issue. When using optimum-quanto==0.2.6 to quantize FLUX.1-schnell, the output also turned into random noise. After investigating, I found that the issue was caused by MarlinF8QBytesTensor.

To fix it, you can modify optimum/quanto/tensor/weights/qbytes.py.

Simply change the line:

and torch.cuda.get_device_capability(data.device)[0] >= 8
to:
and torch.cuda.get_device_capability(data.device)[0] >= 20
This resolved the problem for me.

    if (
        qtype == qtypes["qfloat8_e4m3fn"]
        and activation_qtype is None
        and scale.dtype in [torch.float16, torch.bfloat16]
        and len(size) == 2
        and (data.device.type == "cuda" and torch.version.cuda)
        and axis == 0
        and torch.cuda.get_device_capability(data.device)[0] >= 8
    ):
        out_features, in_features = size
        if (
            in_features >= 64
            and out_features >= 64
            and (
                (in_features % 64 == 0 and out_features % 128 == 0)
                or (in_features % 128 == 0 and out_features % 64 == 0)
            )
        ):
            return MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale, requires_grad)

But I don't know why MarlinF8QBytesTensor can‘t work. @dacorvo

@dacorvo
Copy link
Collaborator

dacorvo commented Nov 10, 2024

@tyyff thank you for investigating this. See also #332. There might be a general issue with Marlin kernels when the size of the tensors involved in the matmul increases (could be an overflow, could be some overlaps in the intermediate result buffers, I really don't know). I will disable the FP8 kernel for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants