-
Notifications
You must be signed in to change notification settings - Fork 0
Enable ViT torch.compile + CUDA Graph #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: mlperf-inf-mm-q3vl-v6.0
Are you sure you want to change the base?
Changes from all commits
a997f97
b5886e9
ccbeba9
062ceea
7d70346
e1019e3
b31bca9
d0d63e3
bb32c23
bcc72a4
35becaa
9a7e47b
f7af48a
a516b15
3490b83
6aea329
4c1f1a0
7ca3136
cf04736
545e478
a2b474d
c3a025f
722ff9d
c981904
b479627
14ec7d5
fa3242c
11bdbe6
7558fcf
99db6e0
cc493e2
74879cd
fbdfca8
b8dc719
d164dfa
2851732
6c2b3ad
618df68
f803a22
d5dc124
70345f0
27a43c9
05154eb
f7c9c08
d98928c
0a14bf4
36926b1
2e8af6e
a62d131
18aab18
046127d
c5021d1
86efb8a
53f0cd8
936c902
eb46592
4d17522
a1ac7b9
a1ddd01
864a172
0a8d84c
85012a6
a7040d8
ed906f2
661e2a3
7ed8810
7e6dbea
c702983
23d1d6a
cd3f613
b1836d3
97dbf86
124b893
4fa0971
f5292d2
9eec91d
b0feebd
3458b7b
f2a8f3d
48fb275
83a0380
b81d3ce
cccd01d
f595666
f987139
a7650be
432ad03
7356f51
86cb2f7
0aae8fb
5b6176b
ce5a1ed
727cb88
7df6054
f873370
226a42b
e165ef4
26b1b8c
992dd7e
240306e
1fcd767
e63d608
0328e22
f98c7d0
87480b8
f8f5e25
34cccff
6473b92
171466a
571c0f4
142745b
46e4431
ab88f9b
86c96fe
67d828a
0e61942
e87d4d8
78a2eb4
7571794
9fd17c7
7978645
fd11d38
186406c
2b28ae9
9a957c9
6323e3e
ff74d0b
96a1dd6
ee90e73
dba634a
14e8b15
ea90f03
d0f9807
00c03e3
41ff604
f918ee9
3296099
cd453be
a13707b
e452402
f93a3d6
e06be38
f1db17e
6b64707
4fe2061
fbd14d1
3d2b60e
24bfc22
df710a8
f5142a2
36c346b
acdd256
24eb4a5
735d36b
9441d5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -438,6 +438,91 @@ class CompilationConfig: | |
| on selected platforms. Disabled by default until more models | ||
| are supported/tested to work.""" | ||
|
|
||
| # Encoder (ViT) CUDA graph settings | ||
| cudagraph_mm_encoder: bool = False | ||
| """Whether to enable CUDA graph capture for multimodal encoders (ViT). | ||
| When enabled, CUDA graphs are captured for the vision encoder to eliminate | ||
| kernel launch overhead. Requires fixed input sizes via bucketing. | ||
| Experimental feature - use with caution.""" | ||
|
|
||
| encoder_cudagraph_bucket_sizes: list[int] | None = None | ||
| """Square grid side lengths for padded CUDA graph execution. Each size N | ||
| creates a bucket grid (1, N, N). Inputs with max(H, W) <= N are padded to | ||
| fit the bucket. Example: [32, 64, 94, 128, 188, 256, 312] captures grids | ||
| (1, 32, 32), (1, 64, 64), etc. Used with encoder_cudagraph_padded_mode=True.""" | ||
|
|
||
| encoder_cudagraph_grid_configs: list[tuple[int, int, int]] | str | None = None | ||
| """Grid configurations (T, H, W in patch units) for exact-match CUDA graph | ||
| capture. Can be a list of tuples or preset "custom" (top 30 most common grids, | ||
| 58.9% exact match coverage). If None, uses "custom" as default.""" | ||
|
|
||
| encoder_cudagraph_padded_mode: bool = True | ||
| """Whether to use padded execution for encoder CUDA graphs. | ||
| When True, inputs smaller than a captured bucket are padded to fit. | ||
| Padded: pixel_values, pos_embeds, rotary_embeds (with zeros). | ||
| NOT padded: cu_seqlens, max_seqlen (set to actual values so flash | ||
| attention only processes real tokens). Output is trimmed to actual size. | ||
| When False, only exact grid matches use CUDA graphs.""" | ||
|
|
||
| encoder_cudagraph_max_grid_size: int = 256 | ||
| """Maximum grid dimension (H or W) for encoder CUDA graph capture. | ||
| Grids with H > max or W > max are skipped to limit GPU memory usage. | ||
| Memory scales roughly with H*W: | ||
| - 128x128: ~0.8 GiB | ||
| - 188x188: ~1.7 GiB | ||
| - 256x256: ~3.2 GiB | ||
| Set lower (e.g., 128, 188, 218) on memory-constrained systems. | ||
| Default 256 captures all grids in CUSTOM_GRID_CONFIGS.""" | ||
|
|
||
| encoder_cudagraph_verbose: bool = False | ||
| """Enable verbose logging for encoder CUDA graph execution. | ||
| When True, logs each ViT input size and CUDA graph hit/miss/padded status. | ||
| Useful for debugging and analyzing CUDA graph utilization. | ||
| When False, only logs summary stats at the end of execution.""" | ||
|
|
||
| encoder_cudagraph_one_by_one: bool = True | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to launch multiple CUDA graphs for multi-image batches? I know we may need multiple sets of input buffers and need to handle synchronization per stream. But maybe we could find a common batch size that reoccurs heavily that could justifies this cost?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. launch multiple cuda graphs for multi-image batches - we already do this, just each cuda graph haddle one image in the batc. find a common batch size that reoccurs - I'm not sure the benchmark reuses a certain batch sizes, need to check. |
||
| """Enable one-by-one image processing for multi-image batches. | ||
| When True (default), multi-image batches are processed individually to | ||
| maximize CUDA graph hit rate. | ||
| When False, multi-image batches are processed together in eager mode, | ||
| which may be faster when CUDA graph overhead (sync, memory) outweighs | ||
| the kernel launch savings. | ||
| Set to False if you observe throughput regression with encoder CUDA graphs.""" | ||
|
|
||
| encoder_cudagraph_batch_sizes: list[int] | None = None | ||
| """Batch sizes for grouped batched CUDA graph capture. | ||
| When set (e.g., [4]), captures graphs for processing multiple images | ||
| together. Images are grouped by similar grid sizes and padded to the | ||
| largest grid in each group. Single graph replay for the whole group. | ||
| Example: [4] captures batch_size=4 graphs only (1-3 images use eager). | ||
| Default None uses legacy one-by-one mode (batch_size=1 per image).""" | ||
|
|
||
| encoder_cudagraph_piecewise: bool = False | ||
| """Enable piecewise CUDA graph mode for encoder (ViT). | ||
| When True, torch.compile splits the encoder graph at attention ops, so: | ||
| - Non-attention ops (norm, MLP, patch_embed, merger) are captured in CUDA graphs | ||
| - Attention ops run in eager mode with original batch structure | ||
| This allows batching multiple images together while still benefiting from | ||
| CUDA graphs for the non-attention parts. More efficient than one-by-one | ||
| processing when batch sizes vary. | ||
| Requires compile_mm_encoder=True. Mutually exclusive with cudagraph_mm_encoder.""" | ||
|
|
||
| encoder_cudagraph_capture_sizes: list[int] | None = None | ||
| """CUDA graph capture sizes (token counts) for encoder piecewise mode. | ||
| These are the total token counts at which CUDA graphs are captured. | ||
| For Qwen3-VL with spatial_merge_size=2: | ||
| - (1, 32, 32) grid → 1024 patches → 256 output tokens | ||
| - (1, 64, 64) grid → 4096 patches → 1024 output tokens | ||
| - (1, 94, 94) grid → 8836 patches → 2209 output tokens | ||
| Example: [256, 512, 1024, 2048, 4096, 8192, 16384] | ||
| If None, encoder piecewise mode will use compile_ranges only (no cudagraph).""" | ||
|
|
||
| encoder_spatial_merge_size: int = 2 | ||
| """Spatial merge size for vision encoder (e.g., 2 for Qwen3-VL). | ||
| This converts encoder_cudagraph_capture_sizes (output tokens) to input patches. | ||
| Input patches = output tokens * spatial_merge_size². | ||
| Default is 2, which is common for Qwen-VL family models.""" | ||
|
|
||
| # Inductor capture | ||
| compile_sizes: list[int | str] | None = None | ||
| """Sizes to compile for inductor. In addition | ||
|
|
@@ -622,6 +707,15 @@ class CompilationConfig: | |
| "vllm::sparse_attn_indexer", | ||
| ] | ||
|
|
||
| # Encoder (ViT) attention ops; used for piecewise cudagraphs on encoders | ||
| # These ops depend on batch structure (cu_seqlens), so they must be | ||
| # excluded from cudagraph capture to allow batching multiple images. | ||
| _encoder_attention_ops: ClassVar[list[str]] = [ | ||
| "vllm::flash_attn_maxseqlen_wrapper", | ||
| "vllm::fa4_flash_attn_maxseqlen_wrapper", | ||
| "vllm::flashinfer_wrapper", | ||
| ] | ||
|
|
||
| def compute_hash(self) -> str: | ||
| """ | ||
| Provide a hash that uniquely identifies all the configs | ||
|
|
@@ -1023,6 +1117,15 @@ def splitting_ops_contain_attention(self) -> bool: | |
| op in self.splitting_ops for op in self._attention_ops | ||
| ) | ||
|
|
||
| def get_encoder_splitting_ops(self) -> list[str]: | ||
| """Get splitting ops for encoder (ViT) compilation. | ||
|
|
||
| For piecewise cudagraph on encoders, we split at attention ops | ||
| so that non-attention ops (norm, MLP) can be captured in cudagraphs | ||
| while attention runs in eager mode with batched images. | ||
| """ | ||
| return list(self._encoder_attention_ops) | ||
|
|
||
| def is_attention_compiled_piecewise(self) -> bool: | ||
| if not self.splitting_ops_contain_attention(): | ||
| return False | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a scenario where we don't need to pad? If not, I suggest removing this flag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This flag allows us to catch more grid size (i.e. token buckets) where exact match does not work. So far, enabling padding besides exact match has been giving better performance than exact match only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah so If that's the case I guess we can just use padding by default.