Skip to content

[WebGPU] Add compact attention mask support#2022

Open
qjia7 wants to merge 7 commits intomainfrom
compact_attention_mask_v2
Open

[WebGPU] Add compact attention mask support#2022
qjia7 wants to merge 7 commits intomainfrom
compact_attention_mask_v2

Conversation

@qjia7
Copy link
Copy Markdown
Contributor

@qjia7 qjia7 commented Mar 13, 2026

Problem

The standard attention mask [batch_size, total_sequence_length] has below disadvantages for WebGPU graph capture:

  1. Large pre-allocation — Graph capture requires all tensor shapes to be constant across runs. So the mask must be pre-allocated as [batch_size, max_length] (e.g., 4096) even when the actual sequence is short (e.g., 100 tokens). This wastes GPU memory.

  2. Full buffer copy every step — ORT WebGPU's CopyTensors supports creating a smaller tensor view on an existing buffer, but only starting from offset 0 — you can't copy to an arbitrary offset within the buffer. For the attention mask, each decode step appends a 1 at the end (position = total_seq_len - 1). Since we can't copy starting from that offset, we must copy all previous values from the start — i.e., copy total_seq_len elements every step instead of just the 1 new element. As the sequence grows, this copy grows linearly (100 elements at step 100, 1000 at step 1000, etc.).

  3. Additionally, the ONNX graph contains ReduceSum/ReduceMax subgraphs to derive seqlens_k/total_seq_len from the full binary mask, even though GenAI already knows these values on the CPU.

Solution

Compact attention mask [batch_size, 1] — a single scalar per batch containing the total sequence length. This eliminates the large buffer, replaces the full-buffer copy with a 4-byte scalar copy, and simplifies the ONNX graph.

Result

Graph capture works with continuous decoding. Simpler ONNX graph. No large static mask allocation needed.

@qjia7 qjia7 marked this pull request as ready for review March 13, 2026 10:18
Copilot AI review requested due to automatic review settings March 13, 2026 10:18
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an optional “compact attention mask” mode that represents attention mask inputs as per-batch sequence lengths ([batch_size, 1]) rather than a full binary mask ([batch_size, total_sequence_length]), targeting simpler graphs and improved WebGPU graph-capture behavior.

Changes:

  • Introduces compact_attention_mask configuration (Python model builder emits it; C++ config parses/uses it).
  • Extends DeviceInterface with UpdateCompactAttentionMask and implements it for CPU and WebGPU.
  • Updates runtime position/mask input handling to create/update the compact mask.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
src/webgpu/interface.cpp Adds WebGPU implementation to upload compact mask values via ORT CopyTensors.
src/smartptrs.h Adds DeviceInterface::UpdateCompactAttentionMask virtual hook.
src/python/py/models/builders/base.py Switches attention_mask input shape in compact mode; adds compact GQA mask reformatting subgraph; writes config flag.
src/python/py/models/builder.py Adds CLI/extra option plumbing + help text for compact_attention_mask.
src/models/position_inputs.h Declares compact mask creation/update helpers.
src/models/position_inputs.cpp Implements compact mask creation + update paths in DefaultPositionInputs.
src/cpu/interface.cpp Implements CPU UpdateCompactAttentionMask (fills lengths).
src/config.h Adds decoder.compact_attention_mask flag to config struct.
src/config.cpp Parses compact_attention_mask from genai_config.json.

You can also share your feedback on Copilot code review. Take the survey.

@qjia7 qjia7 force-pushed the compact_attention_mask_v2 branch from d0b7db2 to ac3fbfe Compare March 24, 2026 06:45
@qjia7 qjia7 requested a review from Copilot March 24, 2026 07:02
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds end-to-end support for a “compact attention mask” mode, where the decoder attention mask is represented as [batch_size, 1] (sequence lengths) instead of a full [batch_size, total_sequence_length] binary mask, improving graph simplicity and reducing work—especially for WebGPU graph capture and continuous decoding.

Changes:

  • Add a new model.decoder.compact_attention_mask config flag and wire it through config parsing and Python model building.
  • Extend the device interface with UpdateCompactAttentionMask, implementing it for CPU and WebGPU.
  • Update runtime position/mask input handling to initialize/update/rewind compact masks (including a WebGPU-friendly fast path for batch_size == 1).

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
src/webgpu/interface.cpp Implements WebGPU device-side update path for compact attention mask via a single CPU→GPU copy.
src/cpu/interface.cpp Implements CPU write logic for compact attention mask values.
src/smartptrs.h Extends DeviceInterface with UpdateCompactAttentionMask.
src/models/position_inputs.h Declares compact mask initialization/update helpers.
src/models/position_inputs.cpp Adds compact mask initialization, per-step update logic, and rewind handling (incl. WebGPU fast path for batch=1).
src/python/py/models/builders/base.py Adds GQA subgraph generation for compact attention mask mode and overrides attention_mask input shape.
src/python/py/models/builder.py Exposes compact_attention_mask as a validated boolean extra option and documents it in CLI help.
src/config.h Adds compact_attention_mask to the decoder config schema.
src/config.cpp Parses compact_attention_mask from genai_config.json.

@qjia7 qjia7 force-pushed the compact_attention_mask_v2 branch from ce68bb1 to 4d1d66b Compare March 27, 2026 02:06
@qjia7 qjia7 force-pushed the compact_attention_mask_v2 branch from 4d1d66b to fd35276 Compare March 27, 2026 02:20
@qjia7
Copy link
Copy Markdown
Contributor Author

qjia7 commented Mar 27, 2026

@kunal-vaishnavi @baijumeswani Please take a look, thanks.

mask_data[0] = static_cast<T>(prompt_length);
} else {
// Count non-pad tokens per batch to get the effective sequence length
const auto* word_id = const_cast<DeviceSpan<int32_t>&>(next_tokens).CpuSpan().data();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like it is doing the same thing as the ONNX model already performs within itself: count the number of tokens to get the sequence length. Why not let the ONNX model do this work with a boolean mask of size [batch_size, total_sequence_length]? If the issue is about dtypes, we can always convert the dtype internally while keeping the existing input/output (I/O) structure.

In other words, what is the benefit of doing this? We intentionally picked a boolean mask of shape [batch_size, total_sequence_length] to standardize the I/O structure across models, to support a wide range of attention ops across EPs, and to allow the produced ONNX models to match and run with other solutions such as Optimum and Transformers.js for direct comparisons.

Does graph capture not work with the existing attention mask? If yes, then why do we have the following subgraph?

def make_attention_mask_graph_capture_reformatting_for_gqa(self, attn_mask_basename):
# Make nodes for the attention mask subgraph that calculates
# attributes about the 2D attention mask to use in GroupQueryAttention
#
# Key difference vs make_attention_mask_standard_reformatting_for_gqa:
# - Standard mode: total_seq_len is calculated from Shape op (always runs on CPU)
# - Graph capture mode: No Shape ops inserted to ensure all ops run on GPU (no CPU ops)
#
# attention_mask
# |
# Cast to int32
# |
# ReduceSum (keepdims=0)
# / \
# / \
# Sub ReduceMax
# | |
# seqlens_k total_seq_len
# (1D) (int)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not let the ONNX model do this work with a boolean mask of size [batch_size, total_sequence_length]?

In current genai logic, it seems that we have known the total_sequence_length in cpu, but still pass a boolean mask of size [batch_size, total_sequence_length] to the model. And let the model to get the total_sequence_length again. It seems that the purpose of attention mask is let model get the total sequence length. Why we can't directly pass the total_sequence_length to the model? In this case, the model itself doesn't need to calculate the total_sequence_length/seqlens_k from a boolean mask again. The benefit is that

  1. we don't need to allocate a big [batch_size, max_length] as the static tensor to support graph capture and the model itself can also be simplified.
  2. The interface's UpdateCompactAttentionMask method can be easily implemented by CopyTensors API. However, for the boolean attention mask, I need a CopyTensors with offset support to update a partial part of attention mask. (We haven't gotten an agreement how to support CopyTensors with offset in ORT)

That's why I want to support the compact attention which can perfectly resolve above two issues.

Ideally, I hope compact attention mask is only used when batch_size = 1. But we can't forbid users to use it when batch_size > 1. That's why I also add the else branch when batch_size > 1. How do you think?


std::vector<PipelineModel> pipeline;

bool compact_attention_mask{false}; // When true, attention_mask has shape [batch_size, 1] with the total sequence length value
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having an attribute and then adding a boolean here, couldn't we check the shape of the attention mask tensor to know whether it is compact or not? I don't think a user should need to specify this. It should be inferrable from the ONNX model.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config flag is needed because a [batch_size, 1] attention mask shape is ambiguous at runtime — it could be a compact mask (fixed shape, value = total_seq_len) or a standard binary mask where total_sequence_length happens to be 1.

@qjia7 qjia7 requested a review from kunal-vaishnavi April 2, 2026 09:53
qjia7 added 2 commits April 3, 2026 10:14
When enable_webgpu_graph=1 is passed to the builder, automatically
enable compact_attention_mask unless the user explicitly set it.
This ensures graph capture models use the optimized compact mask
by default, while still allowing opt-out via compact_attention_mask=0.
@qjia7 qjia7 force-pushed the compact_attention_mask_v2 branch from 50289b6 to c9162ba Compare April 3, 2026 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants