Skip to content

Differences between memory efficient attention and vanilla eager attention #9

@nbui-sc

Description

@nbui-sc

I'm trying to use your implementation of memory efficient attention (and pure Pytorch FlashAttention implementation as well) to replace the eager implementation of Qwen3 as in this function.

https://github.com/huggingface/transformers/blob/0e1c2817455602d182bd8ebf5fba212e14fb187e/src/transformers/models/qwen3/modeling_qwen3.py#L135

But I noticed that there's a difference between the output of the vanilla eager implementation and your implementation. For one layer, the differences could be up to 1e-2 for one layer.

Image

It will be accumulated after multiple layers and eventually lead to a very weird output. I wonder if you have some insights on the place that caused the error? When I debug with a very small input (everything fits one chunk) and check the output at each step, the differences start from the softmax computation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions