Skip to content

[tx] Per-layer gradient checkpointing with stacked decoder layers#996

Open
raulchen wants to merge 128 commits intoNovaSky-AI:mainfrom
raulchen:stack-weights
Open

[tx] Per-layer gradient checkpointing with stacked decoder layers#996
raulchen wants to merge 128 commits intoNovaSky-AI:mainfrom
raulchen:stack-weights

Conversation

@raulchen
Copy link
Contributor

@raulchen raulchen commented Jan 30, 2026

Summary

Implement per-layer gradient checkpointing using jax.lax.scan with permanently stacked decoder layer weights via
nnx.vmap. This reduces peak memory by ~num_layers factor during training while maintaining a unified code path for training
and inference.

Key Changes

1. Per-layer Gradient Checkpointing

  • Use jax.lax.scan with jax.checkpoint to recompute activations during backward pass
  • XLA compiles ONE loop body and reuses buffers, unlike Python loops which unroll N separate checkpoint regions
  • Enable via gradient_checkpointing=True in model config

2. Stacked Layer Weights

  • Stack decoder layer weights at initialization using nnx.vmap → shape (num_layers, ...)
  • Eliminates runtime stacking overhead (weights already in stacked format)
  • Single forward_layers() function for both training and inference
  • KV cache uses stacked format (num_layers, batch, seq, heads, dim)

3. DeepSeekV3 Split Stacking

  • Handles heterogeneous layers (dense MLP vs MoE) with separate stacks
  • dense_layers for initial layers, moe_layers for MoE layers
  • KV caches merged after forward pass

Files Changed

  • tx/models/utils.py - New: create_stacked_layers(), forward_layers()
  • tx/models/{llama3,qwen3,deepseekv3}.py - Use stacked layers
  • tx/layers/lora.py - Stacked LoRA indexing
  • tx/utils/models.py - Stack/unstack for HF checkpoint compatibility
  • tx/utils/generator.py - Stacked KV cache
  • tx/tinker/backends/jax.py - Fix gradient accumulation for stacked params

Test plan

  • Forward outputs match with/without checkpointing
  • Gradients match with/without checkpointing
  • All model tests pass (37)
  • All tinker tests pass (19)
  • DeepSeekV3 EP=2 tests pass

raulchen and others added 30 commits January 20, 2026 18:55
Compute lm_head projection in chunks to avoid materializing the full
[B*T, V] logits tensor. Key changes:

- Add compute_logits flag to model.__call__ (skip lm_head when False)
- Add lm_head weight to CausalLMOutput for external computation
- Implement chunked logprobs with jax.lax.map (default chunk_size=1024)
- Add loss_chunk_size config option

Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor.
For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ze<=0

The chunked cross-entropy path computes logits via direct matmul with
lm_head weight, bypassing LoRA adapters. This is incorrect when
train_unembed=True since LoRA should be applied to lm_head.

Changes:
- Rename is_training to skip_logits for clarity
- Add _use_chunked_loss flag to backend
- Automatically switch to non-chunked mode when:
  - train_unembed=True (requires LoRA on lm_head)
  - loss_chunk_size <= 0 (config-based disable)
- Non-chunked path uses pre-computed logits with LoRA correctly applied
Recompute activations during backward to save memory. Only one layer's
activations are held at a time during backward pass, reducing peak
memory by ~num_layers factor.

- Add gradient_checkpointing config to ModelConfig
- Apply jax.checkpoint per-layer when is_training=True
- Rename compute_logits to is_training (controls both logits and checkpointing)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…euse

Add _forward_layers_checkpointed() using jax.lax.fori_loop so XLA compiles
ONE loop body and reuses buffers during backward recomputation. With a
Python loop, XLA unrolls N separate checkpoint regions and can't optimize
buffer reuse across them.

Only enabled when gradient_checkpointing=True. Without checkpointing,
activations are stored anyway, so fori_loop's buffer reuse doesn't help
and its weight stacking overhead makes it worse.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- test_jax_backend.py: extend test_gradient_checkpointing to verify gradients match
- test_models_common.py: add common tests for Llama3/Qwen3 (output, hidden_states, edge cases)
Handle edge case where self.layers is empty to prevent IndexError.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Resolve conflicts in llama3.py and qwen3.py
- Integrate LogitsProcessor from main
- Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths
- Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers
- Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods:
  compute_logits(), compute_logprobs(), logits_to_logprobs()
- Model forward() returns only hidden_states (removed logits computation)
- Simplified CausalLMOutput: removed logits and lm_head fields
- Generator uses LogitsProcessor for all logits/logprobs computation
- Backend uses LogitsProcessor.compute_logprobs() with chunking
- Updated tests to use new LogitsProcessor API

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Create CausalLMBase class with compute_logits/compute_logprobs methods
- Models expose wrapper methods instead of direct LogitsProcessor access
- Update generator and jax.py backend to use model methods
- LogitsProcessor is now internal implementation detail

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace _has_train_unembed flag with _train_unembed_mask array
- Check at runtime if any adapter in batch needs LoRA on lm_head
- Use jax.lax.cond to choose chunked vs non-chunked path
- Handle adapter reuse correctly (reset mask on delete)
- Remove unused _use_chunked_loss flag

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace abstract property with __init__(lm_head) in base class
- Subclasses explicitly call CausalLMBase.__init__(self, lm_head)
- Fix test to support multiple adapters for mixed train_unembed test

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Apply softmax to all router logits before top-k selection, not after.
This matches HF's implementation and fixes ~1.3x output scaling error.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
return None


def _adapter_index(is_stacked: bool, adapter_index: int):
Copy link
Collaborator

@pcmoritz pcmoritz Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change this function to

def get_adapter_idx(path: str, adapter_index: int):
    is_stacked = _is_stacked(path)
    return (slice(None), adapter_index) if is_stacked else (adapter_index,)

and then use it everywhere in the PR. I believe all the call sited of _adapter_index are already of this form, and there are lots more call sites where we can get rid of a pattern like

if is_stacked:
     # Process stacked weights
else:
     # Process unstacked weights

This can always be done like

idx = get_adapter_idx(path, adapter_index)
# Process weights with weights[idx,...]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea. code could be simpler

raise ValueError("The 'learning_rate' key must be provided in optimizer_args.")


def _lora_slice(is_stacked: bool, adapter_index: int, rank: int, is_lora_A: bool) -> tuple:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need this function? Can't we just do things like

idx = get_adapter_idx(path, adapter_index)
p.at[idx, ..., :, :rank]
p.at[idx, ..., :rank, :]

- Refactor lora_test_utils.py to reduce duplication with _slice_out_of_rank helper
- Simplify DeepseekV3 decoder layers by passing mlp_cls instead of subclassing
- Add KVCache.split() and concatenate() methods for layer-wise cache operations

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
if "layers" not in path_strs:
return False
layers_idx = path_strs.index("layers")
if layers_idx + 1 < len(path_strs) and path_strs[layers_idx + 1].isdigit():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this case is needed? If it isn't, we only need one method here and can always use the logic of is_stacked_lora_path.

My rationale for this question is, if we always stack the layers that are stackable, a case like ('model', 'layers', '0', 'self_attn', ...) should never happen, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was added in the mid-ground state when both stacked and non-stacked are supported.
Should be non longer needed.

Introduces get_adapter_idx(path, adapter_index) that encapsulates
the stacked vs non-stacked adapter indexing logic. Removes duplicate
if/else patterns across the codebase.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@pcmoritz pcmoritz added the tx label Jan 30, 2026
raulchen and others added 2 commits January 30, 2026 16:00
Use is_stacked_lora_path directly since we always stack layers.
The digit check for non-stacked format is no longer needed.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
raulchen and others added 2 commits January 30, 2026 16:15
Make split() and concatenate() handle None for empty layer groups.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
raulchen and others added 3 commits January 30, 2026 17:32
When a single layer is extracted from stacked layers via x[layer_idx],
the tensor loses a dimension but the PartitionSpec metadata still has
the extra leading None (from vmap transform_metadata). Truncate the spec
from the beginning to match the actual tensor rank.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
pcmoritz added a commit that referenced this pull request Jan 31, 2026
This is in preparation for merging
#996, so we don't need to depend
on the jax tracer. It is also slightly cleaner this way and the assert
is not needed any more, since the error is "defined away".

It also adds the FSDP sharding for llama3.
pcmoritz and others added 3 commits January 30, 2026 21:23
Replace nnx.vmap with individual layer creation + jnp.stack.
vmap breaks eager sharding, causing ~4x memory overhead due to
full model replication instead of tensor-parallel sharding.

The new approach:
- Creates layers individually with a Python loop (respects eager sharding)
- Stacks parameters using jit with donate_argnums to reduce peak memory
- Preserves correct sharding specs on stacked arrays

Memory improvement (per GPU, Qwen3-4B with tp=8):
- nnx.List baseline: 1461 MiB
- Old vmap approach: 4533 MiB
- New loop+stack:    2485 MiB

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
raulchen and others added 4 commits February 4, 2026 13:50
Instead of creating all layers then stacking (which requires holding
both original arrays and stacked arrays simultaneously), pre-allocate
the stacked arrays and copy each layer's params directly using
dynamic_update_slice. This keeps only one layer in memory at a time.

Memory improvement during layer creation:
- Old: JAX peak ~2098 MiB (originals + stacked arrays)
- New: JAX peak ~1316 MiB (stacked arrays + 1 layer)

Also adds memory logging via nvidia-smi and JAX memory_stats for
debugging memory usage throughout the layer creation process.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Pass KV cache keys/values as part of scan carry instead of xs,
enabling JAX buffer donation for effective in-place updates.
This reduces peak memory during decode from 10793 MiB to 6697 MiB
(38% reduction) by avoiding duplicate cache allocation.

Also unifies the body_fn and scan call for prefill/decode paths.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove logging/debugging code (_log_mem, subprocess calls)
- Use .at[idx].set() instead of dynamic_update_slice (cleaner syntax)
- Keep donate_argnums=(0,) for buffer reuse (key to memory efficiency)
- Reduce code from ~80 lines to ~40 lines

Memory benchmark unchanged at 6697 MiB (vs 10797 MiB with vmap).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@raulchen
Copy link
Contributor Author

raulchen commented Feb 5, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed refactoring to improve memory efficiency during training by implementing per-layer gradient checkpointing using jax.lax.scan and stacking decoder layer weights at initialization. While the architectural changes are impressive and include a unified code path for training and inference with comprehensive test coverage, a security audit identified several critical vulnerabilities. These include arbitrary method execution in the distributed backend, potential SSRF/path traversal in checkpoint loading, and missing ownership checks in model management, which must be addressed for system robustness and security, especially in multi-tenant or exposed environments. Furthermore, a correctness issue was found related to the collection of hidden states, affecting a corresponding test.

else:
new_kv_cache = None if is_training else KVCache.from_layer_outputs(all_keys, all_values, attention_mask)

all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When output_hidden_states is true, the list of hidden states is missing the output of the final decoder layer. The current implementation list(all_hs[:-1]) excludes the last layer's output from the scan operation.

To ensure all intermediate hidden states are collected, you should include all outputs from the scan. The list of hidden states returned by forward_layers should contain the initial embeddings plus the output of every decoder layer.

After this change, you will also need to update the assertion in test_hidden_states_length_matches to expect num_hidden_layers + 2 states (embeddings + N layer outputs + final normed output).

Suggested change
all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else []
all_hidden_states = [hidden_states] + list(all_hs) if output_hidden_states else []

hidden_states_ckpt = out.hidden_states
del out

assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == num_hidden_layers + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion incorrectly expects num_hidden_layers + 1 hidden states. With the recommended fix in tx/models/utils.py to correctly collect all hidden states, the model will return num_hidden_layers + 2 states when output_hidden_states=True (initial embeddings, N layer outputs, and the final normalized output). Please update this assertion to reflect the correct number of hidden states.

Suggested change
assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == num_hidden_layers + 1
assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == num_hidden_layers + 2

tree_unflatten creates Variables with metadata from the original treedef,
which doesn't include the stacked sharding. NNX APIs (get_partition_spec,
Optimizer) read from 'sharding_names' metadata rather than array.sharding,
so we sync them after unflatten.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants