Hi,
We have been observing some issues with masking, namely that if all tokens are masked out. In this specific case, the output of attention will be twice of what native pytorch's implementation is, and the gradient of v value vector is simply garbage value.
To demonstrate the need for masking all tokens, consider the case where we land on no templates in template search but the computation graph needs to stay the same during training to avoid a bunch of issues including memory fragmentation.
Is there a way to fix this?
Best,
Rui