Skip to content

FSDP2 integration: torch.chunks(Params4bit) not returning Params4bit subclass #1424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mreso opened this issue Nov 21, 2024 · 2 comments · May be fixed by #1612
Open

FSDP2 integration: torch.chunks(Params4bit) not returning Params4bit subclass #1424

mreso opened this issue Nov 21, 2024 · 2 comments · May be fixed by #1612
Labels
Bug Something isn't working Contributions Welcome We welcome contributions to fix this issue! FSDP High Priority (first issues that will be worked on)

Comments

@mreso
Copy link

mreso commented Nov 21, 2024

System Info

Hi, I'am trying to make FSDP2 work with a llama model quantized with bitsandbytes but it seems that bitsandbytes' tensor subclasses like Params4bit are not compatible with the way FSDP2 shards the model.
When creating the DTensors to shard the model FSDP2 applies torch.chunk to the parameters which get returned by torch.chunk as ordinary Tensors instead of the original subclass (like Params4bit) which leads to errors down the line.

Is this a known issue and are there plans to make bitsandbytes composable with FSDP2?

Reproduction

Created a simple repro:

import torch
import torch.nn as nn

import bitsandbytes as bnb
from bitsandbytes.nn import Params4bit
        
blocksize=64
compress_statistics = True
quant_type = "fp4"
quant_storage=torch.uint8

w = torch.ones(4).to("cuda")

w_4bit, quant_state = bnb.functional.quantize_4bit(
    w,
    blocksize=blocksize,
    compress_statistics=compress_statistics,
    quant_type=quant_type,
    quant_storage=quant_storage,
    )

b = Params4bit.from_prequantized(w_4bit, quant_state.as_dict(packed=True))
print(f"{b=}")

chunks = torch.chunk(b, 2, dim=0)

print(f"{chunks=}")

Output:

b=Parameter containing:
Parameter(Params4bit([[51],
            [51]], device='cuda:0', dtype=torch.uint8))
chunks=(tensor([[51]], device='cuda:0', dtype=torch.uint8), tensor([[51]], device='cuda:0', dtype=torch.uint8))

Expected behavior

Expecting the output of torch.chunk to be a a tuple of Params4bits instead of a Tensors.

@matthewdouglas matthewdouglas added the Bug Something isn't working label Nov 26, 2024
@Titus-von-Koeller
Copy link
Collaborator

We’re prioritizing this with high prio, but our current focus is on tasks with even greater impact. Contributions are highly appreciated, and we’re happy to assist with anything needed along the way.

Let’s collaborate to address this as soon as possible. Thanks for taking the initiative and highlighting its importance, along with the limitations you’ve encountered.

@Titus-von-Koeller Titus-von-Koeller pinned this issue Dec 11, 2024
@Titus-von-Koeller Titus-von-Koeller changed the title torch.chunks(Params4bit) not returning Params4bit subclass FSDP2 integration: torch.chunks(Params4bit) not returning Params4bit subclass Dec 11, 2024
@Titus-von-Koeller Titus-von-Koeller added help wanted High Priority (first issues that will be worked on) FSDP labels Dec 11, 2024
@Titus-von-Koeller Titus-von-Koeller added Contributions Welcome We welcome contributions to fix this issue! and removed help wanted labels Apr 25, 2025
@ved1beta
Copy link
Contributor

working on it : )

@ved1beta ved1beta linked a pull request Apr 27, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug Something isn't working Contributions Welcome We welcome contributions to fix this issue! FSDP High Priority (first issues that will be worked on)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants