Clarification on fused kernel behavior for FP8 vs. BF16 #1862
Unanswered
EricJKrebs
asked this question in
Q&A
Replies: 2 comments
-
|
Beta Was this translation helpful? Give feedback.
0 replies
-
Seems strange that fp8 would be slower than bf16. I'd expect fp8 fwd to be ~1.3x faster than bf16. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I've been creating profiles to compare performance of FP8 vs BF16 using the PyTorch Transformer Engine package on GPUs that support FP8. GEMMs seem pretty straightforward, but transformer engine's dot product attention uses fused flash attention kernels that I have a few questions about since I don't know how to profile into those.
The code I've been testing is a simple implementation of GPT2 training using Transformer Engine with "mixed" FP8 (E4M3 forward, E5M2 backward). During forward pass, the GPU timing I see using FP8 is dominated by an FP8 flash attention kernel, second being GEMM:
Running the same code with purely BF16, the dominant time during forward pass is from GEMMs, then float 16 flash attention kernels:
Time comparison for the GEMMs is about half for FP8 (which is expected), but higher for the FP8 flash attention kernel compared to BF16. I've changed sequence length, number of attention heads, per head dimension, and batch sizes looking at this, but I keep seeing similar results where flash attention time for FP8 is higher than BF16.
My questions regarding the flash attention kernels:
Any advice into gaining more insight?
Beta Was this translation helpful? Give feedback.
All reactions