[tx] Per-layer gradient checkpointing with stacked decoder layers#996
[tx] Per-layer gradient checkpointing with stacked decoder layers#996raulchen wants to merge 128 commits intoNovaSky-AI:mainfrom
Conversation
Compute lm_head projection in chunks to avoid materializing the full [B*T, V] logits tensor. Key changes: - Add compute_logits flag to model.__call__ (skip lm_head when False) - Add lm_head weight to CausalLMOutput for external computation - Implement chunked logprobs with jax.lax.map (default chunk_size=1024) - Add loss_chunk_size config option Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor. For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ze<=0 The chunked cross-entropy path computes logits via direct matmul with lm_head weight, bypassing LoRA adapters. This is incorrect when train_unembed=True since LoRA should be applied to lm_head. Changes: - Rename is_training to skip_logits for clarity - Add _use_chunked_loss flag to backend - Automatically switch to non-chunked mode when: - train_unembed=True (requires LoRA on lm_head) - loss_chunk_size <= 0 (config-based disable) - Non-chunked path uses pre-computed logits with LoRA correctly applied
Recompute activations during backward to save memory. Only one layer's activations are held at a time during backward pass, reducing peak memory by ~num_layers factor. - Add gradient_checkpointing config to ModelConfig - Apply jax.checkpoint per-layer when is_training=True - Rename compute_logits to is_training (controls both logits and checkpointing) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…euse Add _forward_layers_checkpointed() using jax.lax.fori_loop so XLA compiles ONE loop body and reuses buffers during backward recomputation. With a Python loop, XLA unrolls N separate checkpoint regions and can't optimize buffer reuse across them. Only enabled when gradient_checkpointing=True. Without checkpointing, activations are stored anyway, so fori_loop's buffer reuse doesn't help and its weight stacking overhead makes it worse. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- test_jax_backend.py: extend test_gradient_checkpointing to verify gradients match - test_models_common.py: add common tests for Llama3/Qwen3 (output, hidden_states, edge cases)
Handle edge case where self.layers is empty to prevent IndexError. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Resolve conflicts in llama3.py and qwen3.py - Integrate LogitsProcessor from main - Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths - Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers - Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods: compute_logits(), compute_logprobs(), logits_to_logprobs() - Model forward() returns only hidden_states (removed logits computation) - Simplified CausalLMOutput: removed logits and lm_head fields - Generator uses LogitsProcessor for all logits/logprobs computation - Backend uses LogitsProcessor.compute_logprobs() with chunking - Updated tests to use new LogitsProcessor API Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Create CausalLMBase class with compute_logits/compute_logprobs methods - Models expose wrapper methods instead of direct LogitsProcessor access - Update generator and jax.py backend to use model methods - LogitsProcessor is now internal implementation detail Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace _has_train_unembed flag with _train_unembed_mask array - Check at runtime if any adapter in batch needs LoRA on lm_head - Use jax.lax.cond to choose chunked vs non-chunked path - Handle adapter reuse correctly (reset mask on delete) - Remove unused _use_chunked_loss flag Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace abstract property with __init__(lm_head) in base class - Subclasses explicitly call CausalLMBase.__init__(self, lm_head) - Fix test to support multiple adapters for mixed train_unembed test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Apply softmax to all router logits before top-k selection, not after. This matches HF's implementation and fixes ~1.3x output scaling error. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
skyrl-tx/tx/layers/lora.py
Outdated
| return None | ||
|
|
||
|
|
||
| def _adapter_index(is_stacked: bool, adapter_index: int): |
There was a problem hiding this comment.
I think we should change this function to
def get_adapter_idx(path: str, adapter_index: int):
is_stacked = _is_stacked(path)
return (slice(None), adapter_index) if is_stacked else (adapter_index,)and then use it everywhere in the PR. I believe all the call sited of _adapter_index are already of this form, and there are lots more call sites where we can get rid of a pattern like
if is_stacked:
# Process stacked weights
else:
# Process unstacked weightsThis can always be done like
idx = get_adapter_idx(path, adapter_index)
# Process weights with weights[idx,...]There was a problem hiding this comment.
good idea. code could be simpler
skyrl-tx/tx/utils/models.py
Outdated
| raise ValueError("The 'learning_rate' key must be provided in optimizer_args.") | ||
|
|
||
|
|
||
| def _lora_slice(is_stacked: bool, adapter_index: int, rank: int, is_lora_A: bool) -> tuple: |
There was a problem hiding this comment.
Do we actually need this function? Can't we just do things like
idx = get_adapter_idx(path, adapter_index)
p.at[idx, ..., :, :rank]
p.at[idx, ..., :rank, :]- Refactor lora_test_utils.py to reduce duplication with _slice_out_of_rank helper - Simplify DeepseekV3 decoder layers by passing mlp_cls instead of subclassing - Add KVCache.split() and concatenate() methods for layer-wise cache operations Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
skyrl-tx/tx/utils/models.py
Outdated
| if "layers" not in path_strs: | ||
| return False | ||
| layers_idx = path_strs.index("layers") | ||
| if layers_idx + 1 < len(path_strs) and path_strs[layers_idx + 1].isdigit(): |
There was a problem hiding this comment.
Do you know why this case is needed? If it isn't, we only need one method here and can always use the logic of is_stacked_lora_path.
My rationale for this question is, if we always stack the layers that are stackable, a case like ('model', 'layers', '0', 'self_attn', ...) should never happen, right?
There was a problem hiding this comment.
this was added in the mid-ground state when both stacked and non-stacked are supported.
Should be non longer needed.
Introduces get_adapter_idx(path, adapter_index) that encapsulates the stacked vs non-stacked adapter indexing logic. Removes duplicate if/else patterns across the codebase. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This reverts commit 7d5bf5b.
Use is_stacked_lora_path directly since we always stack layers. The digit check for non-stacked format is no longer needed. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Make split() and concatenate() handle None for empty layer groups. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
aafa0c1 to
9635e4d
Compare
When a single layer is extracted from stacked layers via x[layer_idx], the tensor loses a dimension but the PartitionSpec metadata still has the extra leading None (from vmap transform_metadata). Truncate the spec from the beginning to match the actual tensor rank. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This is in preparation for merging #996, so we don't need to depend on the jax tracer. It is also slightly cleaner this way and the assert is not needed any more, since the error is "defined away". It also adds the FSDP sharding for llama3.
Replace nnx.vmap with individual layer creation + jnp.stack. vmap breaks eager sharding, causing ~4x memory overhead due to full model replication instead of tensor-parallel sharding. The new approach: - Creates layers individually with a Python loop (respects eager sharding) - Stacks parameters using jit with donate_argnums to reduce peak memory - Preserves correct sharding specs on stacked arrays Memory improvement (per GPU, Qwen3-4B with tp=8): - nnx.List baseline: 1461 MiB - Old vmap approach: 4533 MiB - New loop+stack: 2485 MiB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Instead of creating all layers then stacking (which requires holding both original arrays and stacked arrays simultaneously), pre-allocate the stacked arrays and copy each layer's params directly using dynamic_update_slice. This keeps only one layer in memory at a time. Memory improvement during layer creation: - Old: JAX peak ~2098 MiB (originals + stacked arrays) - New: JAX peak ~1316 MiB (stacked arrays + 1 layer) Also adds memory logging via nvidia-smi and JAX memory_stats for debugging memory usage throughout the layer creation process. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Pass KV cache keys/values as part of scan carry instead of xs, enabling JAX buffer donation for effective in-place updates. This reduces peak memory during decode from 10793 MiB to 6697 MiB (38% reduction) by avoiding duplicate cache allocation. Also unifies the body_fn and scan call for prefill/decode paths. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove logging/debugging code (_log_mem, subprocess calls) - Use .at[idx].set() instead of dynamic_update_slice (cleaner syntax) - Keep donate_argnums=(0,) for buffer reuse (key to memory efficiency) - Reduce code from ~80 lines to ~40 lines Memory benchmark unchanged at 6697 MiB (vs 10797 MiB with vmap). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
c1118be to
993d6de
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-executed refactoring to improve memory efficiency during training by implementing per-layer gradient checkpointing using jax.lax.scan and stacking decoder layer weights at initialization. While the architectural changes are impressive and include a unified code path for training and inference with comprehensive test coverage, a security audit identified several critical vulnerabilities. These include arbitrary method execution in the distributed backend, potential SSRF/path traversal in checkpoint loading, and missing ownership checks in model management, which must be addressed for system robustness and security, especially in multi-tenant or exposed environments. Furthermore, a correctness issue was found related to the collection of hidden states, affecting a corresponding test.
| else: | ||
| new_kv_cache = None if is_training else KVCache.from_layer_outputs(all_keys, all_values, attention_mask) | ||
|
|
||
| all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] |
There was a problem hiding this comment.
When output_hidden_states is true, the list of hidden states is missing the output of the final decoder layer. The current implementation list(all_hs[:-1]) excludes the last layer's output from the scan operation.
To ensure all intermediate hidden states are collected, you should include all outputs from the scan. The list of hidden states returned by forward_layers should contain the initial embeddings plus the output of every decoder layer.
After this change, you will also need to update the assertion in test_hidden_states_length_matches to expect num_hidden_layers + 2 states (embeddings + N layer outputs + final normed output).
| all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] | |
| all_hidden_states = [hidden_states] + list(all_hs) if output_hidden_states else [] |
| hidden_states_ckpt = out.hidden_states | ||
| del out | ||
|
|
||
| assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == num_hidden_layers + 1 |
There was a problem hiding this comment.
This assertion incorrectly expects num_hidden_layers + 1 hidden states. With the recommended fix in tx/models/utils.py to correctly collect all hidden states, the model will return num_hidden_layers + 2 states when output_hidden_states=True (initial embeddings, N layer outputs, and the final normalized output). Please update this assertion to reflect the correct number of hidden states.
| assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == num_hidden_layers + 1 | |
| assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == num_hidden_layers + 2 |
tree_unflatten creates Variables with metadata from the original treedef, which doesn't include the stacked sharding. NNX APIs (get_partition_spec, Optimizer) read from 'sharding_names' metadata rather than array.sharding, so we sync them after unflatten. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Summary
Implement per-layer gradient checkpointing using
jax.lax.scanwith permanently stacked decoder layer weights viannx.vmap. This reduces peak memory by ~num_layers factor during training while maintaining a unified code path for trainingand inference.
Key Changes
1. Per-layer Gradient Checkpointing
jax.lax.scanwithjax.checkpointto recompute activations during backward passgradient_checkpointing=Truein model config2. Stacked Layer Weights
nnx.vmap→ shape(num_layers, ...)forward_layers()function for both training and inference(num_layers, batch, seq, heads, dim)3. DeepSeekV3 Split Stacking
dense_layersfor initial layers,moe_layersfor MoE layersFiles Changed
tx/models/utils.py- New:create_stacked_layers(),forward_layers()tx/models/{llama3,qwen3,deepseekv3}.py- Use stacked layerstx/layers/lora.py- Stacked LoRA indexingtx/utils/models.py- Stack/unstack for HF checkpoint compatibilitytx/utils/generator.py- Stacked KV cachetx/tinker/backends/jax.py- Fix gradient accumulation for stacked paramsTest plan