Skip to content

Per weight constant cache#18901

Draft
Gasoonjia wants to merge 4 commits intomainfrom
per-weight-constant-cache
Draft

Per weight constant cache#18901
Gasoonjia wants to merge 4 commits intomainfrom
per-weight-constant-cache

Conversation

@Gasoonjia
Copy link
Copy Markdown
Contributor

Problem: Multi-method AOTI models (e.g., Qwen3.5 MoE with separate prefill/decode methods) load the full weight blob independently for each method, even when they share identical weights. This causes duplicate GPU allocations -- Qwen3.5 MoE peaked at ~35 GB, making it impossible to run on a single 24 GB GPU (e.g., 4090).

Solution: Introduce a per-weight FQN-keyed constant cache in CudaBackend. The first method loads its constants from the blob and caches them. Subsequent methods with matching FQNs skip blob loading entirely and reuse cached GPU tensors via update_user_managed_constant_buffer_pairs. A legacy fallback path is preserved for older AOTI models without constant management APIs.

Results
Peak GPU memory: 35.4 GB → 17.6 GB (-50%)

gasoonjia and others added 3 commits April 14, 2026 18:21
Replace the old update_constants_from_blob + cross-method sharing with a
unified per-weight caching approach. The first method to initialize loads
its constants from the blob and caches them by FQN. Subsequent methods
with matching FQNs reuse cached GPU tensors via
update_user_managed_constant_buffer_pairs, skipping blob loading entirely.

This eliminates duplicate GPU weight allocations for multi-method models
(e.g., prefill/decode), reducing peak GPU memory from ~35 GB to ~17.6 GB
for Qwen 3.5 MoE.

Also adds GPU peak memory reporting to the Qwen3.5 MoE runner and a
CI check (< 20 GB) in test_model_e2e.sh.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 15, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18901

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 15 New Failures, 3 Unrelated Failures

As of commit 9aeea5f with merge base 87e65ac (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 15, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@digantdesai
Copy link
Copy Markdown
Contributor

@claude review this PR.

@claude
Copy link
Copy Markdown

claude bot commented Apr 15, 2026

Claude finished @digantdesai's task in 1m 47s —— View job


Review of PR #18901 — Per-weight constant cache

  • Read changed files and PR diff
  • Review backends/cuda/runtime/cuda_backend.cpp
  • Review backends/aoti/aoti_delegate_handle.h
  • Review examples/models/qwen3_5_moe/main.cpp
  • Review .ci/scripts/test_model_e2e.sh
  • Post review summary

Overall: Good improvement. Replacing the share_kv_cache_across_methods compile-spec opt-in with automatic per-weight caching is the right direction — it's simpler for callers and handles the general case (shared weights, not just KV cache). The legacy fallback is a reasonable safety net. A few issues to address:


Issues

1. Dangling pointers in update_user_managed_constant_buffer_pairs callcuda_backend.cpp:835-843

The pairs vector stores AOTInductorConstantMapEntry with const char* pointing into fqn_to_name's key/value std::string objects. But fqn_to_name is a local variable on the stack. If the structured binding internal_name is a copy (it is — structured bindings from unordered_map give references, but the const auto& should be fine here), the pointers remain valid. However, there's a subtler issue: the lock on shared_constants_mutex_ is released at line 843 before pairs is used at line 847. The pairs entries reference internal_name.c_str() from fqn_to_name (still alive), so that's fine. But the it->second (AtenTensorHandle) is read from shared_constant_tensors_ inside the lock, so the handle value is captured by copy into pairs — this is safe.

Actually, on closer inspection this is fine since fqn_to_name outlives pairs, and the AtenTensorHandle is an opaque pointer copied by value. No issue here on lifetime. Disregard.

2. allocate_constant_on_gpu is dead codecuda_backend.cpp:669-698

The allocate_constant_on_gpu static method is defined but never called anywhere in this PR. The three new function pointer types (get_constant_data_size, get_constant_from_folded, get_constants_blob_size) are loaded but also never used. If these are intended for a future per-weight granular loading optimization (loading individual constants instead of the full blob), that's understandable, but as-is they are unused code.

Remove dead code →

3. First method still loads the full blob even when some constants are already cachedcuda_backend.cpp:765-793

When uncached_fqns is non-empty, the code calls update_constants_from_blob which loads all constants from the blob into GPU memory, not just the uncached ones. It then extracts and caches the new ones. This means the first method still does a full blob load, and if a second method has partially overlapping constants, it will also do a full blob load for its uncached subset. The "full blob load + extract" approach works for the common case (100% overlap between prefill/decode), but for partial overlap it wastes GPU memory transiently.

This is likely acceptable for now since the primary use case is full overlap, but worth a comment noting this limitation.

4. GPU memory measurement is not peak memoryexamples/models/qwen3_5_moe/main.cpp:253-260

cudaMemGetInfo at the end of execution reports current GPU memory usage, not peak. If memory was allocated and freed during execution (e.g., the old blob approach), the peak would have been higher but would not be captured. The printf says "peak memory usage" which is misleading. Consider using cudaMemPoolGetAttribute with cudaMemPoolAttrUsedMemHigh or tracking allocations over time. For this PR's purpose (verifying the cache avoids duplicates) it works as a sanity check since the weights stay allocated, but the label is inaccurate.

5. Missing blank line between function and member variablecuda_backend.cpp:900-901

    return Error::Ok;
  }
  mutable std::mutex shared_constants_mutex_;

There should be a blank line between load_constants_legacy and the member variable declaration for readability (matching the style used elsewhere in this class).

6. Thread safety: get_constant_name / get_constant_original_fqn called under lockcuda_backend.cpp:739-752

The calls to handle->get_constant_name() and handle->get_constant_original_fqn() happen while holding shared_constants_mutex_. These are AOTI runtime calls that shouldn't be slow, but holding a mutex during external library calls is a pattern to be cautious about. Since the lock is needed to check shared_constant_tensors_, consider splitting: enumerate constants first (no lock needed), then take the lock only for the cache lookup.

7. CI test uses -oP (PCRE) in grep.ci/scripts/test_model_e2e.sh:406

PEAK_MEM=$(echo "$OUTPUT" | grep -oP 'GPU peak memory usage: \K[0-9.]+' || true)

grep -P requires PCRE support which may not be available in all CI environments (e.g., macOS grep doesn't support -P). Since this CI likely runs on Linux this is probably fine, but sed or grep -oE with a capture group would be more portable.


Minor / Nits

  • The comment at line 660 mentions allocate_constant_on_gpu() as "the allocation primitive" — if the method is removed per issue Re-sync with internal repository #2, update this comment block.
  • The comment block at lines 651-664 is thorough and well-written.
  • constants_extracted_ member removal is clean — the cache naturally handles first-vs-subsequent without a separate flag.

Looks Good

  • The legacy fallback (load_constants_legacy) correctly preserves the old behavior for models compiled with older torch versions.
  • Removing the share_kv_cache_across_methods compile spec simplifies the API surface — users no longer need to opt in.
  • The AOTIDelegateHandle additions follow the existing pattern.
  • Error handling in load_constants_with_cache properly returns errors up the call chain via ET_CHECK_OK_OR_RETURN_ERROR.

@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results April 15, 2026 03:47 — with GitHub Actions Inactive
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants