Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlexAttention? #1685

Open
johnnynunez opened this issue Feb 10, 2025 · 8 comments
Open

FlexAttention? #1685

johnnynunez opened this issue Feb 10, 2025 · 8 comments

Comments

@johnnynunez
Copy link

Is it compatible flexattention from pytorch 2.6.0?

@jainapurva
Copy link
Contributor

@drisspg

@drisspg
Copy link
Contributor

drisspg commented Feb 10, 2025

Can you add some more context here @johnnynunez

@johnnynunez
Copy link
Author

johnnynunez commented Feb 10, 2025

Can you add some more context here @johnnynunez

I want to quantize the lerobot pizero model, which has FlexAttention. @drisspg
https://huggingface.co/blog/pi0

context:
In the future, we plan on extending this support to allow for quantized versions of attention or things like RadixAttention as well.

https://pytorch.org/blog/flexattention/

@drisspg
Copy link
Contributor

drisspg commented Feb 10, 2025

So currently all of our quantization APIs target linear layers and are orthogonal to flex attention. Therefore, yes, flex attention should work. Flex-Tension currently doesn't support low precision inputs, however, that is planned - no ETA yet though

@johnnynunez
Copy link
Author

So currently all of our quantization APIs target linear layers and are orthogonal to flex attention. Therefore, yes, flex attention should work. Flex-Tension currently doesn't support low precision inputs, however, that is planned - no ETA yet though

thanks! I'm going to try

@drisspg
Copy link
Contributor

drisspg commented Feb 10, 2025

Let me know if anything comes up!

@moinnadeem
Copy link

Flex-Tension currently doesn't support low precision inputs, however, that is planned - no ETA yet though

@drisspg this is a good point -- what will happen with a low-precision input? will it get upcast to bf16 for the actual matmul? if so, are you basically seeing VRAM savings but no time savings?

@drisspg
Copy link
Contributor

drisspg commented Feb 18, 2025

I have an example for doing this and @danielvegamyhre is starting to investigate and ultimately make this a well supported path.

For an fp8 mm the mm will be in low precision and in theory can utilize tensors cores w/ fp8 support on H100 and accumulated into high

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants