Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/20250422-new-kernel-deep-dive.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ To fully utilize GPU compute resources, we need to overlap CUDA Core operations
Our solution involves an additional mathematical transformation beyond FlashAttention's online softmax and accumulation approach. In each step, we take two KV blocks (called $K_0$, $K_1$, $V_0$, and $V_1$). Since the output matrix occupies 32,768 registers (too many for one warpgroup), we split it vertically into $O_L$ and $O_R$ (each $64 \times 256$). We similarly split $V_0$ and $V_1$ into $V_{0L}$, $V_{0R}$, $V_{1L}$, and $V_{1R}$ (each $64 \times 256$). The output matrix is then computed as follows:

0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0).
1. [0] Compute $`\vec p_0 = \vec q K_0^\intercal / qk\_scale`$.
2. [1] Compute $`\vec p_1 = \vec q K_1^\intercal / qk\_scale`$.
3. [0] Compute $mp_0 = \max(\vec p_0)$, $`m\_new_0 = \max(m, mp_0)`$, and $`scale_0 = \exp(m\_new_0 - m)`$. Update $`m \gets m\_new_0`$.
4. [0] Perform softmax on $\vec p_0$: $`\vec p_0 \gets \exp(\vec p_0 - m\_new_0)`$.
1. [0] Compute $`\vec p_0 = \vec q K_0^\intercal / \mathit{qk\_scale}`$.
2. [1] Compute $`\vec p_1 = \vec q K_1^\intercal / \mathit{qk\_scale}`$.
3. [0] Compute $mp_0 = \max(\vec p_0)$, $`\mathit{m\_new_0} = \max(m, mp_0)`$, and $`scale_0 = \exp(\mathit{m\_new_0} - m)`$. Update $`m \gets \mathit{m\_new_0}`$.
4. [0] Perform softmax on $\vec p_0$: $`\vec p_0 \gets \exp(\vec p_0 - \mathit{m\_new_0})`$.
5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$.
6. [1] Compute $mp_1 = \max(\vec p_1)$, $`m\_new_1 = \max(m, mp_1)`$, and $`scale_1 = \exp(m\_new_1 - m)`$. Update $`m \gets m\_new_1`$.
7. [1] Perform softmax on $\vec p_1$: $`\vec p_1 \gets \exp(\vec p_1 - m\_new_1)`$.
6. [1] Compute $mp_1 = \max(\vec p_1)$, $`\mathit{m\_new_1} = \max(m, mp_1)`$, and $`scale_1 = \exp(\mathit{m\_new_1} - m)`$. Update $`m \gets \mathit{m\_new_1}`$.
7. [1] Perform softmax on $\vec p_1$: $`\vec p_1 \gets \exp(\vec p_1 - \mathit{m\_new_1})`$.
8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$.
9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$.
10. [1] Update $\vec o_R \gets \vec o_R + \vec p_0 V_{0R}$.
Expand Down