-
Notifications
You must be signed in to change notification settings - Fork 35
Open
Description
I'm trying to use your implementation of memory efficient attention (and pure Pytorch FlashAttention implementation as well) to replace the eager implementation of Qwen3 as in this function.
But I noticed that there's a difference between the output of the vanilla eager implementation and your implementation. For one layer, the differences could be up to 1e-2 for one layer.

It will be accumulated after multiple layers and eventually lead to a very weird output. I wonder if you have some insights on the place that caused the error? When I debug with a very small input (everything fits one chunk) and check the output at each step, the differences start from the softmax computation.
Metadata
Metadata
Assignees
Labels
No labels