diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index f06047be61b9..c79dce86224f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -412,6 +412,22 @@ def call_module( i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] + # Check if we should use piecewise backend for this compilation + # For encoder with encoder_cudagraph_piecewise=False, skip piecewise + # backend entirely to avoid shape tracking issues. The encoder will + # use torch.compile directly and EncoderCudaGraphManager handles + # full cudagraph capture separately. + encoder_skip_piecewise = self.vllm_backend.is_encoder and not getattr( + self.compilation_config, "encoder_cudagraph_piecewise", False + ) + + if encoder_skip_piecewise: + # For encoder without piecewise mode, just use the compiled + # submodule directly. EncoderCudaGraphManager will capture + # the full graph later. + self.module.__dict__[target] = submod + return output + # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend @@ -424,10 +440,13 @@ def call_module( self.vllm_backend, ) - if ( + # Check if we should use piecewise cudagraphs for this compilation + use_piecewise_cudagraph = ( self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() and not self.compilation_config.use_inductor_graph_partition - ): + ) + + if use_piecewise_cudagraph: # We're using Dynamo-based piecewise splitting, so we wrap # the whole subgraph with a static graph wrapper. from .cuda_graph import CUDAGraphOptions @@ -555,6 +574,13 @@ def __init__( # in future we need PostGradPassManager.uuid() to be executed # only at compile time. self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config) + + # Disable cache for encoder compilation to avoid assertion errors + # with simple graphs (e.g., Conv3d) that don't produce AOT artifacts. + # This skips the save in InductorStandaloneAdaptor.compile(). + if self.is_encoder: + self.inductor_config["force_disable_caches"] = True + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -716,6 +742,10 @@ def __call__( if self.compilation_config.use_inductor_graph_partition: # Let Inductor decide partitioning; avoid FX-level pre-splitting. fx_split_ops: list[str] = [] + elif self.is_encoder: + # For encoder compilation, use encoder-specific splitting ops + # to enable piecewise cudagraph (attention in eager, rest in graph) + fx_split_ops = self.compilation_config.get_encoder_splitting_ops() else: fx_split_ops = self.compilation_config.splitting_ops or [] diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 29d6f89990cd..e01c88e45aa4 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -72,8 +72,35 @@ def __init__( log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" logger.debug_once(log_string) - self.compile_sizes = self.compilation_config.compile_sizes - log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" + # Use encoder-specific capture sizes for encoder compilation + self.compile_sizes: list[Any] | None = None + if self.is_encoder_compilation: + encoder_capture_sizes = ( + self.compilation_config.encoder_cudagraph_capture_sizes + ) + if encoder_capture_sizes is not None: + # Convert from output tokens to input patches + # encoder_cudagraph_capture_sizes is specified in output tokens + # but runtime_shape (from sym_shape_indices) is in input patches + merge_size_sq = self.compilation_config.encoder_spatial_merge_size**2 + self.compile_sizes = [ + size * merge_size_sq for size in encoder_capture_sizes + ] + logger.debug_once( + "PiecewiseBackend: converted encoder capture sizes from " + "output tokens %s to input patches %s (merge_size²=%d)", + tuple(encoder_capture_sizes), + tuple(self.compile_sizes), + merge_size_sq, + ) + else: + self.compile_sizes = None + else: + self.compile_sizes = self.compilation_config.compile_sizes + log_string = ( + f"PiecewiseBackend: compile_sizes: {self.compile_sizes} " + f"(is_encoder={self.is_encoder_compilation})" + ) logger.debug_once(log_string) self.sym_shape_indices = sym_shape_indices @@ -143,15 +170,13 @@ def _maybe_compile_for_range_entry( range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) + is_exact_size = range_entry.compile_range.is_single_size() + # args are real arguments # fakify for range, real args for concrete size. # For concrete size, we clear the shape env in # compiler_manager.compile() so no need to fakify. - args_list = ( - self._fakify_args(args) - if not range_entry.compile_range.is_single_size() - else list(args) - ) + args_list = self._fakify_args(args) if not is_exact_size else list(args) range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args_list, @@ -169,23 +194,53 @@ def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None: # If not found, we search for the range entry # that contains the runtime shape. if self.compile_sizes is None: + logger.debug( + "PIECEWISE: compile_sizes is None, shape=%d, is_encoder=%s", + runtime_shape, + self.is_encoder_compilation, + ) return None if runtime_shape in self.compile_sizes: + # Exact match with capture size - will use cudagraph + logger.debug( + "PIECEWISE: exact match shape=%d in compile_sizes, is_encoder=%s", + runtime_shape, + self.is_encoder_compilation, + ) return self.range_entries[Range(start=runtime_shape, end=runtime_shape)] else: + # No exact match - fall back to compile_ranges (no cudagraph) for range in self.compile_ranges: if runtime_shape in range: + logger.debug( + "PIECEWISE: shape=%d not in compile_sizes, " + "using compile_range=%s (NO CUDAGRAPH), is_encoder=%s", + runtime_shape, + range, + self.is_encoder_compilation, + ) return self.range_entries[range] + # Shape not in any range - will cause assertion error + logger.warning( + "PIECEWISE: shape=%d not in compile_sizes=%s or " + "compile_ranges=%s, is_encoder=%s", + runtime_shape, + self.compile_sizes, + self.compile_ranges, + self.is_encoder_compilation, + ) return None def __call__(self, *args: Any) -> Any: runtime_shape = args[self.sym_shape_indices[0]] + range_entry = self._find_range_for_shape(runtime_shape) assert range_entry is not None, ( f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}" ) - self._maybe_compile_for_range_entry(range_entry, args) - return range_entry.runnable(*args) + self._maybe_compile_for_range_entry(range_entry, args) # type: ignore[arg-type] + + return range_entry.runnable(*args) # type: ignore[union-attr] diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 035aa24e33c7..5c015cb60823 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -438,6 +438,58 @@ class CompilationConfig: on selected platforms. Disabled by default until more models are supported/tested to work.""" + # Encoder (ViT) CUDA graph settings + cudagraph_mm_encoder: bool = False + """Whether to enable CUDA graph capture for multimodal encoders (ViT). + When enabled, CUDA graphs are captured for the vision encoder to eliminate + kernel launch overhead. Requires fixed input sizes via bucketing. + Experimental feature - use with caution.""" + + encoder_cudagraph_verbose: bool = False + """Enable verbose logging for encoder CUDA graph execution. + When True, logs each ViT input size and CUDA graph hit/miss status. + Useful for debugging and analyzing CUDA graph utilization.""" + + encoder_cudagraph_token_budgets: list[int] | None = None + """List of total output token budget levels for budget batch CUDA graphs. + E.g., [2048, 4096, 8192]. For each budget, one graph is captured with + max_images_per_batch image slots. At runtime, images are sorted + smallest-first and greedily packed; the smallest fitting budget graph is + selected. Works with FA2 and FA4 attention backends only. + Requires encoder_cudagraph_max_images_per_batch to also be set.""" + + encoder_cudagraph_max_images_per_batch: int | None = None + """Maximum number of images per budget batch. The captured CUDA graph + has fixed cu_seqlens of size max_images_per_batch + 1. Empty slots use + zero-length sequences (no-op in flash attention). Used together with + encoder_cudagraph_token_budgets.""" + + encoder_cudagraph_piecewise: bool = False + """Enable piecewise CUDA graph mode for encoder (ViT). + When True, torch.compile splits the encoder graph at attention ops, so: + - Non-attention ops (norm, MLP, patch_embed, merger) are captured in CUDA graphs + - Attention ops run in eager mode with original batch structure + This allows batching multiple images together while still benefiting from + CUDA graphs for the non-attention parts. More efficient than one-by-one + processing when batch sizes vary. + Requires compile_mm_encoder=True. Mutually exclusive with cudagraph_mm_encoder.""" + + encoder_cudagraph_capture_sizes: list[int] | None = None + """CUDA graph capture sizes (token counts) for encoder piecewise mode. + These are the total token counts at which CUDA graphs are captured. + For Qwen3-VL with spatial_merge_size=2: + - (1, 32, 32) grid → 1024 patches → 256 output tokens + - (1, 64, 64) grid → 4096 patches → 1024 output tokens + - (1, 94, 94) grid → 8836 patches → 2209 output tokens + Example: [256, 512, 1024, 2048, 4096, 8192, 16384] + If None, encoder piecewise mode will use compile_ranges only (no cudagraph).""" + + encoder_spatial_merge_size: int = 2 + """Spatial merge size for vision encoder (e.g., 2 for Qwen3-VL). + This converts encoder_cudagraph_capture_sizes (output tokens) to input patches. + Input patches = output tokens * spatial_merge_size². + Default is 2, which is common for Qwen-VL family models.""" + # Inductor capture compile_sizes: list[int | str] | None = None """Sizes to compile for inductor. In addition @@ -622,6 +674,15 @@ class CompilationConfig: "vllm::sparse_attn_indexer", ] + # Encoder (ViT) attention ops; used for piecewise cudagraphs on encoders + # These ops depend on batch structure (cu_seqlens), so they must be + # excluded from cudagraph capture to allow batching multiple images. + _encoder_attention_ops: ClassVar[list[str]] = [ + "vllm::flash_attn_maxseqlen_wrapper", + "vllm::fa4_flash_attn_maxseqlen_wrapper", + "vllm::flashinfer_wrapper", + ] + def compute_hash(self) -> str: """ Provide a hash that uniquely identifies all the configs @@ -1023,6 +1084,15 @@ def splitting_ops_contain_attention(self) -> bool: op in self.splitting_ops for op in self._attention_ops ) + def get_encoder_splitting_ops(self) -> list[str]: + """Get splitting ops for encoder (ViT) compilation. + + For piecewise cudagraph on encoders, we split at attention ops + so that non-attention ops (norm, MLP) can be captured in cudagraphs + while attention runs in eager mode with batched images. + """ + return list(self._encoder_attention_ops) + def is_attention_compiled_piecewise(self) -> bool: if not self.splitting_ops_contain_attention(): return False diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0181cb1f086e..c99f8009c565 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1249,6 +1249,7 @@ def _set_compile_ranges(self): and x > 1 ): computed_compile_ranges_split_points.append(x) + compilation_config.compile_ranges_split_points = sorted( computed_compile_ranges_split_points ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 4f8e694d75cb..7495d14102ce 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -686,9 +686,11 @@ def forward( if isinstance(grid_thw, list): grid_thw_list = grid_thw grid_thw = np.array(grid_thw, dtype=np.int32) + elif isinstance(grid_thw, np.ndarray): + grid_thw_list = grid_thw.tolist() else: grid_thw_list = grid_thw.tolist() - grid_thw = grid_thw.numpy() + grid_thw = grid_thw.cpu().numpy() # compute position embedding rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 707e0ccfd3c5..ef8c56c15bd5 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -144,10 +144,10 @@ def forward( # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) - q = q_by_head.view(q.shape) + q = q_by_head.flatten(-2) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head) - k = k_by_head.view(k.shape) + k = k_by_head.flatten(-2) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1f1ee2f56219..2fb3be4cf601 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -52,6 +52,7 @@ from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_pp_group, parallel_state +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.conv import Conv3dLayer @@ -65,6 +66,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.vision import should_torch_compile_mm_vit from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.evs import ( compute_mrope_for_media, @@ -92,6 +94,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils.collection_utils import is_list_of from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.worker.gpu.mm.encoder_cudagraph import EMBEDDING_WARMUP_GRIDS from .interfaces import ( MultiModalEmbeddings, @@ -135,7 +138,14 @@ # This avoids creating a new graph for each unique batch size at runtime BATCH_BUCKETS = [8, 16, 32, 64] +# Set of pre-warmed grids for O(1) lookup in embedding cache +_EMBEDDING_WARMUP_GRIDS_SET: set[tuple[int, int, int]] = set(EMBEDDING_WARMUP_GRIDS) + +@support_torch_compile( + dynamic_arg_dims={"x": 0}, + enable_if=should_torch_compile_mm_vit, +) class Qwen3_VisionPatchEmbed(nn.Module): def __init__( self, @@ -207,6 +217,18 @@ def forward(self, x: torch.Tensor): return mlp_output +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + "cu_seqlens": 0, + "rotary_pos_emb_cos": 0, + "rotary_pos_emb_sin": 0, + "max_seqlen": 0, + "sequence_lengths": 0, # Batch dimension is dynamic + }, + mark_unbacked_dims={"max_seqlen": 0}, + enable_if=should_torch_compile_mm_vit, +) class Qwen3_VisionBlock(nn.Module): def __init__( self, @@ -266,6 +288,10 @@ def forward( return x +@support_torch_compile( + dynamic_arg_dims={"x": 0}, + enable_if=should_torch_compile_mm_vit, +) class Qwen3_VisionPatchMerger(nn.Module): def __init__( self, @@ -300,6 +326,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.linear_fc1", disable_tp=use_data_parallel, + return_bias=False, ) self.act_fn = nn.GELU() self.linear_fc2 = RowParallelLinear( @@ -309,6 +336,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.linear_fc2", disable_tp=use_data_parallel, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -317,9 +345,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: x = self.norm(x).view(-1, self.hidden_size) - x_parallel, _ = self.linear_fc1(x) + x_parallel = self.linear_fc1(x) x_parallel = self.act_fn(x_parallel) - out, _ = self.linear_fc2(x_parallel) + out = self.linear_fc2(x_parallel) return out @@ -360,12 +388,15 @@ def __init__( 1 + len(self.deepstack_visual_indexes) ) - self.patch_embed = Qwen3_VisionPatchEmbed( - patch_size=self.patch_size, - temporal_patch_size=self.temporal_patch_size, - in_channels=vision_config.in_channels, - hidden_size=self.hidden_size, - ) + from vllm.compilation.backends import set_model_tag + + with set_model_tag("Qwen3_VisionPatchEmbed", is_encoder=True): + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) @@ -378,31 +409,37 @@ def __init__( rope_parameters={"partial_rotary_factor": 0.5}, ) - self.merger = Qwen3_VisionPatchMerger( - d_model=vision_config.out_hidden_size, - context_dim=self.hidden_size, - norm_layer=norm_layer, - spatial_merge_size=self.spatial_merge_size, - quant_config=quant_config, - multimodal_config=multimodal_config, - prefix=f"{prefix}.merger", - ) + with set_model_tag("Qwen3_VisionPatchMerger", is_encoder=True): + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=f"{prefix}.merger", + ) - self.deepstack_merger_list = nn.ModuleList( - [ - Qwen3_VisionPatchMerger( - d_model=vision_config.out_hidden_size, - context_dim=self.hidden_size, - spatial_merge_size=self.spatial_merge_size, - use_postshuffle_norm=True, - norm_layer=norm_layer, - quant_config=quant_config, - multimodal_config=multimodal_config, - prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", - ) - for layer_idx in range(len(self.deepstack_visual_indexes)) - ] - ) + with set_model_tag("Qwen3_VisionPatchMerger_postshuffle_norm", is_encoder=True): + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", + ) + for layer_idx in range(len(self.deepstack_visual_indexes)) + ] + ) + + # Per-grid embedding cache for eager mode optimization + # Key: (t, h, w), Value: dict with pos_embeds, rotary_cos, rotary_sin + self._embedding_cache: dict[tuple[int, int, int], dict[str, torch.Tensor]] = {} attn_backend_override = ( multimodal_config.mm_encoder_attn_backend if multimodal_config else None @@ -424,28 +461,31 @@ def __init__( f"Qwen3-VL does not support {self.attn_backend} backend now." ) - workspace_buffer = ( - None - if self.attn_backend != AttentionBackendEnum.FLASHINFER - else torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device=self.device) - ) - - self.blocks = nn.ModuleList( - [ - Qwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - multimodal_config=multimodal_config, - prefix=f"{prefix}.blocks.{layer_idx}", - workspace_buffer=workspace_buffer, + with set_model_tag("Qwen3_VisionBlock", is_encoder=True): + workspace_buffer = ( + None + if self.attn_backend != AttentionBackendEnum.FLASHINFER + else torch.zeros( + 128 * 1024 * 1024, dtype=torch.uint8, device=self.device ) - for layer_idx in range(vision_config.depth) - ] - ) + ) + + self.blocks = nn.ModuleList( + [ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=f"{prefix}.blocks.{layer_idx}", + workspace_buffer=workspace_buffer, + ) + for layer_idx in range(vision_config.depth) + ] + ) @property def dtype(self) -> torch.dtype: @@ -617,6 +657,76 @@ def compute_flashinfer_cu_seqlens( ) return np.concatenate([cu_seqlens_qk, cu_seqlens_v, cu_seqlens_o]) + def _get_cached_embeddings( + self, grid_thw_list: list[list[int]] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Get position and rotary embeddings with per-grid caching. + + This method caches embeddings only for grids in EMBEDDING_WARMUP_GRIDS + to avoid unbounded memory growth. Grids not in the warmup set are + computed on-the-fly without caching. + + Args: + grid_thw_list: List of [T, H, W] for each image + + Returns: + Tuple of (pos_embeds, rotary_cos, rotary_sin) + """ + pos_embeds_list: list[torch.Tensor] = [] + rotary_cos_list: list[torch.Tensor] = [] + rotary_sin_list: list[torch.Tensor] = [] + + for grid in grid_thw_list: + t, h, w = grid + grid_key = (t, h, w) + + if grid_key in self._embedding_cache: + # Cache hit - use cached embeddings + cached = self._embedding_cache[grid_key] + pos_embeds_list.append(cached["pos_embeds"]) + rotary_cos_list.append(cached["rotary_cos"]) + rotary_sin_list.append(cached["rotary_sin"]) + else: + # Cache miss - compute embeddings + single_grid = [[t, h, w]] + pos_embed = self.fast_pos_embed_interpolate(single_grid) + rotary_cos, rotary_sin = self.rot_pos_emb(single_grid) + + # Only cache if grid is in pre-warmed set to prevent OOM. + # Caching at runtime causes unbounded memory growth. + if grid_key in _EMBEDDING_WARMUP_GRIDS_SET: + self._embedding_cache[grid_key] = { + "pos_embeds": pos_embed, + "rotary_cos": rotary_cos, + "rotary_sin": rotary_sin, + } + + pos_embeds_list.append(pos_embed) + rotary_cos_list.append(rotary_cos) + rotary_sin_list.append(rotary_sin) + + # Concatenate all embeddings + pos_embeds = torch.cat(pos_embeds_list, dim=0) + rotary_pos_emb_cos = torch.cat(rotary_cos_list, dim=0) + rotary_pos_emb_sin = torch.cat(rotary_sin_list, dim=0) + + return pos_embeds, rotary_pos_emb_cos, rotary_pos_emb_sin + + def get_embedding_cache_memory(self) -> int: + """ + Compute the total GPU memory consumption of the embedding cache. + + Returns: + Total memory in bytes used by all cached embeddings. + """ + total_bytes = 0 + for grid, cached in self._embedding_cache.items(): + for key, tensor in cached.items(): + if isinstance(tensor, torch.Tensor): + total_bytes += tensor.numel() * tensor.element_size() + return total_bytes + def forward( self, x: torch.Tensor, @@ -628,13 +738,17 @@ def forward( if isinstance(grid_thw, list): grid_thw_list = grid_thw grid_thw = np.array(grid_thw, dtype=np.int32) + elif isinstance(grid_thw, np.ndarray): + grid_thw_list = grid_thw.tolist() else: grid_thw_list = grid_thw.tolist() - grid_thw = grid_thw.numpy() + grid_thw = grid_thw.cpu().numpy() - pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) + # Get embeddings with caching for eager mode optimization + pos_embeds, rotary_pos_emb_cos, rotary_pos_emb_sin = ( + self._get_cached_embeddings(grid_thw_list) + ) hidden_states = hidden_states + pos_embeds - rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( axis=0, dtype=np.int32 @@ -683,6 +797,204 @@ def forward( ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] return hidden_states + def forward_cudagraph( + self, + x: torch.Tensor, + pos_embeds: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + sequence_lengths: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass optimized for CUDA graph capture/replay. + + This method accepts pre-computed position embeddings, rotary embeddings, + and cumulative sequence lengths to avoid CPU operations during CUDA graph + replay. All tensor arguments must be on the correct device. + + Args: + x: Input pixel values [num_patches, patch_channels] + pos_embeds: Pre-computed position embeddings [num_patches, hidden_size] + rotary_pos_emb_cos: Pre-computed rotary cosine embeddings + rotary_pos_emb_sin: Pre-computed rotary sine embeddings + cu_seqlens: Pre-computed cumulative sequence lengths (on GPU) + max_seqlen: Pre-computed max sequence length (scalar tensor on GPU) + sequence_lengths: Pre-computed sequence lengths (for FlashInfer CuDNN) + + Returns: + Vision encoder output tensor + """ + # Patch embedding (GPU operation) + hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) + hidden_states = self.patch_embed(hidden_states) + + # Add pre-computed position embeddings + hidden_states = hidden_states + pos_embeds + + hidden_states = hidden_states.unsqueeze(1) + + # Run through transformer blocks with pre-computed values + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat([hidden_states] + deepstack_feature_lists, dim=1) + return hidden_states + + def precompute_for_cudagraph( + self, + grid_thw: list[list[int]], + ) -> dict[str, torch.Tensor]: + """ + Pre-compute all grid-dependent tensors for CUDA graph capture. + + This method computes position embeddings, rotary embeddings, and + cumulative sequence lengths that are fixed for a given grid configuration. + These can be cached and reused during CUDA graph replay. + + Args: + grid_thw: List of [T, H, W] for each image + + Returns: + Dict containing pre-computed tensors: + - pos_embeds: Position embeddings + - rotary_pos_emb_cos: Rotary cosine embeddings + - rotary_pos_emb_sin: Rotary sine embeddings + - cu_seqlens: Cumulative sequence lengths (on GPU) + - max_seqlen: Maximum sequence length (scalar tensor on GPU) + """ + # Compute position embeddings + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + + # Compute rotary embeddings + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw) + + # Compute cumulative sequence lengths + grid_thw_np = np.array(grid_thw, dtype=np.int32) + cu_seqlens = np.repeat( + grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0] + ).cumsum(axis=0, dtype=np.int32) + cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) + sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + if self.attn_backend == AttentionBackendEnum.FLASHINFER: + sequence_lengths = self.add_padding_to_fi_seqlens( + sequence_lengths, len(sequence_lengths), 0 + ) + cu_seqlens = self.compute_flashinfer_cu_seqlens( + cu_seqlens, rotary_pos_emb_cos, rotary_pos_emb_sin + ) + cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True) + sequence_lengths = torch.from_numpy(sequence_lengths).to( + self.device, non_blocking=True + ) + + # Compute max sequence length as CPU scalar tensor + # Using CPU tensor is important for CUDA graph capture: .item() on CPU + # tensor doesn't trigger GPU sync, so it won't invalidate capture. + max_seqlen_gpu = ( + torch.tensor(128 * 1024, device=self.device) + # setting to 128k to avoid cudnn recompilation + # TODO: use the real max_seqlen once cudnn compilation is optimized + if self.attn_backend == AttentionBackendEnum.FLASHINFER + else self.compute_attn_mask_seqlen(cu_seqlens) + ) + max_seqlen = max_seqlen_gpu.cpu() # Move to CPU to avoid GPU sync on .item() + + return { + "pos_embeds": pos_embeds, + "rotary_pos_emb_cos": rotary_pos_emb_cos, + "rotary_pos_emb_sin": rotary_pos_emb_sin, + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "sequence_lengths": sequence_lengths, + } + + def forward_piecewise( + self, + x: torch.Tensor, + pos_embeds: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + sequence_lengths: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass optimized for piecewise CUDA graph mode with batched images. + + This method accepts pre-computed position embeddings, rotary embeddings, + and cumulative sequence lengths. Unlike forward_cudagraph which processes + one image at a time, this method handles batched images with padding for + piecewise cudagraph optimization. + + The key difference from the regular forward() is that all grid-dependent + computations (position embeddings, rotary embeddings, cu_seqlens) are + pre-computed outside the compiled graph, allowing padding to be applied + to match cudagraph capture sizes. + + Args: + x: Input pixel values [num_patches, patch_channels] + pos_embeds: Pre-computed position embeddings [num_patches, hidden_size] + rotary_pos_emb_cos: Pre-computed rotary cosine embeddings + rotary_pos_emb_sin: Pre-computed rotary sine embeddings + cu_seqlens: Pre-computed cumulative sequence lengths (on GPU) + max_seqlen: Pre-computed max sequence length (scalar tensor, can be CPU) + sequence_lengths: Pre-computed sequence lengths (for FlashInfer CuDNN) + + Returns: + Vision encoder output tensor [num_output_tokens, hidden_size] + """ + # Patch embedding (GPU operation) + hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) + hidden_states = self.patch_embed(hidden_states) + + # Ensure max_seqlen is on GPU for attention kernels + if max_seqlen.device.type == "cpu": + max_seqlen = max_seqlen.to(self.device, non_blocking=True) + + # Add pre-computed position embeddings + hidden_states = hidden_states + pos_embeds + + hidden_states = hidden_states.unsqueeze(1) + + # Run through transformer blocks with pre-computed values + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat([hidden_states] + deepstack_feature_lists, dim=1) + return hidden_states + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -1359,6 +1671,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config + self.vllm_config = vllm_config self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" @@ -1373,7 +1686,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.visual = None else: self.visual = Qwen3_VisionTransformer( - config.vision_config, + vision_config=config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, multimodal_config=multimodal_config, @@ -1510,12 +1823,16 @@ def _process_image_input( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" - ) - else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + with set_forward_context(None, self.vllm_config): + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values, + grid_thw.tolist(), + rope_type="rope_3d", + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size @@ -1534,13 +1851,16 @@ def _process_video_input( pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype ) - if self.use_data_parallel: - grid_thw_list = grid_thw.tolist() - return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" - ) - else: - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + with set_forward_context(None, self.vllm_config): + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values_videos, + grid_thw.tolist(), + rope_type="rope_3d", + ) + else: + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index 3186804488e5..b37ae9c307f7 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -416,6 +416,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config + self.vllm_config = vllm_config self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index ea1a10f6ac9b..dd26c3548c78 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -17,15 +17,17 @@ from vllm._ipex_ops import ipex_ops reshape_and_cache_flash = ipex_ops.reshape_and_cache_flash - flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func + flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func # type: ignore[assignment] get_scheduler_metadata = ipex_ops.get_scheduler_metadata elif current_platform.is_rocm(): try: - from flash_attn import flash_attn_varlen_func # noqa: F401 + from flash_attn import ( # type: ignore[no-redef] + flash_attn_varlen_func, # noqa: F401 + ) except ImportError: - def flash_attn_varlen_func(*args, **kwargs): + def flash_attn_varlen_func(*args, **kwargs): # type: ignore[misc] raise ImportError( "ROCm platform requires upstream flash-attn " "to be installed. Please install flash-attn first." diff --git a/vllm/v1/attention/backends/mla/aiter_triton_mla.py b/vllm/v1/attention/backends/mla/aiter_triton_mla.py index b164bb7b2ecd..5b6ecb65c243 100644 --- a/vllm/v1/attention/backends/mla/aiter_triton_mla.py +++ b/vllm/v1/attention/backends/mla/aiter_triton_mla.py @@ -49,7 +49,7 @@ def __init__( def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs ): - result = self.flash_attn_varlen_func( + result = self.flash_attn_varlen_func( # type: ignore[call-arg] q, k, v, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 46ca97cac670..3abf8ad309d3 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -230,7 +230,7 @@ def __init__( def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs ): - output = self.flash_attn_varlen_func( + output = self.flash_attn_varlen_func( # type: ignore[call-arg] q=q, k=k, v=v, diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index 84b1438fb1b0..edc118c9add8 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -29,7 +29,7 @@ def flash_attn_maxseqlen_wrapper( fa_version: int | None, scale: float | None = None, cu_seqlens: torch.Tensor | None = None, - max_seqlen: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Also accepts int at runtime ) -> torch.Tensor: kwargs = {} if is_rocm_aiter: @@ -45,7 +45,14 @@ def flash_attn_maxseqlen_wrapper( cu_seqlens = torch.arange( 0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device ) - max_seqlen = q_len if max_seqlen is None else max_seqlen.item() + # Handle max_seqlen: can be None, int, or tensor + # For CUDA graph capture, use CPU tensor so .item() doesn't trigger GPU sync + if max_seqlen is None: + max_seqlen = q_len + elif isinstance(max_seqlen, int): + pass # already an int + else: + max_seqlen = max_seqlen.item() # CPU tensor .item() is safe during capture q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( @@ -117,7 +124,7 @@ def fa4_flash_attn_maxseqlen_wrapper( batch_size: int, scale: float | None = None, cu_seqlens: torch.Tensor | None = None, - max_seqlen: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Also accepts int at runtime ) -> torch.Tensor: """FA4 (flash_attn.cute) wrapper for ViT attention. @@ -132,7 +139,14 @@ def fa4_flash_attn_maxseqlen_wrapper( cu_seqlens = torch.arange( 0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device ) - max_seqlen_int = q_len if max_seqlen is None else max_seqlen.item() + # Handle max_seqlen: can be None, int, or tensor + # For CUDA graph capture, use CPU tensor so .item() doesn't trigger GPU sync + if max_seqlen is None: + max_seqlen_int = q_len + elif isinstance(max_seqlen, int): + max_seqlen_int = max_seqlen + else: + max_seqlen_int = max_seqlen.item() # CPU tensor .item() is safe during capture q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = fa4_flash_attn_varlen_func( @@ -306,6 +320,7 @@ def flashinfer_wrapper( batch_offsets_k=batch_offsets_qk, batch_offsets_v=batch_offsets_v, batch_offsets_o=batch_offsets_o, + is_cuda_graph_compatible=True, ) if is_reshaped: diff --git a/vllm/v1/worker/gpu/mm/encoder_cudagraph.py b/vllm/v1/worker/gpu/mm/encoder_cudagraph.py new file mode 100644 index 000000000000..5a60888b88d4 --- /dev/null +++ b/vllm/v1/worker/gpu/mm/encoder_cudagraph.py @@ -0,0 +1,1014 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +CUDA Graph Manager for Multimodal Encoders (ViT). + +This module provides CUDA graph capture and replay functionality for vision +encoders to eliminate kernel launch overhead and improve GPU utilization. + +Primary execution mode - Budget Batching: +- Captures CUDA graphs for multiple token budget levels (e.g., [2048, 4096, + 8192, 13824]), each with a fixed max_images_per_batch. +- At runtime, images are sorted smallest-first and greedily packed into + budget-sized batches. The smallest fitting budget graph is selected. +- cu_seqlens is padded to max_images_per_batch + 1 by repeating the last + value, creating zero-length sequences for empty slots (no-op in FA2/FA4). +- Works with any number of images (1 or many) and any grid sizes. + +Key design principles: +1. Capture graphs based on token budgets, not grid sizes +2. Reuse one graph for any batch where total tokens fit the budget +3. Fall back to eager mode when no suitable graph is available +4. Track statistics for monitoring and optimization +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import graph_capture +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Grid configurations for CUDA graph capture (T, H, W in patch units) +# +# Top 100 most common grids for embedding cache pre-warming. +# Pre-warming these grids at startup avoids cold-start embedding computation +# at runtime, eliminating ~20 small kernel launches per grid on first encounter. +# Based on MLPerf VLM dataset analysis (~71% coverage with top 100 grids). +EMBEDDING_WARMUP_GRIDS: list[tuple[int, int, int]] = [ + # Top 50 grids (sorted by frequency) + (1, 62, 62), + (1, 32, 32), + (1, 50, 50), + (1, 38, 38), + (1, 76, 76), + (1, 94, 94), + (1, 64, 64), + (1, 124, 124), + (1, 68, 68), + (1, 100, 100), + (1, 16, 16), + (1, 24, 24), + (1, 46, 46), + (1, 44, 44), + (1, 42, 42), + (1, 40, 40), + (1, 56, 56), + (1, 128, 128), + (1, 18, 18), + (1, 28, 28), + (1, 34, 34), + (1, 80, 80), + (1, 30, 30), + (1, 38, 50), + (1, 22, 22), + (1, 112, 112), + (1, 36, 36), + (1, 34, 50), + (1, 188, 188), + (1, 14, 20), + (1, 90, 90), + (1, 44, 42), + (1, 16, 18), + (1, 54, 54), + (1, 48, 48), + (1, 40, 42), + (1, 60, 60), + (1, 88, 88), + (1, 26, 26), + (1, 156, 156), + (1, 94, 62), + (1, 30, 38), + (1, 24, 38), + (1, 20, 20), + (1, 24, 16), + (1, 18, 16), + (1, 120, 120), + (1, 60, 80), + (1, 52, 52), + (1, 66, 66), + # Next 50 grids + (1, 20, 14), + (1, 24, 32), + (1, 160, 160), + (1, 28, 38), + (1, 30, 40), + (1, 38, 42), + (1, 58, 58), + (1, 20, 32), + (1, 50, 38), + (1, 48, 64), + (1, 78, 78), + (1, 24, 20), + (1, 42, 62), + (1, 62, 94), + (1, 36, 42), + (1, 32, 20), + (1, 150, 150), + (1, 50, 42), + (1, 50, 76), + (1, 72, 72), + (1, 32, 24), + (1, 46, 42), + (1, 92, 94), + (1, 82, 82), + (1, 32, 38), + (1, 90, 94), + (1, 14, 22), + (1, 76, 100), + (1, 94, 92), + (1, 24, 18), + (1, 54, 42), + (1, 38, 32), + (1, 18, 24), + (1, 28, 32), + (1, 30, 42), + (1, 56, 76), + (1, 62, 42), + (1, 28, 50), + (1, 32, 42), + (1, 36, 50), + (1, 38, 24), + (1, 108, 82), + (1, 16, 20), + (1, 26, 38), + (1, 38, 36), + (1, 34, 42), + (1, 76, 50), + (1, 38, 56), + (1, 48, 42), + (1, 30, 32), +] + + +class EncoderCudaGraphManager: + """ + Manages CUDA graphs for multimodal encoders (e.g., ViT in VLMs). + + The manager captures CUDA graphs for specific grid configurations + (T, H, W in patch units) and replays them during inference when + input dimensions exactly match. + + Design: + - Captures graphs for predefined grid configurations + - Only replays when input exactly matches a captured configuration + - Falls back to eager mode for non-matching inputs + - Tracks statistics for monitoring + + Limitations: + - Requires exact dimension match for graph replay + - Variable-size images may not benefit from CUDA graphs + """ + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + dtype: torch.dtype, + graph_pool: Any | None = None, + verbose: bool = False, + ): + self.vllm_config = vllm_config + self.device = device + self.dtype = dtype + self.verbose = verbose + + # CUDA graph storage - keyed by (batch_size, t, h, w) tuple + self.graphs: dict[tuple[int, int, int, int], torch.cuda.CUDAGraph] = {} + # Use private pools by default to avoid segfaults with rapid back-to-back + # graph replays during one-by-one multi-image processing. + # Set VLLM_ENCODER_SHARED_POOL=1 to use shared pool (saves memory but + # may cause issues with rapid replays) + import os + + if os.environ.get("VLLM_ENCODER_SHARED_POOL", "0") == "1": + self.pool = ( + graph_pool if graph_pool is not None else torch.cuda.graph_pool_handle() + ) + logger.info("Encoder CUDA graphs: using shared pool") + else: + self.pool = None # Each graph uses private memory (default) + + # Pre-allocated input/output buffers per graph config + # Key: (batch_size, t, h, w), Value: {"pixel_values": tensor, "grid_thw": list} + self.input_buffers: dict[tuple[int, int, int, int], dict[str, Any]] = {} + self.output_buffers: dict[tuple[int, int, int, int], torch.Tensor] = {} + + # Input buffers for embeddings (padded mode with runtime computation) + # Key: (batch_size, t, h, w), Value: dict with pos_embeds, rotary, cu_seqlens + self.embedding_buffers: dict[ + tuple[int, int, int, int], dict[str, torch.Tensor] + ] = {} + + # Vision encoder reference for runtime embedding computation (set at capture) + self.vision_encoder = None + + # Track if graphs have been captured + self.captured = False + + # Statistics + self.cache_hits = 0 + self.eager_fallbacks = 0 + + # CUDA event for lightweight synchronization + # Instead of torch.cuda.synchronize() which waits for ALL GPU work, + # we use an event to track only the last replay completion. + # This allows better overlap between encoder and other GPU work. + self.replay_done_event: torch.cuda.Event | None = None + + # Single-GPU optimization: when TP=1, PP=1, DP=1, we can capture graphs + # on the current stream instead of a separate stream. This eliminates + # the need for stream synchronization before replay. + parallel_config = vllm_config.parallel_config + self.is_single_gpu = ( + parallel_config.tensor_parallel_size == 1 + and parallel_config.pipeline_parallel_size == 1 + and parallel_config.data_parallel_size == 1 + ) + if self.is_single_gpu: + logger.info( + "Encoder CUDA graphs: single-GPU mode enabled " + "(TP=1, PP=1, DP=1), using optimized sync scheme" + ) + + # Per-grid embedding cache for batched contiguous mode + # Key: (t, h, w), Value: dict with pos_embeds, rotary_cos, rotary_sin + # This avoids recomputing embeddings at runtime - just look up and concat + self.grid_embedding_cache: dict[ + tuple[int, int, int], dict[str, torch.Tensor] + ] = {} + + # Budget batching config + # Maps token_budget -> graph_key for budget batch CUDA graphs + self.budget_graph_keys: dict[int, tuple[int, int, int, int]] = {} + self.token_budgets: list[int] = [] + self.max_images_per_batch: int = 0 + self._read_budget_config() + + def _read_budget_config(self) -> None: + """Read budget batching configuration from compilation config.""" + compilation_config = self.vllm_config.compilation_config + if compilation_config is None: + return + + token_budgets = getattr( + compilation_config, "encoder_cudagraph_token_budgets", None + ) + max_images = getattr( + compilation_config, "encoder_cudagraph_max_images_per_batch", None + ) + + if token_budgets is None and max_images is None: + return + + if (token_budgets is None) != (max_images is None): + logger.warning( + "encoder_cudagraph_token_budgets and " + "encoder_cudagraph_max_images_per_batch must both be set. " + "Budget batching disabled." + ) + return + + if token_budgets is None or max_images is None: + return + + if max_images <= 0: + logger.warning( + "encoder_cudagraph_max_images_per_batch must be positive. " + "Budget batching disabled." + ) + return + + bad_budgets = [b for b in token_budgets if b % max_images != 0] + if bad_budgets: + logger.warning( + "encoder_cudagraph_token_budgets values %s are not divisible " + "by max_images_per_batch=%d. Budget batching disabled.", + bad_budgets, + max_images, + ) + return + + self.token_budgets = sorted(token_budgets) + self.max_images_per_batch = max_images + + logger.info( + "Budget batching configured: token_budgets=%s, max_images_per_batch=%d", + self.token_budgets, + self.max_images_per_batch, + ) + + def _compute_output_tokens( + self, + grid_thw: tuple[int, int, int], + spatial_merge_size: int, + ) -> int: + """Compute number of output tokens for a grid configuration.""" + t, h, w = grid_thw + # After spatial merge: tokens = T * (H/merge) * (W/merge) + return t * (h // spatial_merge_size) * (w // spatial_merge_size) + + def _prepare_dummy_inputs_for_grid( + self, + grid_config: tuple[int, int, int], + vision_encoder: nn.Module, + batch_size: int = 1, + ) -> dict[str, Any]: + """ + Prepare dummy inputs for CUDA graph capture with a specific grid config. + + Args: + grid_config: Tuple of (T, H, W) in patch units + vision_encoder: The vision encoder module + batch_size: Number of images in the batch (all same grid) + + Returns: + Dict with pixel_values, grid_thw, and metadata + """ + t, h, w = grid_config + + # Get vision encoder properties + patch_size = vision_encoder.patch_size + temporal_patch_size = vision_encoder.temporal_patch_size + spatial_merge_size = vision_encoder.spatial_merge_size + in_channels = 3 # RGB + + # Calculate patch input channels + patch_input_channels = ( + temporal_patch_size * patch_size * patch_size * in_channels + ) + + # Calculate number of pixel patches per image + # h, w are in patch units, so num_patches = t * h * w + num_pixel_patches_per_image = t * h * w + total_pixel_patches = num_pixel_patches_per_image * batch_size + + # Create dummy pixel values for batch (zeros are fine for warmup/capture) + pixel_values = torch.zeros( + total_pixel_patches, + patch_input_channels, + dtype=self.dtype, + device=self.device, + ) + + # Grid (temporal, height, width) for each image in batch + grid_thw = [[t, h, w]] * batch_size + + # Calculate output tokens per image and total + output_tokens_per_image = self._compute_output_tokens( + grid_config, spatial_merge_size + ) + total_output_tokens = output_tokens_per_image * batch_size + + return { + "pixel_values": pixel_values, + "grid_thw": grid_thw, + "num_output_tokens": total_output_tokens, + "num_output_tokens_per_image": output_tokens_per_image, + "num_pixel_patches": total_pixel_patches, + "num_pixel_patches_per_image": num_pixel_patches_per_image, + "patch_input_channels": patch_input_channels, + "batch_size": batch_size, + } + + def capture_graph_for_grid( + self, + grid_config: tuple[int, int, int], + vision_encoder: nn.Module, + batch_size: int = 1, + ) -> None: + """ + Capture a CUDA graph for the given grid configuration and batch size. + + This method pre-computes and caches all grid-dependent tensors + (position embeddings, rotary embeddings, cu_seqlens) to eliminate + CPU operations during CUDA graph replay. + + Args: + grid_config: Tuple of (T, H, W) in patch units + vision_encoder: The vision encoder module + batch_size: Number of images with same grid (default 1) + """ + t, h, w = grid_config + graph_key = (batch_size, t, h, w) + logger.debug( + "Capturing encoder CUDA graph for key %s (batch_size=%d, grid=%s)", + graph_key, + batch_size, + grid_config, + ) + + # Prepare dummy inputs for batch + dummy_inputs = self._prepare_dummy_inputs_for_grid( + grid_config, vision_encoder, batch_size + ) + pixel_values = dummy_inputs["pixel_values"] + grid_thw = dummy_inputs["grid_thw"] + + # Store input buffer reference with new key format + self.input_buffers[graph_key] = { + "pixel_values": pixel_values.clone(), + "grid_thw": grid_thw, + } + + # Store vision encoder reference for runtime embedding computation + self.vision_encoder = vision_encoder + + # Check if vision encoder supports optimized CUDA graph forward + has_cudagraph_forward = hasattr( + vision_encoder, "forward_cudagraph" + ) and hasattr(vision_encoder, "precompute_for_cudagraph") + + if has_cudagraph_forward: + cached = vision_encoder.precompute_for_cudagraph(grid_thw) + + # Cache per-grid embeddings for batched contiguous mode + # This avoids recomputing embeddings at runtime - just lookup and concat + grid_key = (t, h, w) + if grid_key not in self.grid_embedding_cache: + # Compute embeddings for a single image of this grid size + single_cached = vision_encoder.precompute_for_cudagraph([[t, h, w]]) + self.grid_embedding_cache[grid_key] = { + "pos_embeds": single_cached["pos_embeds"], + "rotary_pos_emb_cos": single_cached["rotary_pos_emb_cos"], + "rotary_pos_emb_sin": single_cached["rotary_pos_emb_sin"], + } + logger.debug( + "Cached per-grid embeddings for grid %s: pos_embeds=%s", + grid_key, + single_cached["pos_embeds"].shape, + ) + + # Create INPUT BUFFERS for embeddings (padded mode runtime computation) + # These buffers can be updated at runtime before graph replay + # Note: max_seqlen is a CPU scalar tensor to avoid GPU sync on .item() + self.embedding_buffers[graph_key] = { + "pos_embeds": cached["pos_embeds"].clone(), + "rotary_pos_emb_cos": cached["rotary_pos_emb_cos"].clone(), + "rotary_pos_emb_sin": cached["rotary_pos_emb_sin"].clone(), + "cu_seqlens": cached["cu_seqlens"].clone(), + "max_seqlen": cached["max_seqlen"].clone(), + "sequence_lengths": cached["sequence_lengths"].clone(), + } + embed_buffers = self.embedding_buffers[graph_key] + + # Warmup run with embedding buffers + # Use set_forward_context to provide vllm_config for torch.compile + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + ): + warmup_output = vision_encoder.forward_cudagraph( + pixel_values, + pos_embeds=embed_buffers["pos_embeds"], + rotary_pos_emb_cos=embed_buffers["rotary_pos_emb_cos"], + rotary_pos_emb_sin=embed_buffers["rotary_pos_emb_sin"], + cu_seqlens=embed_buffers["cu_seqlens"], + max_seqlen=embed_buffers["max_seqlen"], + sequence_lengths=embed_buffers["sequence_lengths"], + ) + self.output_buffers[graph_key] = torch.empty_like(warmup_output) + + # Capture the graph with embedding BUFFERS (not constants) + # This allows updating embeddings at runtime for padded mode + graph = torch.cuda.CUDAGraph() + input_buffer = self.input_buffers[graph_key]["pixel_values"] + + with ( + set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + ), + torch.cuda.graph(graph, self.pool), + ): + output = vision_encoder.forward_cudagraph( + input_buffer, + pos_embeds=embed_buffers["pos_embeds"], + rotary_pos_emb_cos=embed_buffers["rotary_pos_emb_cos"], + rotary_pos_emb_sin=embed_buffers["rotary_pos_emb_sin"], + cu_seqlens=embed_buffers["cu_seqlens"], + max_seqlen=embed_buffers["max_seqlen"], + sequence_lengths=embed_buffers["sequence_lengths"], + ) + self.output_buffers[graph_key].copy_(output) + else: + # Fallback to original forward (will have CPU gaps) + logger.warning( + "Vision encoder does not support forward_cudagraph, " + "using standard forward (will have CPU gaps)" + ) + + # Warmup run (required before capture) + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + ): + warmup_output = vision_encoder(pixel_values, grid_thw=grid_thw) + self.output_buffers[graph_key] = torch.empty_like(warmup_output) + + # Capture the graph + graph = torch.cuda.CUDAGraph() + input_buffer = self.input_buffers[graph_key]["pixel_values"] + + with ( + set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + ), + torch.cuda.graph(graph, self.pool), + ): + output = vision_encoder(input_buffer, grid_thw=grid_thw) + self.output_buffers[graph_key].copy_(output) + + self.graphs[graph_key] = graph + cached_suffix = " (with cached tensors)" if has_cudagraph_forward else "" + logger.debug( + "Captured encoder CUDA graph for key %s -> %d output tokens%s", + graph_key, + dummy_inputs["num_output_tokens"], + cached_suffix, + ) + + def capture_budget_graphs(self, vision_encoder: nn.Module) -> None: + """ + Capture CUDA graphs for budget batching mode. + + For each configured token_budget, captures a graph with + max_images_per_batch image slots. The graph uses a synthetic grid + that produces the right tensor shapes. At runtime, embedding buffers + are overwritten with actual per-image values from grid_embedding_cache. + + Args: + vision_encoder: The vision encoder module + """ + if not self.token_budgets or self.max_images_per_batch <= 0: + return + + merge = getattr(vision_encoder, "spatial_merge_size", 2) + + for token_budget in self.token_budgets: + per_image_output = token_budget // self.max_images_per_batch + if per_image_output <= 0: + logger.warning( + "token_budget=%d too small for max_images=%d, skipping", + token_budget, + self.max_images_per_batch, + ) + continue + + # Synthetic grid: (1, merge, per_image_output * merge) + # Output tokens per image: + # 1 * (merge/merge) * (per_image_output*merge/merge) + # = per_image_output + # Total output = max_images * per_image_output = token_budget + grid_config = (1, merge, per_image_output * merge) + + try: + if self.is_single_gpu: + self.capture_graph_for_grid( + grid_config, + vision_encoder, + batch_size=self.max_images_per_batch, + ) + else: + with graph_capture(device=self.device): + self.capture_graph_for_grid( + grid_config, + vision_encoder, + batch_size=self.max_images_per_batch, + ) + + graph_key = ( + self.max_images_per_batch, + 1, + merge, + per_image_output * merge, + ) + self.budget_graph_keys[token_budget] = graph_key + logger.info( + "Captured budget graph: token_budget=%d, " + "max_images=%d, graph_key=%s", + token_budget, + self.max_images_per_batch, + graph_key, + ) + except Exception as e: + logger.warning( + "Failed to capture budget graph for token_budget=%d: %s", + token_budget, + e, + ) + + def find_budget_graph( + self, + total_output_tokens: int, + ) -> tuple[int, int, int, int] | None: + """ + Find the smallest budget graph that fits the given total output tokens. + + Args: + total_output_tokens: Total output tokens for the packed batch + + Returns: + Graph key (batch_size, t, h, w) or None if no budget fits + """ + best_key = None + best_budget = float("inf") + + for budget, graph_key in self.budget_graph_keys.items(): + if budget >= total_output_tokens and budget < best_budget: + best_budget = budget + best_key = graph_key + + return best_key + + @torch.inference_mode() + def capture( + self, + vision_encoder: nn.Module, + embed_multimodal_fn: Callable, + ) -> None: + """ + Capture CUDA graphs for all configured grid and batch size combinations. + + Args: + vision_encoder: The vision encoder module (e.g., Qwen3_VisionTransformer) + embed_multimodal_fn: The model's embed_multimodal method (unused) + """ + if self.captured: + logger.warning("Encoder CUDA graphs already captured, skipping") + return + + # Pre-warm embedding cache for common grids + self._prewarm_embedding_cache(vision_encoder) + + # Capture budget batch graphs + if self.token_budgets and self.max_images_per_batch > 0: + self.capture_budget_graphs(vision_encoder) + + self.captured = True + + def _prewarm_embedding_cache(self, vision_encoder: nn.Module) -> None: + """ + Pre-warm the embedding cache for common grid configurations. + + This avoids cold-start embedding computation at runtime by pre-computing + embeddings for the top 100 most common grids. Each grid that would + otherwise trigger ~20 small kernel launches on first encounter will + instead hit the cache. + + Args: + vision_encoder: The vision encoder module with precompute_for_cudagraph + """ + if not hasattr(vision_encoder, "precompute_for_cudagraph"): + logger.debug( + "Vision encoder lacks precompute_for_cudagraph, skipping warmup" + ) + return + + # Filter out grids that are already cached (from graph capture) + grids_to_warm = [ + g for g in EMBEDDING_WARMUP_GRIDS if g not in self.grid_embedding_cache + ] + + if not grids_to_warm: + logger.debug("All warmup grids already cached") + return + + if self.verbose: + logger.info( + "Pre-warming embedding cache for %d grids (%d already cached)", + len(grids_to_warm), + len(EMBEDDING_WARMUP_GRIDS) - len(grids_to_warm), + ) + + for grid in grids_to_warm: + t, h, w = grid + try: + cached = vision_encoder.precompute_for_cudagraph([[t, h, w]]) + self.grid_embedding_cache[grid] = { + "pos_embeds": cached["pos_embeds"], + "rotary_pos_emb_cos": cached["rotary_pos_emb_cos"], + "rotary_pos_emb_sin": cached["rotary_pos_emb_sin"], + } + except Exception as e: + logger.debug("Failed to pre-warm grid %s: %s", grid, e) + + # Calculate and log embedding cache memory consumption + if self.verbose: + cache_memory_bytes = self._compute_embedding_cache_memory() + logger.info( + "Embedding cache warmed: %d grids total, memory: %.2f MiB", + len(self.grid_embedding_cache), + cache_memory_bytes / (1024 * 1024), + ) + + def _compute_embedding_cache_memory(self) -> int: + """ + Compute the total GPU memory consumption of the embedding cache. + + Returns: + Total memory in bytes used by all cached embeddings. + """ + total_bytes = 0 + for grid, cached in self.grid_embedding_cache.items(): + for key, tensor in cached.items(): + if isinstance(tensor, torch.Tensor): + total_bytes += tensor.numel() * tensor.element_size() + return total_bytes + + def run_batched_contiguous( + self, + pixel_values: torch.Tensor, + grid_thw_list: list[list[int]], + graph_key: tuple[int, int, int, int], + spatial_merge_size: int = 2, + ) -> torch.Tensor | None: + """ + Run batched CUDA graph with contiguous packing and end padding. + + This method packs images contiguously in the buffer (no interleaved padding), + computes actual cu_seqlens at runtime, and pads only at the end. This ensures + flash attention reads correct data for each sequence. + + Memory layout: + Buffer: [img0][img1][img2][img3][PADDING at end] + cu_seqlens: [0, size0, size0+size1, ..., total_actual] + + Flash attention uses cu_seqlens to process only actual tokens; padding at + the end is outside all sequence boundaries and is ignored. + + Args: + pixel_values: Contiguously packed pixel values (no padding between images) + grid_thw_list: List of [T, H, W] for each image (can be different grids) + graph_key: The bucket graph key (batch_size, t, h, w) to use + spatial_merge_size: Spatial merge size (default 2) + + Returns: + Full output tensor from the bucket, or None if failed. + Caller should use cu_seqlens to extract per-image outputs. + """ + if graph_key not in self.graphs: + logger.debug("No graph for key %s", graph_key) + return None + + batch_size = graph_key[0] + num_actual_images = len(grid_thw_list) + is_budget_graph = graph_key in self.budget_graph_keys.values() + + if num_actual_images > batch_size: + logger.warning( + "grid_thw_list length (%d) exceeds graph batch_size (%d)", + num_actual_images, + batch_size, + ) + return None + + if num_actual_images != batch_size and not is_budget_graph: + logger.warning( + "grid_thw_list length (%d) doesn't match graph batch_size (%d)" + " and not a budget graph", + num_actual_images, + batch_size, + ) + return None + + # Check if vision encoder is available for embedding computation + if self.vision_encoder is None or not hasattr( + self.vision_encoder, "precompute_for_cudagraph" + ): + logger.debug("Vision encoder not available for batched contiguous mode") + return None + + # Check if we have embedding buffers for this bucket + if graph_key not in self.embedding_buffers: + logger.debug("No embedding buffers for bucket %s", graph_key) + return None + + # Get the input buffer for this bucket + input_buffer = self.input_buffers[graph_key]["pixel_values"] + actual_input_patches = pixel_values.shape[0] + bucket_input_patches = input_buffer.shape[0] + + if actual_input_patches > bucket_input_patches: + logger.warning( + "Input patches (%d) exceed bucket capacity (%d).", + actual_input_patches, + bucket_input_patches, + ) + self.eager_fallbacks += 1 + return None + + # Verify device and dtype match + if pixel_values.device != input_buffer.device: + logger.warning( + "Device mismatch: expected %s, got %s.", + input_buffer.device, + pixel_values.device, + ) + self.eager_fallbacks += 1 + return None + + if pixel_values.dtype != input_buffer.dtype: + logger.warning( + "Dtype mismatch: expected %s, got %s.", + input_buffer.dtype, + pixel_values.dtype, + ) + self.eager_fallbacks += 1 + return None + + # Ensure contiguous memory layout + if not pixel_values.is_contiguous(): + pixel_values = pixel_values.contiguous() + + # Count actual images processed (for accurate hit rate) + self.cache_hits += num_actual_images + + # Wait for any previous graph replay to complete + if not self.is_single_gpu and self.replay_done_event is not None: + self.replay_done_event.synchronize() + + # Get embedding buffers for the bucket + embed_buffers = self.embedding_buffers[graph_key] + + # Zero the buffers first (for clean padding at end) + input_buffer.zero_() + embed_buffers["pos_embeds"].zero_() + embed_buffers["rotary_pos_emb_cos"].zero_() + embed_buffers["rotary_pos_emb_sin"].zero_() + + # Copy actual pixel values to the beginning of the buffer (contiguous) + input_buffer[:actual_input_patches].copy_(pixel_values, non_blocking=True) + + # Look up cached embeddings for each grid and pack contiguously + # This avoids expensive per-image precompute_for_cudagraph calls + pos_embeds_list = [] + rotary_cos_list = [] + rotary_sin_list = [] + sequence_lengths = [] + cache_miss_grids: list[tuple[int, int, int]] = [] + + for grid in grid_thw_list: + t, h, w = grid + grid_key = (t, h, w) + # Each temporal frame is a separate attention sequence in patch space. + # This matches the eager path: np.repeat(h*w, t) per image. + for _ in range(t): + sequence_lengths.append(h * w) + + # Try to use cached embeddings (populated during graph capture) + if grid_key in self.grid_embedding_cache: + cached = self.grid_embedding_cache[grid_key] + pos_embeds_list.append(cached["pos_embeds"]) + rotary_cos_list.append(cached["rotary_pos_emb_cos"]) + rotary_sin_list.append(cached["rotary_pos_emb_sin"]) + else: + # Cache miss - compute on-the-fly but don't cache + # (avoids unbounded GPU memory growth at runtime) + cache_miss_grids.append(grid_key) + if self.vision_encoder is not None: + actual_embeds = self.vision_encoder.precompute_for_cudagraph([grid]) + pos_embeds_list.append(actual_embeds["pos_embeds"]) + rotary_cos_list.append(actual_embeds["rotary_pos_emb_cos"]) + rotary_sin_list.append(actual_embeds["rotary_pos_emb_sin"]) + else: + logger.warning("Grid %s not cached and no vision encoder", grid_key) + return None + + if cache_miss_grids and self.verbose: + logger.info( + "Embedding cache miss for grids: %s (computed on-the-fly)", + cache_miss_grids, + ) + + # Concatenate cached embeddings (just tensor concat, no computation) + packed_pos_embeds = torch.cat(pos_embeds_list, dim=0) + packed_rotary_cos = torch.cat(rotary_cos_list, dim=0) + packed_rotary_sin = torch.cat(rotary_sin_list, dim=0) + + # Copy packed embeddings to buffer (padding remains zero at end) + actual_embed_len = packed_pos_embeds.shape[0] + embed_buffers["pos_embeds"][:actual_embed_len].copy_( + packed_pos_embeds, non_blocking=True + ) + embed_buffers["rotary_pos_emb_cos"][:actual_embed_len].copy_( + packed_rotary_cos, non_blocking=True + ) + embed_buffers["rotary_pos_emb_sin"][:actual_embed_len].copy_( + packed_rotary_sin, non_blocking=True + ) + + # Build cu_seqlens from actual cumulative sizes + # cu_seqlens = [0, size0, size0+size1, ..., total_actual] + cu_seqlens_list = [0] + for length in sequence_lengths: + cu_seqlens_list.append(cu_seqlens_list[-1] + length) + + # For budget graphs: pad cu_seqlens to batch_size + 1 by repeating + # the last value. This creates zero-length sequences for empty slots + # that flash attention skips (no-op). + # Note: num_sequences = sum(t_i) for all images. For images (t=1), + # this equals num_images <= batch_size. For videos (t>1), it could + # exceed batch_size — fall back to eager in that case. + if is_budget_graph and len(sequence_lengths) > batch_size: + logger.debug( + "Too many sequences (%d) for budget graph batch_size (%d), " + "falling back to eager", + len(sequence_lengths), + batch_size, + ) + return None + if is_budget_graph and len(cu_seqlens_list) < batch_size + 1: + last_val = cu_seqlens_list[-1] + while len(cu_seqlens_list) < batch_size + 1: + cu_seqlens_list.append(last_val) + + # For budget graphs: pad sequence_lengths with zeros for empty slots + if is_budget_graph and len(sequence_lengths) < batch_size: + sequence_lengths = list(sequence_lengths) + [0] * ( + batch_size - len(sequence_lengths) + ) + + cu_seqlens_tensor = torch.tensor( + cu_seqlens_list, dtype=torch.int32, device=self.device + ) + max_seqlen = ( + max(s for s in sequence_lengths if s > 0) if sequence_lengths else 0 + ) + max_seqlen_tensor = torch.tensor(max_seqlen, dtype=torch.int32, device="cpu") + sequence_lengths_tensor = torch.tensor( + sequence_lengths, dtype=torch.int32, device=self.device + ) + + # Copy full cu_seqlens and sequence_lengths to buffers + # For budget graphs, sizes match exactly (padded to batch_size + 1). + # For non-budget graphs, copy only the actual part. + cu_seqlens_buf = embed_buffers["cu_seqlens"] + seq_len_buf = embed_buffers["sequence_lengths"] + if is_budget_graph: + cu_seqlens_buf.copy_(cu_seqlens_tensor, non_blocking=True) + seq_len_buf.copy_(sequence_lengths_tensor, non_blocking=True) + else: + cu_seqlens_buf[: len(cu_seqlens_list)].copy_( + cu_seqlens_tensor, non_blocking=True + ) + seq_len_buf[:batch_size].copy_(sequence_lengths_tensor, non_blocking=True) + embed_buffers["max_seqlen"].copy_(max_seqlen_tensor, non_blocking=True) + + if self.verbose: + logger.info( + "run_batched_contiguous(): graph_key=%s, grids=%s, " + "actual_patches=%d, bucket_patches=%d, cu_seqlens=%s", + graph_key, + grid_thw_list, + actual_input_patches, + bucket_input_patches, + cu_seqlens_list, + ) + + if self.is_single_gpu: + self.graphs[graph_key].replay() + return self.output_buffers[graph_key] + else: + torch.cuda.current_stream().synchronize() + self.graphs[graph_key].replay() + if self.replay_done_event is None: + self.replay_done_event = torch.cuda.Event() + self.replay_done_event.record() + self.replay_done_event.synchronize() + return self.output_buffers[graph_key].clone() + + def get_stats(self, verbose: bool = True) -> dict[str, Any]: + """Get and optionally log cache statistics. + + Args: + verbose: If True, log stats to INFO level. If False, only return stats dict. + """ + total = self.cache_hits + self.eager_fallbacks + hit_rate = self.cache_hits / total if total > 0 else 0.0 + stats = { + "cache_hits": self.cache_hits, + "eager_fallbacks": self.eager_fallbacks, + "hit_rate": hit_rate, + "num_graphs": len(self.graphs), + "captured_configs": sorted(self.graphs.keys()), + } + if verbose: + logger.info( + "Encoder CUDA graph stats: hits=%d, eager=%d, " + "hit_rate=%.1f%%, num_graphs=%d", + self.cache_hits, + self.eager_fallbacks, + hit_rate * 100, + len(self.graphs), + ) + return stats diff --git a/vllm/v1/worker/gpu/mm/encoder_runner.py b/vllm/v1/worker/gpu/mm/encoder_runner.py index f9a0b50f34b4..db090eb164b1 100644 --- a/vllm/v1/worker/gpu/mm/encoder_runner.py +++ b/vllm/v1/worker/gpu/mm/encoder_runner.py @@ -1,14 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np import torch +from vllm.logger import init_logger from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.worker.gpu.mm.encoder_cudagraph import EncoderCudaGraphManager + +logger = init_logger(__name__) + class EncoderRunner: def __init__( @@ -17,11 +28,13 @@ def __init__( hidden_size: int, dtype: torch.dtype, device: torch.device, + vllm_config: VllmConfig | None = None, ): self.max_num_tokens = max_num_tokens self.hidden_size = hidden_size self.dtype = dtype self.device = device + self.vllm_config = vllm_config self.inputs_embeds = torch.zeros( max_num_tokens, @@ -34,6 +47,68 @@ def __init__( self.tmp_is_mm_embed = UvaBufferPool(max_num_tokens, torch.bool) + # Encoder CUDA graph manager (optional) + self.encoder_cudagraph_manager: EncoderCudaGraphManager | None = None + self.encoder_cudagraph_budget_mode: bool = False + self._encoder_call_count: int = 0 + self._init_encoder_cudagraph_manager() + + def _init_encoder_cudagraph_manager(self) -> None: + """Initialize encoder CUDA graph manager if enabled in config.""" + if self.vllm_config is None: + return + + compilation_config = self.vllm_config.compilation_config + if compilation_config is None: + return + + if not getattr(compilation_config, "cudagraph_mm_encoder", False): + return + + # Import here to avoid circular imports + from vllm.v1.worker.gpu.mm.encoder_cudagraph import EncoderCudaGraphManager + + self.encoder_cudagraph_manager = EncoderCudaGraphManager( + vllm_config=self.vllm_config, + device=self.device, + dtype=self.dtype, + ) + + # Check if budget batching is configured + self.encoder_cudagraph_budget_mode = bool( + self.encoder_cudagraph_manager.token_budgets + and self.encoder_cudagraph_manager.max_images_per_batch > 0 + ) + + logger.info( + "Encoder CUDA graph manager initialized: budget_mode=%s", + self.encoder_cudagraph_budget_mode, + ) + + def capture_encoder_cudagraphs( + self, + model: SupportsMultiModal, + ) -> None: + """ + Capture CUDA graphs for the encoder. + + Should be called during model warmup after the model is loaded. + """ + if self.encoder_cudagraph_manager is None: + return + + if not hasattr(model, "visual") or model.visual is None: + logger.warning( + "Model does not have a visual encoder, " + "skipping encoder CUDA graph capture" + ) + return + + self.encoder_cudagraph_manager.capture( + vision_encoder=model.visual, + embed_multimodal_fn=model.embed_multimodal, + ) + def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]): self.req_id_to_mm_features[req_id] = mm_features @@ -59,6 +134,32 @@ def prepare_mm_inputs( mm_kwargs.append(mm_feature.data) return mm_hashes, mm_kwargs + def _get_grid_thw_from_kwargs( + self, + mm_kwargs_group: dict, + modality: str, + ) -> list[list[int]] | None: + """ + Extract grid_thw from mm_kwargs_group. + + Returns None if grid_thw is not available. + """ + if modality not in ("image", "video"): + return None + + # Try to get grid_thw from the kwargs + grid_thw = mm_kwargs_group.get("image_grid_thw") + if grid_thw is None: + grid_thw = mm_kwargs_group.get("video_grid_thw") + if grid_thw is None: + return None + + # Convert to list if tensor + if hasattr(grid_thw, "tolist"): + grid_thw = grid_thw.tolist() + + return grid_thw + @torch.inference_mode() def execute_mm_encoder( self, @@ -75,7 +176,20 @@ def execute_mm_encoder( device=self.device, pin_memory=False, ): - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) + # Try to use CUDA graph if available + cudagraph_result = None + if self.encoder_cudagraph_manager is not None: + cudagraph_result = self._execute_with_cudagraph( + model, mm_kwargs_group, modality, num_items + ) + + if cudagraph_result is not None: + # CUDA graph was used successfully + curr_group_outputs = cudagraph_result + else: + # Fall back to eager mode + curr_group_outputs = list(model.embed_multimodal(**mm_kwargs_group)) + sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=num_items, @@ -85,8 +199,188 @@ def execute_mm_encoder( # Cache the encoder outputs by mm_hash for mm_hash, output in zip(mm_hashes, encoder_outputs): self.encoder_cache[mm_hash] = output + + # Log encoder CUDA graph stats + self._encoder_call_count += 1 + if self.encoder_cudagraph_manager is not None: + self.encoder_cudagraph_manager.get_stats() + return encoder_outputs + def _execute_with_cudagraph( + self, + model: SupportsMultiModal, + mm_kwargs_group: dict, + modality: str, + num_items: int, + ) -> list[torch.Tensor] | None: + """ + Execute the encoder using budget batch CUDA graphs. + + Packs images (sorted smallest-first) into budget-sized batches + and replays the smallest fitting CUDA graph. Falls back to eager + if no budget graph fits. + + Args: + model: The multimodal model + mm_kwargs_group: Batched multimodal kwargs + modality: The modality type ("image" or "video") + num_items: Number of items in the batch + + Returns: + List of encoder outputs if CUDA graph was used, None otherwise + """ + if self.encoder_cudagraph_manager is None: + return None + + if not self.encoder_cudagraph_budget_mode: + return None + + # Extract grid_thw from kwargs + grid_thw = self._get_grid_thw_from_kwargs(mm_kwargs_group, modality) + if grid_thw is None: + return None + + # Extract pixel_values + if modality == "image": + pixel_values = mm_kwargs_group.get("pixel_values") + else: # video + pixel_values = mm_kwargs_group.get("pixel_values_videos") + + if pixel_values is None: + logger.debug("No pixel_values found in kwargs. Using eager mode.") + return None + + # Ensure pixel_values is on the correct device + pixel_values = pixel_values.to(device=self.device, dtype=self.dtype) + + # Get spatial merge size for token calculations + visual = getattr(model, "visual", None) + spatial_merge_size = getattr(visual, "spatial_merge_size", 2) + + return self._execute_budget_batch(pixel_values, grid_thw, spatial_merge_size) + + def _execute_budget_batch( + self, + pixel_values: torch.Tensor, + grid_thw: list[list[int]], + spatial_merge_size: int, + ) -> list[torch.Tensor] | None: + """ + Execute images using budget batch CUDA graphs. + + Sorts images by output token count (smallest first), greedily packs + them into budget-sized batches, and replays the appropriate CUDA graph. + + Args: + pixel_values: Concatenated pixel values for all images + grid_thw: List of [T, H, W] for each image + spatial_merge_size: Spatial merge size (e.g., 2) + + Returns: + List of per-image output tensors in original order, or None + """ + manager = self.encoder_cudagraph_manager + if manager is None or not manager.budget_graph_keys: + return None + + max_budget = max(manager.budget_graph_keys.keys()) + max_images = manager.max_images_per_batch + + # Compute per-image info: (output_tokens, input_patches, original_idx) + image_info: list[tuple[int, int, int]] = [] + for i, (t, h, w) in enumerate(grid_thw): + out_tokens = t * (h // spatial_merge_size) * (w // spatial_merge_size) + in_patches = t * h * w + image_info.append((out_tokens, in_patches, i)) + + # Sort by output tokens ascending (small first) + sorted_images = sorted(image_info, key=lambda x: x[0]) + + # Compute pixel_values offsets for each original image + patch_offsets = [0] + for t, h, w in grid_thw: + patch_offsets.append(patch_offsets[-1] + t * h * w) + + # Greedy packing into budget batches + batches: list[list[tuple[int, int, int]]] = [] + current_batch: list[tuple[int, int, int]] = [] + current_tokens = 0 + + for out_tokens, in_patches, orig_idx in sorted_images: + if ( + current_tokens + out_tokens <= max_budget + and len(current_batch) < max_images + ): + current_batch.append((out_tokens, in_patches, orig_idx)) + current_tokens += out_tokens + else: + if current_batch: + batches.append(current_batch) + current_batch = [(out_tokens, in_patches, orig_idx)] + current_tokens = out_tokens + + if current_batch: + batches.append(current_batch) + + # Execute each packed batch + outputs: list[torch.Tensor | None] = [None] * len(grid_thw) + + for batch in batches: + total_out_tokens = sum(out_tok for out_tok, _, _ in batch) + + # Find smallest budget graph that fits + graph_key = manager.find_budget_graph(total_out_tokens) + if graph_key is None: + # No budget fits - fall back entirely + logger.debug( + "No budget graph for %d tokens, falling back to eager", + total_out_tokens, + ) + return None + + # Concatenate pixel values in sorted order + pv_slices = [] + batch_grids = [] + for _, _, orig_idx in batch: + start = patch_offsets[orig_idx] + end = patch_offsets[orig_idx + 1] + pv_slices.append(pixel_values[start:end]) + batch_grids.append(grid_thw[orig_idx]) + + packed_pv = torch.cat(pv_slices, dim=0) + + # Run the budget graph + output = manager.run_batched_contiguous( + packed_pv, batch_grids, graph_key, spatial_merge_size + ) + if output is None: + logger.debug( + "Budget graph replay failed for key %s, falling back to eager", + graph_key, + ) + return None + + # Split output by per-image output token counts + offset = 0 + for out_tokens, _, orig_idx in batch: + outputs[orig_idx] = output[offset : offset + out_tokens].clone() + offset += out_tokens + + if manager.verbose: + logger.info( + "ViT BUDGET BATCH: %d images, %d tokens, graph_key=%s", + len(batch), + total_out_tokens, + graph_key, + ) + + # Check all images were processed + if any(o is None for o in outputs): + return None + + return outputs # type: ignore[return-value] + def gather_mm_embeddings( self, req_ids: list[str], diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index a55519f0fa36..427109bd5c8b 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -107,6 +107,7 @@ def __init__( hidden_size=self.inputs_embeds_size, dtype=self.dtype, device=self.device, + vllm_config=self.vllm_config, ) self.uses_mrope = self.model_config.uses_mrope if self.uses_mrope: @@ -425,6 +426,10 @@ def warmup_for_prefill(self) -> None: self._dummy_run(self.max_num_tokens, skip_attn=False) torch.cuda.synchronize() + # Capture encoder CUDA graphs if enabled + if self.supports_mm_inputs: + self.encoder_runner.capture_encoder_cudagraphs(self.model) + def finish_requests(self, scheduler_output: SchedulerOutput) -> None: if scheduler_output.preempted_req_ids is not None: for req_id in scheduler_output.preempted_req_ids: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 23d5bac75d00..1b76c7836727 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -155,6 +155,7 @@ from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin +from vllm.v1.worker.gpu.mm.encoder_cudagraph import EncoderCudaGraphManager from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin @@ -425,6 +426,15 @@ def __init__( # mm_hash -> encoder_output self.encoder_cache: dict[str, torch.Tensor] = {} + # Encoder CUDA graph manager for ViT + self.encoder_cudagraph_manager: EncoderCudaGraphManager | None = None + self.encoder_cudagraph_verbose: bool = False + self.encoder_cudagraph_budget_mode: bool = False + # Pre-allocated buffers for piecewise padded mode (lazily initialized) + # Key: capture_size (output tokens), Value: dict of buffers + self._piecewise_buffers: dict[int, dict[str, torch.Tensor]] = {} + self._init_encoder_cudagraph_manager() + self.use_aux_hidden_state_outputs = False # Set up speculative decoding. # NOTE(Jiayi): currently we put the entire draft model on @@ -680,6 +690,65 @@ def __init__( self.kv_connector_output: KVConnectorOutput | None = None self.layerwise_nvtx_hooks_registered = False + def _init_encoder_cudagraph_manager(self) -> None: + """Initialize encoder CUDA graph manager if enabled in config.""" + if self.compilation_config is None: + return + + # Always check verbose logging first (applies to all modes) + self.encoder_cudagraph_verbose = getattr( + self.compilation_config, + "encoder_cudagraph_verbose", + False, + ) + + # Check if piecewise encoder cudagraph mode is enabled + # In piecewise mode, torch.compile handles graph splitting at attention ops, + # so we don't need the full EncoderCudaGraphManager + encoder_cudagraph_piecewise = getattr( + self.compilation_config, "encoder_cudagraph_piecewise", False + ) + if encoder_cudagraph_piecewise: + compile_mm_encoder = getattr( + self.compilation_config, "compile_mm_encoder", False + ) + if not compile_mm_encoder: + logger.warning( + "encoder_cudagraph_piecewise=True requires " + "compile_mm_encoder=True. Piecewise encoder cudagraph " + "will not be effective." + ) + else: + logger.info( + "Piecewise encoder CUDA graph mode enabled. " + "torch.compile will handle graph splitting at attention ops." + ) + return + + if not getattr(self.compilation_config, "cudagraph_mm_encoder", False): + return + + encoder_graph_pool = torch.cuda.graph_pool_handle() + + self.encoder_cudagraph_manager = EncoderCudaGraphManager( + vllm_config=self.vllm_config, + device=self.device, + dtype=self.dtype, + graph_pool=encoder_graph_pool, + verbose=self.encoder_cudagraph_verbose, + ) + + # Check if budget batching is configured + self.encoder_cudagraph_budget_mode = bool( + self.encoder_cudagraph_manager.token_budgets + and self.encoder_cudagraph_manager.max_images_per_batch > 0 + ) + + logger.info( + "Encoder CUDA graph manager initialized: budget_mode=%s", + self.encoder_cudagraph_budget_mode, + ) + def update_max_model_len(self, max_model_len: int) -> None: self.max_model_len = max_model_len if self.speculative_config: @@ -713,7 +782,7 @@ def init_fp8_kv_scales(self) -> None: attn_layers = self.compilation_config.static_forward_context for name, module in attn_layers.items(): - if isinstance(module, (Attention, MLAAttention)): + if isinstance(module, Attention | MLAAttention): # TODO: Generally, scale is 1.0 if user uses on-the-fly fp8 # kvcache quant. However, to get better accuracy, compression # frameworks like llm-compressors allow users to tune the @@ -2302,20 +2371,48 @@ def _execute_mm_encoder( curr_group_outputs_lst.extend(micro_batch_outputs) curr_group_outputs = curr_group_outputs_lst + elif self.encoder_cudagraph_budget_mode: + # Budget batch mode: replaces grouped batch, one-by-one, + # exact match, and padded modes + budget_result = self._execute_budget_batch( + model, mm_kwargs_group, modality, num_items + ) + if budget_result is not None: + curr_group_outputs = budget_result + else: + # Fall back to eager + if self.encoder_cudagraph_verbose: + logger.info( + "ViT BUDGET BATCH fallback to eager: " + "modality=%s, num_items=%d", + modality, + num_items, + ) + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) else: - # Run the encoder. - # `curr_group_outputs` is either of the following: - # 1. A tensor of shape (num_items, feature_size, hidden_size) - # in case feature_size is fixed across all multimodal items. - # 2. A list or tuple (length: num_items) of tensors, - # each of shape (feature_size, hidden_size) in case the feature - # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) + # No budget mode: try piecewise -> eager + piecewise_result = None + piecewise_enabled = self.compilation_config is not None and getattr( + self.compilation_config, + "encoder_cudagraph_piecewise", + False, + ) + + if piecewise_enabled: + piecewise_result = self._execute_encoder_piecewise_padded( + model, mm_kwargs_group, modality + ) + + if piecewise_result is not None: + curr_group_outputs = piecewise_result + else: + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=num_items, ) + assert curr_group_outputs is not None # sanity_check ensures this encoder_outputs.extend(curr_group_outputs) # Cache the encoder outputs by mm_hash @@ -2324,8 +2421,759 @@ def _execute_mm_encoder( logger.debug("Finish execute for mm hash %s", mm_hash) self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) + # Log encoder CUDA graph stats periodically (verbose only) + if self.encoder_cudagraph_manager is not None: + self.encoder_cudagraph_manager.get_stats( + verbose=self.encoder_cudagraph_verbose + ) + return encoder_outputs + def _execute_budget_batch( + self, + model: "SupportsMultiModal", + mm_kwargs_group: dict, + modality: str, + num_items: int, + ) -> list[torch.Tensor] | None: + """ + Execute the encoder using budget batch CUDA graphs. + + Sorts images by output token count (smallest first), greedily packs + them into budget-sized batches, and replays the smallest fitting + CUDA graph. Falls back to None (eager) if any batch can't find a + fitting budget graph. + + Args: + model: The multimodal model + mm_kwargs_group: Batched multimodal kwargs + modality: The modality type ("image" or "video") + num_items: Number of items in the batch + + Returns: + List of encoder outputs if CUDA graph was used, None otherwise + """ + manager = self.encoder_cudagraph_manager + if manager is None or not manager.budget_graph_keys: + return None + + if modality not in ("image", "video"): + return None + + # Extract grid_thw + grid_thw = mm_kwargs_group.get("image_grid_thw") + if grid_thw is None: + grid_thw = mm_kwargs_group.get("video_grid_thw") + if grid_thw is None: + return None + + # Convert to list if tensor + if hasattr(grid_thw, "tolist"): + grid_thw = grid_thw.tolist() + + # Extract pixel_values + if modality == "image": + pixel_values = mm_kwargs_group.get("pixel_values") + else: # video + pixel_values = mm_kwargs_group.get("pixel_values_videos") + + if pixel_values is None: + return None + + # Ensure pixel_values is on the correct device + pixel_values = pixel_values.to( + device=self.device, dtype=self.dtype + ).contiguous() + + # Get spatial merge size + visual = getattr(model, "visual", None) + spatial_merge_size = getattr(visual, "spatial_merge_size", 2) + + max_budget = max(manager.budget_graph_keys.keys()) + max_images = manager.max_images_per_batch + + # Compute per-image info: (output_tokens, input_patches, orig_idx) + image_info: list[tuple[int, int, int]] = [] + for i, (t, h, w) in enumerate(grid_thw): + out_tokens = t * (h // spatial_merge_size) * (w // spatial_merge_size) + in_patches = t * h * w + image_info.append((out_tokens, in_patches, i)) + + # Sort by output tokens ascending (small first) + sorted_images = sorted(image_info, key=lambda x: x[0]) + + # Compute pixel_values offsets for each original image + patch_offsets = [0] + for t, h, w in grid_thw: + patch_offsets.append(patch_offsets[-1] + t * h * w) + + # Greedy packing into budget batches + batches: list[list[tuple[int, int, int]]] = [] + current_batch: list[tuple[int, int, int]] = [] + current_tokens = 0 + + for out_tokens, in_patches, orig_idx in sorted_images: + if ( + current_tokens + out_tokens <= max_budget + and len(current_batch) < max_images + ): + current_batch.append((out_tokens, in_patches, orig_idx)) + current_tokens += out_tokens + else: + if current_batch: + batches.append(current_batch) + current_batch = [(out_tokens, in_patches, orig_idx)] + current_tokens = out_tokens + + if current_batch: + batches.append(current_batch) + + # Execute each packed batch + outputs: list[torch.Tensor | None] = [None] * len(grid_thw) + + for batch in batches: + total_out_tokens = sum(out_tok for out_tok, _, _ in batch) + + # Find smallest budget graph that fits + graph_key = manager.find_budget_graph(total_out_tokens) + if graph_key is None: + logger.debug( + "No budget graph for %d tokens, falling back to eager", + total_out_tokens, + ) + return None + + # Concatenate pixel values in sorted order + pv_slices = [] + batch_grids = [] + for _, _, orig_idx in batch: + start = patch_offsets[orig_idx] + end = patch_offsets[orig_idx + 1] + pv_slices.append(pixel_values[start:end]) + batch_grids.append(grid_thw[orig_idx]) + + packed_pv = torch.cat(pv_slices, dim=0) + + # Run the budget graph + output = manager.run_batched_contiguous( + packed_pv, batch_grids, graph_key, spatial_merge_size + ) + if output is None: + logger.debug( + "Budget graph replay failed for key %s, falling back to eager", + graph_key, + ) + return None + + # Split output by per-image output token counts + offset = 0 + for out_tokens, _, orig_idx in batch: + outputs[orig_idx] = output[offset : offset + out_tokens].clone() + offset += out_tokens + + if self.encoder_cudagraph_verbose: + bs, gt, gh, gw = graph_key + budget_tokens = ( + bs * gt * (gh // spatial_merge_size) * (gw // spatial_merge_size) + ) + logger.info( + "ViT BUDGET BATCH: %d images, %d tokens, " + "budget=%d, waste=%.1f%%, graph_key=%s", + len(batch), + total_out_tokens, + budget_tokens, + (budget_tokens - total_out_tokens) / budget_tokens * 100, + graph_key, + ) + + # Check all images were processed + if any(o is None for o in outputs): + return None + + return outputs # type: ignore[return-value] + + def _find_nearest_encoder_capture_size(self, num_tokens: int) -> int | None: + """Find the smallest capture size >= num_tokens for piecewise mode. + + Args: + num_tokens: The actual number of output tokens + + Returns: + The nearest capture size, or None if no suitable size found + """ + if self.compilation_config is None: + return None + + capture_sizes = getattr( + self.compilation_config, "encoder_cudagraph_capture_sizes", None + ) + if capture_sizes is None or len(capture_sizes) == 0: + return None + + # Find smallest size >= num_tokens + for size in sorted(capture_sizes): + if size >= num_tokens: + return size + + # num_tokens exceeds all capture sizes + return None + + # Class-level counters for piecewise padded mode statistics + _piecewise_stats: dict = {} + + @classmethod + def _init_piecewise_stats(cls): + if not cls._piecewise_stats: + cls._piecewise_stats = { + "calls": 0, + "executions": 0, + "total_actual_tokens": 0, + "total_padded_tokens": 0, + "capture_size_hits": {}, # capture_size -> count + "fallback_reasons": {}, # reason -> count + } + + def _record_piecewise_fallback(self, reason: str): + self._init_piecewise_stats() + self._piecewise_stats["calls"] += 1 + self._piecewise_stats["fallback_reasons"][reason] = ( + self._piecewise_stats["fallback_reasons"].get(reason, 0) + 1 + ) + if self.encoder_cudagraph_verbose: + logger.info("ViT PIECEWISE fallback: %s", reason) + + @classmethod + def _record_piecewise_execution(cls, actual_tokens: int, capture_size: int): + cls._init_piecewise_stats() + cls._piecewise_stats["calls"] += 1 + cls._piecewise_stats["executions"] += 1 + cls._piecewise_stats["total_actual_tokens"] += actual_tokens + cls._piecewise_stats["total_padded_tokens"] += capture_size + cls._piecewise_stats["capture_size_hits"][capture_size] = ( + cls._piecewise_stats["capture_size_hits"].get(capture_size, 0) + 1 + ) + + @classmethod + def get_piecewise_stats_summary(cls) -> str: + cls._init_piecewise_stats() + stats = cls._piecewise_stats + if stats["calls"] == 0: + return "Piecewise padded: no calls" + + total_actual = stats["total_actual_tokens"] + total_padded = stats["total_padded_tokens"] + waste_pct = ( + (total_padded - total_actual) / total_padded * 100 + if total_padded > 0 + else 0 + ) + + lines = [ + "Piecewise padded stats:", + f" Calls: {stats['calls']}, Executions: {stats['executions']}", + f" Total actual tokens: {total_actual}", + f" Total padded tokens: {total_padded}", + f" Padding waste: {waste_pct:.1f}%", + f" Capture size hits: {stats['capture_size_hits']}", + f" Fallback reasons: {stats['fallback_reasons']}", + ] + return "\n".join(lines) + + def _execute_encoder_piecewise_padded( + self, + model: "SupportsMultiModal", + mm_kwargs_group: dict, + modality: str, + ) -> list[torch.Tensor] | None: + """Execute encoder with padding for piecewise cudagraph mode. + + Pre-computes embeddings outside the compiled graph, pads all tensors + to the nearest capture size, then calls forward_piecewise. This allows + cudagraph capture at fixed sizes while handling variable batch sizes. + + The key insight is that position embeddings depend on grid dimensions + and must be computed OUTSIDE the compiled graph. By pre-computing them + and padding, we can achieve cudagraph hits for the compiled regions. + + Args: + model: The multimodal model + mm_kwargs_group: Batched multimodal kwargs + modality: The modality type ("image" or "video") + + Returns: + List of encoder outputs if padding was applied, None otherwise + """ + if self.encoder_cudagraph_verbose: + logger.info( + "ViT PIECEWISE: _execute_encoder_piecewise_padded called, modality=%s", + modality, + ) + + # Only support image/video modalities + if modality not in ("image", "video"): + self._record_piecewise_fallback(f"unsupported_modality:{modality}") + return None + + # Extract grid_thw and pixel_values + grid_thw = mm_kwargs_group.get("image_grid_thw") + pixel_key = "pixel_values" + if grid_thw is None: + grid_thw = mm_kwargs_group.get("video_grid_thw") + pixel_key = "pixel_values_videos" + if grid_thw is None: + self._record_piecewise_fallback("no_grid_thw") + return None + + pixel_values = mm_kwargs_group.get(pixel_key) + if pixel_values is None: + self._record_piecewise_fallback("no_pixel_values") + return None + + # Convert to list if tensor + if hasattr(grid_thw, "tolist"): + grid_thw_list = grid_thw.tolist() + else: + grid_thw_list = list(grid_thw) + + # Get visual encoder and check for forward_piecewise support + visual = getattr(model, "visual", None) + if visual is None: + self._record_piecewise_fallback("no_visual_encoder") + return None + + # Check if forward_piecewise is available + if not hasattr(visual, "forward_piecewise"): + self._record_piecewise_fallback("no_forward_piecewise_method") + return None + + spatial_merge_size = getattr(visual, "spatial_merge_size", 2) + + # Calculate actual tokens + actual_num_patches = sum(t * h * w for t, h, w in grid_thw_list) + actual_output_tokens = actual_num_patches // (spatial_merge_size**2) + + # Find nearest capture size + capture_size = self._find_nearest_encoder_capture_size(actual_output_tokens) + if capture_size is None: + self._record_piecewise_fallback( + f"no_capture_size_for_{actual_output_tokens}_tokens" + ) + return None + + # Calculate padding needed + padding_output_tokens = capture_size - actual_output_tokens + padding_patches = padding_output_tokens * (spatial_merge_size**2) + padded_num_patches = capture_size * (spatial_merge_size**2) + + # Pre-compute embeddings for real images (OUTSIDE compiled graph) + # This is the key to making piecewise padding work + precomputed = visual.precompute_for_cudagraph(grid_thw_list) + pos_embeds = precomputed["pos_embeds"] + rotary_pos_emb_cos = precomputed["rotary_pos_emb_cos"] + rotary_pos_emb_sin = precomputed["rotary_pos_emb_sin"] + cu_seqlens = precomputed["cu_seqlens"] + max_seqlen = precomputed["max_seqlen"] + sequence_lengths = precomputed["sequence_lengths"] + + num_input_patches = pixel_values.shape[0] + + # Get or create pre-allocated buffers for this capture_size + # This avoids allocation and zeros kernels on every call + buffers = self._piecewise_buffers.get(capture_size) + if buffers is None: + # Lazily allocate buffers on first use for this capture_size + # Using torch.zeros ensures padding region is valid (not garbage) + # The zeros kernel only runs once per capture_size, not per call + buffers = { + "pixel_values": torch.zeros( + (padded_num_patches, pixel_values.shape[1]), + dtype=visual.dtype, + device=pixel_values.device, + ), + "pos_embeds": torch.zeros( + (padded_num_patches, pos_embeds.shape[1]), + dtype=pos_embeds.dtype, + device=pos_embeds.device, + ), + "rotary_cos": torch.zeros( + (padded_num_patches, rotary_pos_emb_cos.shape[1]), + dtype=rotary_pos_emb_cos.dtype, + device=rotary_pos_emb_cos.device, + ), + "rotary_sin": torch.zeros( + (padded_num_patches, rotary_pos_emb_sin.shape[1]), + dtype=rotary_pos_emb_sin.dtype, + device=rotary_pos_emb_sin.device, + ), + # Pre-allocate cu_seqlens with max possible entries + # (assuming max ~1000 images per batch is more than enough) + "cu_seqlens": torch.zeros( + (1001,), dtype=cu_seqlens.dtype, device=cu_seqlens.device + ), + "sequence_lengths": torch.zeros( + (1000,), + dtype=sequence_lengths.dtype, + device=sequence_lengths.device, + ), + } + self._piecewise_buffers[capture_size] = buffers + if self.encoder_cudagraph_verbose: + logger.info( + "ViT PIECEWISE: Allocated buffers for capture_size=%d (patches=%d)", + capture_size, + padded_num_patches, + ) + + # Copy data into pre-allocated buffers (no allocation, no zeros kernel) + padded_pixel_values = buffers["pixel_values"] + padded_pixel_values[:num_input_patches].copy_(pixel_values.type(visual.dtype)) + + padded_pos_embeds = buffers["pos_embeds"] + padded_pos_embeds[:num_input_patches].copy_(pos_embeds) + + padded_rotary_cos = buffers["rotary_cos"] + padded_rotary_cos[:num_input_patches].copy_(rotary_pos_emb_cos) + + padded_rotary_sin = buffers["rotary_sin"] + padded_rotary_sin[:num_input_patches].copy_(rotary_pos_emb_sin) + + # Update cu_seqlens to include padding as a separate sequence + num_seqs = cu_seqlens.shape[0] + padded_cu_seqlens = buffers["cu_seqlens"] + padded_cu_seqlens[:num_seqs].copy_(cu_seqlens) + if padding_patches > 0: + # Add padding sequence boundary + padded_cu_seqlens[num_seqs] = cu_seqlens[-1] + padding_patches + num_seqs += 1 + + num_seq_lens = sequence_lengths.shape[0] + padded_sequence_lengths = buffers["sequence_lengths"] + padded_sequence_lengths[:num_seq_lens].copy_(sequence_lengths) + if padding_patches > 0: + padded_sequence_lengths[num_seq_lens] = padding_patches + num_seq_lens += 1 + + # Slice to actual size needed + padded_cu_seqlens = padded_cu_seqlens[:num_seqs] + padded_sequence_lengths = padded_sequence_lengths[:num_seq_lens] + + # Update max_seqlen if padding sequence is larger + if padding_patches > max_seqlen.item(): + max_seqlen = torch.tensor( + padding_patches, dtype=max_seqlen.dtype, device=max_seqlen.device + ) + + # Call forward_piecewise directly with pre-computed and padded tensors + # Enable CUDA graph capture/replay by setting the proper forward context + batch_desc = BatchDescriptor(num_tokens=padded_num_patches) + with set_forward_context( + None, + self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=batch_desc, + ): + encoder_output = visual.forward_piecewise( + x=padded_pixel_values, + pos_embeds=padded_pos_embeds, + rotary_pos_emb_cos=padded_rotary_cos, + rotary_pos_emb_sin=padded_rotary_sin, + cu_seqlens=padded_cu_seqlens, + max_seqlen=max_seqlen, + sequence_lengths=padded_sequence_lengths, + ) + + # Split output by actual token counts for each image (exclude padding) + merge_size_sq = spatial_merge_size**2 + sizes = [t * h * w // merge_size_sq for t, h, w in grid_thw_list] + real_outputs = list(encoder_output[:actual_output_tokens].split(sizes)) + + # Record statistics + self._record_piecewise_execution(actual_output_tokens, capture_size) + + if self.encoder_cudagraph_verbose: + waste_pct = padding_output_tokens / capture_size * 100 + stats = self._piecewise_stats + total_waste_pct = ( + (stats["total_padded_tokens"] - stats["total_actual_tokens"]) + / stats["total_padded_tokens"] + * 100 + if stats["total_padded_tokens"] > 0 + else 0 + ) + logger.info( + "ViT PIECEWISE PADDED: actual=%d, capture_size=%d, " + "padding=%d (%.1f%%), num_images=%d | " + "cumulative: executions=%d, total_actual=%d, total_padded=%d, " + "waste=%.1f%%", + actual_output_tokens, + capture_size, + padding_output_tokens, + waste_pct, + len(grid_thw_list), + stats["executions"], + stats["total_actual_tokens"], + stats["total_padded_tokens"], + total_waste_pct, + ) + + return real_outputs + + def warmup_encoder_piecewise(self) -> None: + """Warmup and capture encoder piecewise compilation. + + This mimics LM's two-phase approach: + 1. Warmup phase: Compile ranges with fake tensors (is_exact_size=False) + 2. Capture phase: Compile all exact capture_sizes upfront (is_exact_size=True) + + This ensures no compilation happens during execution. + """ + if not getattr(self.compilation_config, "encoder_cudagraph_piecewise", False): + return + + capture_sizes = getattr( + self.compilation_config, "encoder_cudagraph_capture_sizes", None + ) + if capture_sizes is None or len(capture_sizes) == 0: + return + + # Get visual encoder + model = self.model + visual = getattr(model, "visual", None) + if visual is None or not hasattr(visual, "forward_piecewise"): + return + + # Assert for mypy - visual is not None after the check above + assert visual is not None + + spatial_merge_size = getattr(visual, "spatial_merge_size", 2) + merge_size_sq = spatial_merge_size**2 + + # Convert capture_sizes to patches + capture_sizes_patches = sorted( + [size * merge_size_sq for size in capture_sizes], + reverse=True, # Largest first like LM + ) + + # Helper to create dummy inputs for a given num_patches + def create_dummy_inputs(num_patches: int): + assert visual is not None # for mypy + patch_embed = getattr(visual, "patch_embed", None) + if patch_embed is not None: + temporal_patch_size = getattr(patch_embed, "temporal_patch_size", 2) + patch_size = getattr(patch_embed, "patch_size", 14) + proj = getattr(patch_embed, "proj", None) + if proj is not None: + raw_in_channels = getattr(proj, "in_channels", 3) + else: + raw_in_channels = 3 + input_channels = ( + raw_in_channels * temporal_patch_size * patch_size * patch_size + ) + else: + input_channels = 3 * 2 * 14 * 14 + + pixel_values = torch.zeros( + (num_patches, input_channels), + dtype=visual.dtype, + device=self.device, + ) + + hidden_size = getattr( + visual, "hidden_size", getattr(visual, "embed_dim", 1152) + ) + + pos_embeds = torch.zeros( + (num_patches, hidden_size), + dtype=visual.dtype, + device=self.device, + ) + + rotary_dim = hidden_size // getattr(visual, "num_heads", 16) // 2 + rotary_cos = torch.zeros( + (num_patches, rotary_dim), + dtype=visual.dtype, + device=self.device, + ) + rotary_sin = torch.zeros( + (num_patches, rotary_dim), + dtype=visual.dtype, + device=self.device, + ) + + cu_seqlens = torch.tensor( + [0, num_patches], dtype=torch.int32, device=self.device + ) + max_seqlen = torch.tensor(num_patches, device=self.device) + sequence_lengths = torch.tensor( + [num_patches], dtype=torch.int32, device=self.device + ) + + return ( + pixel_values, + pos_embeds, + rotary_cos, + rotary_sin, + cu_seqlens, + max_seqlen, + sequence_lengths, + ) + + def run_forward(num_patches: int): + assert visual is not None # for mypy + ( + pixel_values, + pos_embeds, + rotary_cos, + rotary_sin, + cu_seqlens, + max_seqlen, + sequence_lengths, + ) = create_dummy_inputs(num_patches) + + with set_forward_context(None, self.vllm_config): + _ = visual.forward_piecewise( + x=pixel_values, + pos_embeds=pos_embeds, + rotary_pos_emb_cos=rotary_cos, + rotary_pos_emb_sin=rotary_sin, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, + ) + + # ============================================================ + # Phase 1: Warmup compile_ranges with fake tensors + # ============================================================ + # Use a size that's NOT in capture_sizes to trigger range compilation + # with fake tensors (is_exact_size=False) + max_capture_size = max(capture_sizes) # In output tokens + warmup_size_tokens = max_capture_size + 1 + warmup_size = warmup_size_tokens * merge_size_sq + + # Make sure warmup_size is not an exact capture_size + capture_sizes_patches_set = set(capture_sizes_patches) + if warmup_size in capture_sizes_patches_set: + warmup_size_tokens = max_capture_size + 2 + warmup_size = warmup_size_tokens * merge_size_sq + + run_forward(warmup_size) + torch.cuda.empty_cache() + + def _capture_encoder_piecewise_cudagraphs(self) -> None: + """Capture encoder piecewise CUDA graphs for all capture sizes. + + Called during capture_model() when cudagraph capturing is enabled. + This triggers CUDAGraphWrapper to capture graphs for each size. + """ + capture_sizes = getattr( + self.compilation_config, "encoder_cudagraph_capture_sizes", None + ) + if capture_sizes is None or len(capture_sizes) == 0: + return + + model = self.model + visual = getattr(model, "visual", None) + if visual is None or not hasattr(visual, "forward_piecewise"): + return + + spatial_merge_size = getattr(visual, "spatial_merge_size", 2) + merge_size_sq = spatial_merge_size**2 + + # Convert capture_sizes to patches, largest first + capture_sizes_patches = sorted( + [size * merge_size_sq for size in capture_sizes], reverse=True + ) + + logger.info( + "Capturing encoder piecewise CUDA graphs for %d sizes", + len(capture_sizes_patches), + ) + + for num_patches in capture_sizes_patches: + # Create dummy inputs + patch_embed = getattr(visual, "patch_embed", None) + if patch_embed is not None: + temporal_patch_size = getattr(patch_embed, "temporal_patch_size", 2) + patch_size = getattr(patch_embed, "patch_size", 14) + proj = getattr(patch_embed, "proj", None) + raw_in_channels = getattr(proj, "in_channels", 3) if proj else 3 + input_channels = ( + raw_in_channels * temporal_patch_size * patch_size * patch_size + ) + else: + input_channels = 3 * 2 * 14 * 14 + + pixel_values = torch.zeros( + (num_patches, input_channels), + dtype=visual.dtype, + device=self.device, + ) + + hidden_size = getattr( + visual, "hidden_size", getattr(visual, "embed_dim", 1152) + ) + pos_embeds = torch.zeros( + (num_patches, hidden_size), + dtype=visual.dtype, + device=self.device, + ) + + rotary_dim = hidden_size // getattr(visual, "num_heads", 16) // 2 + rotary_cos = torch.zeros( + (num_patches, rotary_dim), + dtype=visual.dtype, + device=self.device, + ) + rotary_sin = torch.zeros( + (num_patches, rotary_dim), + dtype=visual.dtype, + device=self.device, + ) + + cu_seqlens = torch.tensor( + [0, num_patches], dtype=torch.int32, device=self.device + ) + max_seqlen = torch.tensor(num_patches, device=self.device) + sequence_lengths = torch.tensor( + [num_patches], dtype=torch.int32, device=self.device + ) + + # Two-pass capture like LM: + # Pass 1: NONE mode - triggers torch.compile without CUDA graph capture + with set_forward_context(None, self.vllm_config): + _ = visual.forward_piecewise( + x=pixel_values, + pos_embeds=pos_embeds, + rotary_pos_emb_cos=rotary_cos, + rotary_pos_emb_sin=rotary_sin, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, + ) + + # Pass 2: PIECEWISE mode - triggers CUDAGraphWrapper capture + # (compilation already done in pass 1) + batch_desc = BatchDescriptor(num_tokens=num_patches) + with set_forward_context( + None, + self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=batch_desc, + ): + _ = visual.forward_piecewise( + x=pixel_values, + pos_embeds=pos_embeds, + rotary_pos_emb_cos=rotary_cos, + rotary_pos_emb_sin=rotary_sin, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, + ) + + torch.cuda.empty_cache() + + logger.info("Encoder piecewise CUDA graph capture complete") + def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", @@ -2431,7 +3279,7 @@ def _gather_mm_embeddings( def get_model(self) -> nn.Module: # get raw model out of the cudagraph wrapper. - if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)): + if isinstance(self.model, CUDAGraphWrapper | UBatchWrapper): return self.model.unwrap() return self.model @@ -3994,7 +4842,7 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None layer_ids = hf_config.eagle_aux_hidden_state_layer_ids - if layer_ids and isinstance(layer_ids, (list, tuple)): + if layer_ids and isinstance(layer_ids, list | tuple): return tuple(layer_ids) return None @@ -4795,8 +5643,39 @@ def freeze_gc(): # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. set_cudagraph_capturing_enabled(True) + + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + start_total_memory = torch.cuda.mem_get_info()[1] + logger.info( + "Starting CUDA graph capture: %.2f GiB used, %.2f GiB free", + (start_total_memory - start_free_gpu_memory) / 1024**3, + start_free_gpu_memory / 1024**3, + ) + + # Capture encoder CUDA graphs first (if enabled) + # Encoder uses a dedicated graph pool separate from decoder, + # captured outside the decoder's graph_capture context for clean isolation + if self.encoder_cudagraph_manager is not None: + with freeze_gc(): + self._capture_encoder_cudagraphs() + after_encoder_free = torch.cuda.mem_get_info()[0] + encoder_mem = start_free_gpu_memory - after_encoder_free + logger.info( + "Encoder CUDA graphs captured: %.2f GiB used by encoder graphs, " + "%.2f GiB free", + encoder_mem / 1024**3, + after_encoder_free / 1024**3, + ) + + # Capture encoder piecewise CUDA graphs (if enabled) + if getattr(self.compilation_config, "encoder_cudagraph_piecewise", False): + with freeze_gc(): + self._capture_encoder_piecewise_cudagraphs() + + # Capture decoder/LM CUDA graphs in their own context with global pool with freeze_gc(), graph_capture(device=self.device): - start_free_gpu_memory = torch.cuda.mem_get_info()[0] + before_decoder_free = torch.cuda.mem_get_info()[0] + cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None @@ -4845,6 +5724,20 @@ def freeze_gc(): torch.cuda.synchronize() end_free_gpu_memory = torch.cuda.mem_get_info()[0] + decoder_mem = before_decoder_free - end_free_gpu_memory + logger.info( + "Decoder CUDA graphs captured: %.2f GiB used by decoder graphs, " + "%.2f GiB free", + decoder_mem / 1024**3, + end_free_gpu_memory / 1024**3, + ) + + total_cudagraph_mem = start_free_gpu_memory - end_free_gpu_memory + logger.info( + "CUDA graph capture complete: total %.2f GiB for all graphs, %.2f GiB free", + total_cudagraph_mem / 1024**3, + end_free_gpu_memory / 1024**3, + ) # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. @@ -4869,6 +5762,24 @@ def freeze_gc(): ) return cuda_graph_size + def _capture_encoder_cudagraphs(self) -> None: + """Capture CUDA graphs for the vision encoder.""" + if self.encoder_cudagraph_manager is None: + return + + model = self.model + if not hasattr(model, "visual") or model.visual is None: + logger.warning( + "Model does not have a visual encoder, " + "skipping encoder CUDA graph capture" + ) + return + + self.encoder_cudagraph_manager.capture( + vision_encoder=model.visual, + embed_multimodal_fn=model.embed_multimodal, + ) + def _capture_cudagraphs( self, compilation_cases: list[tuple[int, bool]], diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 013780479743..a15c547a7053 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -446,6 +446,12 @@ def compile_or_warm_up_model(self) -> None: # cuda graph capture. kernel_warmup(self) + # Warmup encoder piecewise cudagraph if enabled + # This pre-captures all encoder capture sizes to avoid compilation + # latency during actual execution + if hasattr(self.model_runner, "warmup_encoder_piecewise"): + self.model_runner.warmup_encoder_piecewise() + cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: cuda_graph_memory_bytes = self.model_runner.capture_model()