diff --git a/docs/operators/ALiBiMask.md b/docs/operators/ALiBiMask.md index afdc3a3..33eca0f 100644 --- a/docs/operators/ALiBiMask.md +++ b/docs/operators/ALiBiMask.md @@ -22,7 +22,7 @@ The following code shows the calculation process of mask. alibi_mask = torch.full((seqlen_q, seqlen_kv), float('inf'), dtype=data_type) for i in range(seqlen_q-1, -1, -1): for j in range(seqlen_kv): - mask = j - seqlen_kv + 1 + (seqlen_q - 1 - i) + mask = j - i + (seqlen_q - seqlen_kv) if mask <= 0: alibi_mask[i][j] = mask alibi_mask = alibi_mask.unsqueeze(0).expand(num_heads, -1, -1)