Skip to content

Attention fwd/bwd issues with all zero masks #204

@RuiWang1998

Description

@RuiWang1998

Hi,

We have been observing some issues with masking, namely that if all tokens are masked out. In this specific case, the output of attention will be twice of what native pytorch's implementation is, and the gradient of v value vector is simply garbage value.

To demonstrate the need for masking all tokens, consider the case where we land on no templates in template search but the computation graph needs to stay the same during training to avoid a bunch of issues including memory fragmentation.

Is there a way to fix this?

Best,
Rui

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions