From 891362607859cd243a9d9e6f1d8592c93aa3ba0f Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 12 Jan 2026 16:09:14 -0800 Subject: [PATCH 001/133] feat: add chunked lm_head for memory-efficient logprobs computation Compute lm_head projection in chunks to avoid materializing the full [B*T, V] logits tensor. Key changes: - Add compute_logits flag to model.__call__ (skip lm_head when False) - Add lm_head weight to CausalLMOutput for external computation - Implement chunked logprobs with jax.lax.map (default chunk_size=1024) - Add loss_chunk_size config option Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor. For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/llama3.py | 24 +++++++++-- skyrl-tx/tx/models/qwen3.py | 24 +++++++++-- skyrl-tx/tx/models/types.py | 6 ++- skyrl-tx/tx/tinker/backends/jax.py | 65 +++++++++++++++++++++++++----- 4 files changed, 100 insertions(+), 19 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 2fb165290..b626f84ef 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -285,6 +285,14 @@ def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" return any(name in path for name in ("lora_A", "lora_B")) + @property + def lm_head_weight(self) -> jax.Array: + """Returns lm_head weight [H, V] for external matmul (e.g., chunked cross-entropy).""" + if self.config.tie_word_embeddings: + return self.model.embed_tokens.embedding.value.T + else: + return self.lm_head.kernel.value + def __call__( self, input_ids: jax.Array, @@ -294,6 +302,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = compute_positions(attention_mask) @@ -306,17 +315,24 @@ def __call__( adapter_indices=adapter_indices, kv_cache=kv_cache, ) - hidden_states = outputs.last_hidden_state - if self.config.tie_word_embeddings: - logits = hidden_states @ self.model.embed_tokens.embedding.value.T + + if is_training: + # Training: skip logits, return lm_head for chunked computation + logits = None else: - logits = self.lm_head(hidden_states, adapter_indices=adapter_indices) + # Inference: compute logits normally + hidden_states = outputs.last_hidden_state + if self.config.tie_word_embeddings: + logits = hidden_states @ self.model.embed_tokens.embedding.value.T + else: + logits = self.lm_head(hidden_states, adapter_indices=adapter_indices) return CausalLMOutput( logits=logits, last_hidden_state=outputs.last_hidden_state, kv_cache=outputs.kv_cache, hidden_states=outputs.hidden_states, + lm_head=self.lm_head_weight, ) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index cdc9c3a76..2a7c8581a 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -399,6 +399,14 @@ def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" return any(name in path for name in ("lora_A", "lora_B")) + @property + def lm_head_weight(self) -> jax.Array: + """Returns lm_head weight [H, V] for external matmul (e.g., chunked cross-entropy).""" + if self.config.tie_word_embeddings: + return self.model.embed_tokens.embedding.value.T + else: + return self.lm_head.kernel.value + def __call__( self, input_ids: jax.Array, @@ -408,6 +416,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = compute_positions(attention_mask) @@ -420,17 +429,24 @@ def __call__( adapter_indices=adapter_indices, kv_cache=kv_cache, ) - hidden_states = outputs.last_hidden_state - if self.config.tie_word_embeddings: - logits = hidden_states @ self.model.embed_tokens.embedding.value.T + + if is_training: + # Training: skip logits, return lm_head for chunked computation + logits = None else: - logits = self.lm_head(hidden_states, adapter_indices=adapter_indices) + # Inference: compute logits normally + hidden_states = outputs.last_hidden_state + if self.config.tie_word_embeddings: + logits = hidden_states @ self.model.embed_tokens.embedding.value.T + else: + logits = self.lm_head(hidden_states, adapter_indices=adapter_indices) return CausalLMOutput( logits=logits, last_hidden_state=outputs.last_hidden_state, kv_cache=outputs.kv_cache, hidden_states=outputs.hidden_states, + lm_head=self.lm_head_weight, ) diff --git a/skyrl-tx/tx/models/types.py b/skyrl-tx/tx/models/types.py index 0369a3750..2a5b167c8 100644 --- a/skyrl-tx/tx/models/types.py +++ b/skyrl-tx/tx/models/types.py @@ -36,13 +36,15 @@ class CausalLMOutput: """Output type for causal language models like Qwen3ForCausalLM. Attributes: - logits: The language modeling logits. + logits: The language modeling logits (None if is_training=True). last_hidden_state: The last hidden state from the model. kv_cache: The updated key-value cache. hidden_states: All hidden states, if output_hidden_states=True. + lm_head: The lm_head weight [H, V] for external logits computation. """ - logits: jax.Array + logits: jax.Array | None last_hidden_state: jax.Array kv_cache: KVCache hidden_states: list[jax.Array] | None = None + lm_head: jax.Array | None = None diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 720b760eb..447fe702e 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -83,6 +83,10 @@ class JaxBackendConfig(BaseModel, extra="forbid"): default=False, description="Whether to use gradient checkpointing (full recomputation strategy)", ) + loss_chunk_size: int = Field( + default=1024, + description="Chunk size for cross-entropy loss computation. Reduces memory by avoiding full [B*T, V] logits materialization.", + ) # Multi-node configuration coordinator_address: str | None = Field( default=None, @@ -236,16 +240,26 @@ def _model_forward( input_ids: jax.Array, attention_mask: jax.Array, adapter_indices: jax.Array, - ) -> jax.Array: + ) -> tuple[jax.Array, jax.Array]: + """Forward pass returning hidden states and lm_head weight for chunked cross-entropy.""" model = nnx.merge(graphdef, lora_params, non_lora_params) - output = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) - return output.logits + output = model( + input_ids, + attention_mask=attention_mask, + adapter_indices=adapter_indices, + is_training=True, + ) + return output.last_hidden_state, output.lm_head if self.config.gradient_checkpointing: # Wrap the model forward call to use jax.checkpoint for gradient checkpointing # policy=None corresponds to full activation recomputation _model_forward = jax.checkpoint(_model_forward, policy=None) + loss_chunk_size = self.config.loss_chunk_size + if loss_chunk_size <= 0: + raise ValueError(f"loss_chunk_size must be > 0, got {loss_chunk_size}") + def loss_for_lora( lora_params: nnx.State, non_lora_params: nnx.State, @@ -258,13 +272,46 @@ def loss_for_lora( sampling_logprobs: jax.Array, advantages: jax.Array, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - logits = _model_forward( + # Fused chunked cross-entropy: compute lm_head inside the chunk loop + # This avoids materializing the full [B*T, V] logits tensor + hidden_states, lm_head_weight = _model_forward( self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices - ) # [B, T, V] - - log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) - target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) - target_logprobs = (target_logits - log_sum_exp).squeeze(-1) + ) # hidden_states: [B, T, H], lm_head_weight: [H, V] + + B, T, H = hidden_states.shape + + # Flatten batch and sequence dimensions + flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] + flat_target_ids = target_ids.reshape(-1) # [B*T] + total_tokens = B * T + + # Pad to multiple of chunk_size for clean slicing + num_chunks = (total_tokens + loss_chunk_size - 1) // loss_chunk_size + padded_size = num_chunks * loss_chunk_size + pad_amount = padded_size - total_tokens + + if pad_amount > 0: + flat_hidden = jnp.pad(flat_hidden, ((0, pad_amount), (0, 0))) + flat_target_ids = jnp.pad(flat_target_ids, (0, pad_amount)) + + # Reshape into chunks: [num_chunks, chunk_size, H] and [num_chunks, chunk_size] + chunked_hidden = flat_hidden.reshape(num_chunks, loss_chunk_size, H) + chunked_targets = flat_target_ids.reshape(num_chunks, loss_chunk_size) + + def compute_chunk_logprobs(args): + """Compute lm_head and log probabilities for a chunk of tokens.""" + chunk_hidden, chunk_targets = args + # Compute logits for this chunk only: [chunk_size, H] @ [H, V] = [chunk_size, V] + chunk_logits = chunk_hidden @ lm_head_weight + # Compute log probabilities + log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) + target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) + return (target_logits - log_sum_exp).squeeze(-1) + + # Process chunks sequentially using lax.map (not vmap) to reduce memory + all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets)) + # Flatten and slice to original size, then reshape to [B, T] + target_logprobs = all_logprobs.reshape(-1)[:total_tokens].reshape(B, T) def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): return jax.lax.switch( From 9726415cf849b5e448cd7b5502c2dab775936e70 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 14 Jan 2026 15:38:48 -0800 Subject: [PATCH 002/133] fix: fallback to non-chunked loss when train_unembed=True or chunk_size<=0 The chunked cross-entropy path computes logits via direct matmul with lm_head weight, bypassing LoRA adapters. This is incorrect when train_unembed=True since LoRA should be applied to lm_head. Changes: - Rename is_training to skip_logits for clarity - Add _use_chunked_loss flag to backend - Automatically switch to non-chunked mode when: - train_unembed=True (requires LoRA on lm_head) - loss_chunk_size <= 0 (config-based disable) - Non-chunked path uses pre-computed logits with LoRA correctly applied --- skyrl-tx/tx/models/llama3.py | 8 +- skyrl-tx/tx/models/qwen3.py | 8 +- skyrl-tx/tx/models/types.py | 2 +- skyrl-tx/tx/tinker/backends/jax.py | 114 +++++++++++++++++------------ 4 files changed, 76 insertions(+), 56 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index b626f84ef..4f905dea8 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -302,7 +302,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, - is_training: bool = False, + skip_logits: bool = False, ) -> CausalLMOutput: if positions is None: positions = compute_positions(attention_mask) @@ -316,11 +316,11 @@ def __call__( kv_cache=kv_cache, ) - if is_training: - # Training: skip logits, return lm_head for chunked computation + if skip_logits: + # Skip logits computation for chunked cross-entropy (uses lm_head weight directly) logits = None else: - # Inference: compute logits normally + # Compute logits with LoRA applied (required for train_unembed=True) hidden_states = outputs.last_hidden_state if self.config.tie_word_embeddings: logits = hidden_states @ self.model.embed_tokens.embedding.value.T diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 2a7c8581a..a387ffb82 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -416,7 +416,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, - is_training: bool = False, + skip_logits: bool = False, ) -> CausalLMOutput: if positions is None: positions = compute_positions(attention_mask) @@ -430,11 +430,11 @@ def __call__( kv_cache=kv_cache, ) - if is_training: - # Training: skip logits, return lm_head for chunked computation + if skip_logits: + # Skip logits computation for chunked cross-entropy (uses lm_head weight directly) logits = None else: - # Inference: compute logits normally + # Compute logits with LoRA applied (required for train_unembed=True) hidden_states = outputs.last_hidden_state if self.config.tie_word_embeddings: logits = hidden_states @ self.model.embed_tokens.embedding.value.T diff --git a/skyrl-tx/tx/models/types.py b/skyrl-tx/tx/models/types.py index 2a5b167c8..ab9a32723 100644 --- a/skyrl-tx/tx/models/types.py +++ b/skyrl-tx/tx/models/types.py @@ -36,7 +36,7 @@ class CausalLMOutput: """Output type for causal language models like Qwen3ForCausalLM. Attributes: - logits: The language modeling logits (None if is_training=True). + logits: The language modeling logits (None if skip_logits=True). last_hidden_state: The last hidden state from the model. kv_cache: The updated key-value cache. hidden_states: All hidden states, if output_hidden_states=True. diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 447fe702e..0dbf8b692 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -85,7 +85,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"): ) loss_chunk_size: int = Field( default=1024, - description="Chunk size for cross-entropy loss computation. Reduces memory by avoiding full [B*T, V] logits materialization.", + description="Chunk size for cross-entropy loss computation. Reduces memory by avoiding full [B*T, V] logits materialization. Set to 0 to disable chunking.", ) # Multi-node configuration coordinator_address: str | None = Field( @@ -204,6 +204,11 @@ def __init__(self, base_model: str, config: JaxBackendConfig): f"max_lora_adapters={config.max_lora_adapters}, max_lora_rank={config.max_lora_rank}" ) + # Use chunked cross-entropy by default for memory efficiency. + # Falls back to non-chunked when: + # - loss_chunk_size <= 0 (disabled via config) + # - any model uses train_unembed=True (chunked path doesn't apply LoRA to lm_head) + self._use_chunked_loss = config.loss_chunk_size > 0 self._create_loss_and_grad_fn() def _micro_batch_size(self, total: int) -> int: @@ -232,6 +237,8 @@ def _jit_timing_context(self, seq_len: int, mode: str): def _create_loss_and_grad_fn(self): """Compile and cache the loss function to avoid re-jitting on every call.""" + use_chunked = self._use_chunked_loss + loss_chunk_size = self.config.loss_chunk_size def _model_forward( graphdef: nnx.GraphDef, @@ -241,25 +248,24 @@ def _model_forward( attention_mask: jax.Array, adapter_indices: jax.Array, ) -> tuple[jax.Array, jax.Array]: - """Forward pass returning hidden states and lm_head weight for chunked cross-entropy.""" + """Forward pass returning (hidden_states, lm_head) or (logits, None).""" model = nnx.merge(graphdef, lora_params, non_lora_params) output = model( input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices, - is_training=True, + skip_logits=use_chunked, ) - return output.last_hidden_state, output.lm_head + if use_chunked: + return output.last_hidden_state, output.lm_head + else: + return output.logits, None if self.config.gradient_checkpointing: # Wrap the model forward call to use jax.checkpoint for gradient checkpointing # policy=None corresponds to full activation recomputation _model_forward = jax.checkpoint(_model_forward, policy=None) - loss_chunk_size = self.config.loss_chunk_size - if loss_chunk_size <= 0: - raise ValueError(f"loss_chunk_size must be > 0, got {loss_chunk_size}") - def loss_for_lora( lora_params: nnx.State, non_lora_params: nnx.State, @@ -272,46 +278,54 @@ def loss_for_lora( sampling_logprobs: jax.Array, advantages: jax.Array, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - # Fused chunked cross-entropy: compute lm_head inside the chunk loop - # This avoids materializing the full [B*T, V] logits tensor - hidden_states, lm_head_weight = _model_forward( + forward_out, lm_head_weight = _model_forward( self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices - ) # hidden_states: [B, T, H], lm_head_weight: [H, V] - - B, T, H = hidden_states.shape - - # Flatten batch and sequence dimensions - flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] - flat_target_ids = target_ids.reshape(-1) # [B*T] - total_tokens = B * T - - # Pad to multiple of chunk_size for clean slicing - num_chunks = (total_tokens + loss_chunk_size - 1) // loss_chunk_size - padded_size = num_chunks * loss_chunk_size - pad_amount = padded_size - total_tokens - - if pad_amount > 0: - flat_hidden = jnp.pad(flat_hidden, ((0, pad_amount), (0, 0))) - flat_target_ids = jnp.pad(flat_target_ids, (0, pad_amount)) - - # Reshape into chunks: [num_chunks, chunk_size, H] and [num_chunks, chunk_size] - chunked_hidden = flat_hidden.reshape(num_chunks, loss_chunk_size, H) - chunked_targets = flat_target_ids.reshape(num_chunks, loss_chunk_size) - - def compute_chunk_logprobs(args): - """Compute lm_head and log probabilities for a chunk of tokens.""" - chunk_hidden, chunk_targets = args - # Compute logits for this chunk only: [chunk_size, H] @ [H, V] = [chunk_size, V] - chunk_logits = chunk_hidden @ lm_head_weight - # Compute log probabilities - log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) - target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) - return (target_logits - log_sum_exp).squeeze(-1) - - # Process chunks sequentially using lax.map (not vmap) to reduce memory - all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets)) - # Flatten and slice to original size, then reshape to [B, T] - target_logprobs = all_logprobs.reshape(-1)[:total_tokens].reshape(B, T) + ) + + if use_chunked: + # Chunked cross-entropy: compute lm_head inside the chunk loop + # This avoids materializing the full [B*T, V] logits tensor + hidden_states = forward_out # [B, T, H] + B, T, H = hidden_states.shape + + # Flatten batch and sequence dimensions + flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] + flat_target_ids = target_ids.reshape(-1) # [B*T] + total_tokens = B * T + + # Pad to multiple of chunk_size for clean slicing + num_chunks = (total_tokens + loss_chunk_size - 1) // loss_chunk_size + padded_size = num_chunks * loss_chunk_size + pad_amount = padded_size - total_tokens + + if pad_amount > 0: + flat_hidden = jnp.pad(flat_hidden, ((0, pad_amount), (0, 0))) + flat_target_ids = jnp.pad(flat_target_ids, (0, pad_amount)) + + # Reshape into chunks: [num_chunks, chunk_size, H] and [num_chunks, chunk_size] + chunked_hidden = flat_hidden.reshape(num_chunks, loss_chunk_size, H) + chunked_targets = flat_target_ids.reshape(num_chunks, loss_chunk_size) + + def compute_chunk_logprobs(args): + """Compute lm_head and log probabilities for a chunk of tokens.""" + chunk_hidden, chunk_targets = args + # Compute logits for this chunk only: [chunk_size, H] @ [H, V] = [chunk_size, V] + chunk_logits = chunk_hidden @ lm_head_weight + # Compute log probabilities + log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) + target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) + return (target_logits - log_sum_exp).squeeze(-1) + + # Process chunks sequentially using lax.map (not vmap) to reduce memory + all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets)) + # Flatten and slice to original size, then reshape to [B, T] + target_logprobs = all_logprobs.reshape(-1)[:total_tokens].reshape(B, T) + else: + # Non-chunked: use pre-computed logits (with LoRA applied to lm_head) + logits = forward_out # [B, T, V] + log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) + target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) + target_logprobs = (target_logits - log_sum_exp).squeeze(-1) def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): return jax.lax.switch( @@ -482,6 +496,12 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: if not (0 < lora_config.rank <= self.config.max_lora_rank): raise ValueError(f"LoRA rank {lora_config.rank} must be between 1 and {self.config.max_lora_rank}") + # Switch to non-chunked loss if train_unembed=True (chunked doesn't apply LoRA to lm_head) + if lora_config.train_unembed and self._use_chunked_loss: + logger.info("Switching to non-chunked loss mode (train_unembed=True requires LoRA on lm_head)") + self._use_chunked_loss = False + self._create_loss_and_grad_fn() + # Store model metadata self.models[model_id] = types.ModelMetadata( adapter_index=adapter_index, From 3fa6d2d1fb848b6cbad69411e0597a490cc08f09 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 10:49:16 -0800 Subject: [PATCH 003/133] add tests --- skyrl-tx/tests/tinker/test_jax_backend.py | 102 ++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index 2edd9d82b..e76591484 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -556,3 +556,105 @@ def test_adapter_reuse_initializes_lora_adapter(): # Verify lora_B is zeros assert (lora_layer.lora_B[adapter_idx] == 0.0).all(), "lora_B should be zeros" + + +class TestChunkedCrossEntropyLoss: + """Tests for chunked cross-entropy loss computation.""" + + def _create_backend(self, loss_chunk_size: int) -> JaxBackend: + """Create a backend with specified chunk size.""" + config = JaxBackendConfig( + max_lora_adapters=2, + max_lora_rank=32, + loss_chunk_size=loss_chunk_size, + ) + return JaxBackend(BASE_MODEL, config) + + def _create_inputs(self, backend: JaxBackend, batch_size: int, seq_len: int, adapter_idx: int = 0): + """Create test inputs for forward pass.""" + vocab = backend.model.config.vocab_size + input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + adapter_indices = jnp.full((batch_size,), adapter_idx, dtype=jnp.int32) + target_ids = (input_ids + 1) % vocab + loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) + loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) + sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + return (input_ids, attention_mask, adapter_indices, target_ids, + loss_mask, loss_fn_types, sampling_logprobs, advantages) + + def _run_forward(self, backend: JaxBackend, inputs: tuple): + """Run forward pass and return losses and logprobs.""" + (input_ids, attention_mask, adapter_indices, target_ids, + loss_mask, loss_fn_types, sampling_logprobs, advantages) = inputs + _, losses, logprobs = backend._forward( + backend.accumulated_grads, + backend.lora_params, + backend.non_lora_params, + input_ids, + attention_mask, + adapter_indices, + target_ids, + loss_mask, + loss_fn_types, + sampling_logprobs, + advantages, + ) + return losses, logprobs + + def test_fallback_on_train_unembed(self): + """Verify backend switches to non-chunked when train_unembed=True.""" + backend = self._create_backend(loss_chunk_size=1024) + assert backend._use_chunked_loss is True + + lora_config = LoraConfig(rank=8, alpha=16, seed=0, train_unembed=True) + backend.create_model("model_with_unembed", lora_config) + + assert backend._use_chunked_loss is False + + @pytest.mark.parametrize("chunk_size,expected", [ + (0, False), # Disabled + (-1, False), # Disabled + (1024, True), # Enabled + ]) + def test_use_chunked_loss_config(self, chunk_size, expected): + """Verify _use_chunked_loss is set correctly based on loss_chunk_size.""" + backend = self._create_backend(loss_chunk_size=chunk_size) + assert backend._use_chunked_loss is expected + + @pytest.mark.parametrize("batch_size,seq_len,chunk_size", [ + (2, 16, 8), # Multiple batches + (1, 16, 16), # Exact multiple (1 chunk) + (1, 17, 16), # One extra token (worst case padding) + (1, 8, 16), # Fewer tokens than chunk size + (1, 32, 16), # Exact 2 chunks + (1, 1, 16), # Single token + (1, 31, 16), # Almost 2 chunks + ]) + def test_chunked_vs_nonchunked_logprobs(self, batch_size, seq_len, chunk_size): + """Verify chunked and non-chunked loss produce identical logprobs.""" + backend_chunked = self._create_backend(loss_chunk_size=chunk_size) + backend_nonchunked = self._create_backend(loss_chunk_size=0) + + assert backend_chunked._use_chunked_loss is True + assert backend_nonchunked._use_chunked_loss is False + + inputs = self._create_inputs(backend_chunked, batch_size, seq_len) + losses_chunked, logprobs_chunked = self._run_forward(backend_chunked, inputs) + losses_nonchunked, logprobs_nonchunked = self._run_forward(backend_nonchunked, inputs) + + np.testing.assert_allclose( + np.asarray(logprobs_chunked), + np.asarray(logprobs_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg=f"Logprobs mismatch for batch_size={batch_size}, seq_len={seq_len}, chunk_size={chunk_size}", + ) + np.testing.assert_allclose( + np.asarray(losses_chunked), + np.asarray(losses_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg=f"Losses mismatch for batch_size={batch_size}, seq_len={seq_len}, chunk_size={chunk_size}", + ) From 801f1e929317e1b79704df22a75ee572b383aee1 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 10:51:29 -0800 Subject: [PATCH 004/133] checkpoint --- skyrl-tx/tx/tinker/backends/jax.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 0dbf8b692..fdbefc23c 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -239,6 +239,7 @@ def _create_loss_and_grad_fn(self): """Compile and cache the loss function to avoid re-jitting on every call.""" use_chunked = self._use_chunked_loss loss_chunk_size = self.config.loss_chunk_size + gradient_checkpointing = self.config.gradient_checkpointing def _model_forward( graphdef: nnx.GraphDef, @@ -316,6 +317,9 @@ def compute_chunk_logprobs(args): target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) return (target_logits - log_sum_exp).squeeze(-1) + if gradient_checkpointing: + compute_chunk_logprobs = jax.checkpoint(compute_chunk_logprobs, policy=None) + # Process chunks sequentially using lax.map (not vmap) to reduce memory all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets)) # Flatten and slice to original size, then reshape to [B, T] From 07469ffdecd6c8c3784e338c0720632e243064b4 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 10:54:50 -0800 Subject: [PATCH 005/133] deprecation warning --- skyrl-tx/tx/models/llama3.py | 6 +++--- skyrl-tx/tx/models/qwen3.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 4f905dea8..dffc310d1 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -289,9 +289,9 @@ def is_lora_param(path: tuple, _value) -> bool: def lm_head_weight(self) -> jax.Array: """Returns lm_head weight [H, V] for external matmul (e.g., chunked cross-entropy).""" if self.config.tie_word_embeddings: - return self.model.embed_tokens.embedding.value.T + return self.model.embed_tokens.embedding[...].T else: - return self.lm_head.kernel.value + return self.lm_head.kernel[...] def __call__( self, @@ -323,7 +323,7 @@ def __call__( # Compute logits with LoRA applied (required for train_unembed=True) hidden_states = outputs.last_hidden_state if self.config.tie_word_embeddings: - logits = hidden_states @ self.model.embed_tokens.embedding.value.T + logits = hidden_states @ self.model.embed_tokens.embedding[...].T else: logits = self.lm_head(hidden_states, adapter_indices=adapter_indices) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index a387ffb82..9fac0db64 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -403,9 +403,9 @@ def is_lora_param(path: tuple, _value) -> bool: def lm_head_weight(self) -> jax.Array: """Returns lm_head weight [H, V] for external matmul (e.g., chunked cross-entropy).""" if self.config.tie_word_embeddings: - return self.model.embed_tokens.embedding.value.T + return self.model.embed_tokens.embedding[...].T else: - return self.lm_head.kernel.value + return self.lm_head.kernel[...] def __call__( self, @@ -437,7 +437,7 @@ def __call__( # Compute logits with LoRA applied (required for train_unembed=True) hidden_states = outputs.last_hidden_state if self.config.tie_word_embeddings: - logits = hidden_states @ self.model.embed_tokens.embedding.value.T + logits = hidden_states @ self.model.embed_tokens.embedding[...].T else: logits = self.lm_head(hidden_states, adapter_indices=adapter_indices) From 30f083ac7490593462b016c7d85a7c4785461a66 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 16:38:40 -0800 Subject: [PATCH 006/133] lint --- skyrl-tx/tests/tinker/test_jax_backend.py | 58 ++++++++++++++++------- skyrl-tx/tx/tinker/backends/jax.py | 3 +- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index e76591484..2b8d20e9e 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -581,13 +581,29 @@ def _create_inputs(self, backend: JaxBackend, batch_size: int, seq_len: int, ada loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) - return (input_ids, attention_mask, adapter_indices, target_ids, - loss_mask, loss_fn_types, sampling_logprobs, advantages) + return ( + input_ids, + attention_mask, + adapter_indices, + target_ids, + loss_mask, + loss_fn_types, + sampling_logprobs, + advantages, + ) def _run_forward(self, backend: JaxBackend, inputs: tuple): """Run forward pass and return losses and logprobs.""" - (input_ids, attention_mask, adapter_indices, target_ids, - loss_mask, loss_fn_types, sampling_logprobs, advantages) = inputs + ( + input_ids, + attention_mask, + adapter_indices, + target_ids, + loss_mask, + loss_fn_types, + sampling_logprobs, + advantages, + ) = inputs _, losses, logprobs = backend._forward( backend.accumulated_grads, backend.lora_params, @@ -613,25 +629,31 @@ def test_fallback_on_train_unembed(self): assert backend._use_chunked_loss is False - @pytest.mark.parametrize("chunk_size,expected", [ - (0, False), # Disabled - (-1, False), # Disabled - (1024, True), # Enabled - ]) + @pytest.mark.parametrize( + "chunk_size,expected", + [ + (0, False), # Disabled + (-1, False), # Disabled + (1024, True), # Enabled + ], + ) def test_use_chunked_loss_config(self, chunk_size, expected): """Verify _use_chunked_loss is set correctly based on loss_chunk_size.""" backend = self._create_backend(loss_chunk_size=chunk_size) assert backend._use_chunked_loss is expected - @pytest.mark.parametrize("batch_size,seq_len,chunk_size", [ - (2, 16, 8), # Multiple batches - (1, 16, 16), # Exact multiple (1 chunk) - (1, 17, 16), # One extra token (worst case padding) - (1, 8, 16), # Fewer tokens than chunk size - (1, 32, 16), # Exact 2 chunks - (1, 1, 16), # Single token - (1, 31, 16), # Almost 2 chunks - ]) + @pytest.mark.parametrize( + "batch_size,seq_len,chunk_size", + [ + (2, 16, 8), # Multiple batches + (1, 16, 16), # Exact multiple (1 chunk) + (1, 17, 16), # One extra token (worst case padding) + (1, 8, 16), # Fewer tokens than chunk size + (1, 32, 16), # Exact 2 chunks + (1, 1, 16), # Single token + (1, 31, 16), # Almost 2 chunks + ], + ) def test_chunked_vs_nonchunked_logprobs(self, batch_size, seq_len, chunk_size): """Verify chunked and non-chunked loss produce identical logprobs.""" backend_chunked = self._create_backend(loss_chunk_size=chunk_size) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index fdbefc23c..1334492c6 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -209,6 +209,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig): # - loss_chunk_size <= 0 (disabled via config) # - any model uses train_unembed=True (chunked path doesn't apply LoRA to lm_head) self._use_chunked_loss = config.loss_chunk_size > 0 + logger.info(f"Chunked cross-entropy loss: {self._use_chunked_loss} (chunk_size={config.loss_chunk_size})") self._create_loss_and_grad_fn() def _micro_batch_size(self, total: int) -> int: @@ -291,7 +292,7 @@ def loss_for_lora( # Flatten batch and sequence dimensions flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] - flat_target_ids = target_ids.reshape(-1) # [B*T] + flat_target_ids = target_ids.reshape(-1) # [B*T] total_tokens = B * T # Pad to multiple of chunk_size for clean slicing From f318cbb9a774dcff0fa9a3ab090dce16203c6619 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 12 Jan 2026 17:17:23 -0800 Subject: [PATCH 007/133] feat: add per-layer gradient checkpointing Recompute activations during backward to save memory. Only one layer's activations are held at a time during backward pass, reducing peak memory by ~num_layers factor. - Add gradient_checkpointing config to ModelConfig - Apply jax.checkpoint per-layer when is_training=True - Rename compute_logits to is_training (controls both logits and checkpointing) Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/configs.py | 4 ++++ skyrl-tx/tx/models/llama3.py | 9 ++++++++- skyrl-tx/tx/models/qwen3.py | 9 ++++++++- skyrl-tx/tx/tinker/backends/jax.py | 8 ++------ 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tx/models/configs.py b/skyrl-tx/tx/models/configs.py index adc2b57ab..c21ee80b9 100644 --- a/skyrl-tx/tx/models/configs.py +++ b/skyrl-tx/tx/models/configs.py @@ -14,12 +14,14 @@ class ModelConfig(PretrainedConfig): max_lora_adapters: Maximum number of concurrent LoRA adapters max_lora_rank: Maximum rank for LoRA adapters shard_attention_heads: Whether to shard attention across tensor parallel devices + gradient_checkpointing: Recompute activations during backward to save memory """ # Type hints for LoRA attributes max_lora_adapters: int max_lora_rank: int shard_attention_heads: bool + gradient_checkpointing: bool def __init__( self, @@ -28,6 +30,7 @@ def __init__( max_lora_adapters: int, max_lora_rank: int, shard_attention_heads: bool, + gradient_checkpointing: bool = False, ): # Copy all attributes from the base config super().__init__(**config.to_dict()) @@ -36,6 +39,7 @@ def __init__( self.max_lora_adapters = max_lora_adapters self.max_lora_rank = max_lora_rank self.shard_attention_heads = shard_attention_heads + self.gradient_checkpointing = gradient_checkpointing # Model-specific aliases for clarity and backwards compatibility diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 2fb165290..5a62fe022 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -224,6 +224,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -237,12 +238,16 @@ def __call__( if output_hidden_states: all_hidden_states.append(hidden_states) + layer_kv_cache = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) + if self.config.gradient_checkpointing and is_training: + layer = jax.checkpoint(layer) + hidden_states, (k, v) = layer( hidden_states, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, - kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position), + kv_cache=layer_kv_cache, ) updated_keys.append(k) updated_values.append(v) @@ -294,6 +299,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = compute_positions(attention_mask) @@ -305,6 +311,7 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, + is_training=is_training, ) hidden_states = outputs.last_hidden_state if self.config.tie_word_embeddings: diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index cdc9c3a76..2f63e7294 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -339,6 +339,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -352,12 +353,16 @@ def __call__( if output_hidden_states: all_hidden_states.append(hidden_states) + layer_kv_cache = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) + if self.config.gradient_checkpointing and is_training: + layer = jax.checkpoint(layer) + hidden_states, (k, v) = layer( hidden_states, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, - kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position), + kv_cache=layer_kv_cache, ) updated_keys.append(k) updated_values.append(v) @@ -408,6 +413,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = compute_positions(attention_mask) @@ -419,6 +425,7 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, + is_training=is_training, ) hidden_states = outputs.last_hidden_state if self.config.tie_word_embeddings: diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 720b760eb..afb1268df 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -81,7 +81,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"): ) gradient_checkpointing: bool = Field( default=False, - description="Whether to use gradient checkpointing (full recomputation strategy)", + description="Per-layer activation checkpointing: recompute activations during backward to save memory", ) # Multi-node configuration coordinator_address: str | None = Field( @@ -163,6 +163,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig): max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, shard_attention_heads=config.shard_attention_heads, + gradient_checkpointing=config.gradient_checkpointing, ) model_class = get_model_class(self.model_config) @@ -241,11 +242,6 @@ def _model_forward( output = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) return output.logits - if self.config.gradient_checkpointing: - # Wrap the model forward call to use jax.checkpoint for gradient checkpointing - # policy=None corresponds to full activation recomputation - _model_forward = jax.checkpoint(_model_forward, policy=None) - def loss_for_lora( lora_params: nnx.State, non_lora_params: nnx.State, From a763fce240413c7e37f5171565cefaf0085cb747 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 13 Jan 2026 13:28:22 -0800 Subject: [PATCH 008/133] feat: use fori_loop for gradient checkpointing to enable XLA buffer reuse Add _forward_layers_checkpointed() using jax.lax.fori_loop so XLA compiles ONE loop body and reuses buffers during backward recomputation. With a Python loop, XLA unrolls N separate checkpoint regions and can't optimize buffer reuse across them. Only enabled when gradient_checkpointing=True. Without checkpointing, activations are stored anyway, so fori_loop's buffer reuse doesn't help and its weight stacking overhead makes it worse. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/llama3.py | 106 +++++++++++++++++++++++++++++------ skyrl-tx/tx/models/qwen3.py | 106 +++++++++++++++++++++++++++++------ 2 files changed, 180 insertions(+), 32 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 5a62fe022..18e3d850e 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -232,39 +232,113 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) all_hidden_states: list[jax.Array] = [] - updated_keys, updated_values = [], [] - - for layer_idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) - - layer_kv_cache = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) - if self.config.gradient_checkpointing and is_training: - layer = jax.checkpoint(layer) - hidden_states, (k, v) = layer( + # Checkpointing: use fori_loop so XLA compiles ONE loop body and reuses + # buffers during recomputation. Without checkpointing, activations are + # stored anyway, so fori_loop's buffer reuse doesn't help and its weight + # stacking overhead makes it worse. + if is_training and self.config.gradient_checkpointing: + hidden_states = self._forward_layers_checkpointed( hidden_states, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, - kv_cache=layer_kv_cache, ) - updated_keys.append(k) - updated_values.append(v) + updated_keys, updated_values = [], [] + new_cache_position = input_ids.shape[1] + else: + hidden_states, updated_keys, updated_values = self._forward_layers( + hidden_states, + seq_lengths=seq_lengths, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + output_hidden_states=output_hidden_states, + all_hidden_states=all_hidden_states, + ) + new_cache_position = kv_cache.cache_position + 1 if kv_cache else input_ids.shape[1] hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states.append(hidden_states) - # Increment cache_position if cache exists, or use sequence length for new cache - new_cache_position = kv_cache.cache_position + 1 if kv_cache is not None else input_ids.shape[1] - return ModelOutput( last_hidden_state=hidden_states, kv_cache=KVCache(keys=updated_keys, values=updated_values, cache_position=new_cache_position), hidden_states=all_hidden_states if output_hidden_states else None, ) + def _forward_layers_checkpointed( + self, + hidden_states: jax.Array, + *, + seq_lengths: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + ) -> jax.Array: + """Forward pass with gradient checkpointing using fori_loop. + + Uses fori_loop so XLA compiles ONE loop body and reuses buffers during + backward recomputation. With a Python loop, XLA unrolls N separate + checkpoint regions and can't optimize buffer reuse across them. + + Tradeoff: requires stacking all layer weights once per forward pass. + This is acceptable because checkpointing already trades compute for memory. + + TODO(haochen): Load weights directly into stacked format to avoid 2x memory. + Currently we have both self.layers (original) and stacked copy during forward. + """ + num_layers = len(self.layers) + + # Stack layer weights for dynamic indexing in fori_loop + layer_graphdef, _ = nnx.split(self.layers[0]) + stacked_weights = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *[nnx.state(layer) for layer in self.layers]) + + def body_fn(i, hs): + layer_weights = jax.tree.map(lambda x: x[i], stacked_weights) + layer = nnx.merge(layer_graphdef, layer_weights) + hs, _ = layer( + hs, seq_lengths=seq_lengths, positions=positions, adapter_indices=adapter_indices, kv_cache=None + ) + return hs + + body_fn = jax.checkpoint(body_fn) + return jax.lax.fori_loop(0, num_layers, body_fn, hidden_states) + + def _forward_layers( + self, + hidden_states: jax.Array, + *, + seq_lengths: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + kv_cache: KVCache | None, + output_hidden_states: bool, + all_hidden_states: list[jax.Array], + ) -> tuple[jax.Array, list[jax.Array], list[jax.Array]]: + """Standard forward pass through decoder layers. + + Used for inference (with KV cache) and training without checkpointing. + """ + updated_keys, updated_values = [], [] + + for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + layer_kv = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) + hidden_states, (k, v) = layer( + hidden_states, + seq_lengths=seq_lengths, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=layer_kv, + ) + updated_keys.append(k) + updated_values.append(v) + + return hidden_states, updated_keys, updated_values + class Llama3ForCausalLM(nnx.Module, GeneratorMixin): diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 2f63e7294..d592687f1 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -347,39 +347,113 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) all_hidden_states: list[jax.Array] = [] - updated_keys, updated_values = [], [] - - for layer_idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) - - layer_kv_cache = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) - if self.config.gradient_checkpointing and is_training: - layer = jax.checkpoint(layer) - hidden_states, (k, v) = layer( + # Checkpointing: use fori_loop so XLA compiles ONE loop body and reuses + # buffers during recomputation. Without checkpointing, activations are + # stored anyway, so fori_loop's buffer reuse doesn't help and its weight + # stacking overhead makes it worse. + if is_training and self.config.gradient_checkpointing: + hidden_states = self._forward_layers_checkpointed( hidden_states, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, - kv_cache=layer_kv_cache, ) - updated_keys.append(k) - updated_values.append(v) + updated_keys, updated_values = [], [] + new_cache_position = input_ids.shape[1] + else: + hidden_states, updated_keys, updated_values = self._forward_layers( + hidden_states, + seq_lengths=seq_lengths, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + output_hidden_states=output_hidden_states, + all_hidden_states=all_hidden_states, + ) + new_cache_position = kv_cache.cache_position + 1 if kv_cache else input_ids.shape[1] hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states.append(hidden_states) - # Increment cache_position if cache exists, or use sequence length for new cache - new_cache_position = kv_cache.cache_position + 1 if kv_cache is not None else input_ids.shape[1] - return ModelOutput( last_hidden_state=hidden_states, kv_cache=KVCache(keys=updated_keys, values=updated_values, cache_position=new_cache_position), hidden_states=all_hidden_states if output_hidden_states else None, ) + def _forward_layers_checkpointed( + self, + hidden_states: jax.Array, + *, + seq_lengths: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + ) -> jax.Array: + """Forward pass with gradient checkpointing using fori_loop. + + Uses fori_loop so XLA compiles ONE loop body and reuses buffers during + backward recomputation. With a Python loop, XLA unrolls N separate + checkpoint regions and can't optimize buffer reuse across them. + + Tradeoff: requires stacking all layer weights once per forward pass. + This is acceptable because checkpointing already trades compute for memory. + + TODO(haochen): Load weights directly into stacked format to avoid 2x memory. + Currently we have both self.layers (original) and stacked copy during forward. + """ + num_layers = len(self.layers) + + # Stack layer weights for dynamic indexing in fori_loop + layer_graphdef, _ = nnx.split(self.layers[0]) + stacked_weights = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *[nnx.state(layer) for layer in self.layers]) + + def body_fn(i, hs): + layer_weights = jax.tree.map(lambda x: x[i], stacked_weights) + layer = nnx.merge(layer_graphdef, layer_weights) + hs, _ = layer( + hs, seq_lengths=seq_lengths, positions=positions, adapter_indices=adapter_indices, kv_cache=None + ) + return hs + + body_fn = jax.checkpoint(body_fn) + return jax.lax.fori_loop(0, num_layers, body_fn, hidden_states) + + def _forward_layers( + self, + hidden_states: jax.Array, + *, + seq_lengths: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + kv_cache: KVCache | None, + output_hidden_states: bool, + all_hidden_states: list[jax.Array], + ) -> tuple[jax.Array, list[jax.Array], list[jax.Array]]: + """Standard forward pass through decoder layers. + + Used for inference (with KV cache) and training without checkpointing. + """ + updated_keys, updated_values = [], [] + + for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + layer_kv = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) + hidden_states, (k, v) = layer( + hidden_states, + seq_lengths=seq_lengths, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=layer_kv, + ) + updated_keys.append(k) + updated_values.append(v) + + return hidden_states, updated_keys, updated_values + class Qwen3ForCausalLM(nnx.Module, GeneratorMixin): From 3676aae84e82e998d99ef35db92c0db1d4603e5b Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 14:50:44 -0800 Subject: [PATCH 009/133] fix: use attention_mask instead of seq_lengths in model forward --- skyrl-tx/tx/models/llama3.py | 10 +++++----- skyrl-tx/tx/models/qwen3.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 18e3d850e..e08adcfda 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -249,7 +249,7 @@ def __call__( else: hidden_states, updated_keys, updated_values = self._forward_layers( hidden_states, - seq_lengths=seq_lengths, + attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=kv_cache, @@ -272,7 +272,7 @@ def _forward_layers_checkpointed( self, hidden_states: jax.Array, *, - seq_lengths: jax.Array, + attention_mask: jax.Array, positions: jax.Array, adapter_indices: jax.Array | None, ) -> jax.Array: @@ -298,7 +298,7 @@ def body_fn(i, hs): layer_weights = jax.tree.map(lambda x: x[i], stacked_weights) layer = nnx.merge(layer_graphdef, layer_weights) hs, _ = layer( - hs, seq_lengths=seq_lengths, positions=positions, adapter_indices=adapter_indices, kv_cache=None + hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=None ) return hs @@ -309,7 +309,7 @@ def _forward_layers( self, hidden_states: jax.Array, *, - seq_lengths: jax.Array, + attention_mask: jax.Array, positions: jax.Array, adapter_indices: jax.Array | None, kv_cache: KVCache | None, @@ -329,7 +329,7 @@ def _forward_layers( layer_kv = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) hidden_states, (k, v) = layer( hidden_states, - seq_lengths=seq_lengths, + attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=layer_kv, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index d592687f1..9fac658f7 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -364,7 +364,7 @@ def __call__( else: hidden_states, updated_keys, updated_values = self._forward_layers( hidden_states, - seq_lengths=seq_lengths, + attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=kv_cache, @@ -387,7 +387,7 @@ def _forward_layers_checkpointed( self, hidden_states: jax.Array, *, - seq_lengths: jax.Array, + attention_mask: jax.Array, positions: jax.Array, adapter_indices: jax.Array | None, ) -> jax.Array: @@ -413,7 +413,7 @@ def body_fn(i, hs): layer_weights = jax.tree.map(lambda x: x[i], stacked_weights) layer = nnx.merge(layer_graphdef, layer_weights) hs, _ = layer( - hs, seq_lengths=seq_lengths, positions=positions, adapter_indices=adapter_indices, kv_cache=None + hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=None ) return hs @@ -424,7 +424,7 @@ def _forward_layers( self, hidden_states: jax.Array, *, - seq_lengths: jax.Array, + attention_mask: jax.Array, positions: jax.Array, adapter_indices: jax.Array | None, kv_cache: KVCache | None, @@ -444,7 +444,7 @@ def _forward_layers( layer_kv = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) hidden_states, (k, v) = layer( hidden_states, - seq_lengths=seq_lengths, + attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=layer_kv, From cb083ae50e1970a7038749e76ee10413f7cf7678 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 14:50:53 -0800 Subject: [PATCH 010/133] fix: pass is_training=True to enable gradient checkpointing --- skyrl-tx/tx/tinker/backends/jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index afb1268df..f7ce76c7b 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -239,7 +239,7 @@ def _model_forward( adapter_indices: jax.Array, ) -> jax.Array: model = nnx.merge(graphdef, lora_params, non_lora_params) - output = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) + output = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices, is_training=True) return output.logits def loss_for_lora( From c368f237d022990e9d688da1e3bb74d689cd9b4b Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 14:51:28 -0800 Subject: [PATCH 011/133] feat: use scan instead of fori_loop to support output_hidden_states --- skyrl-tx/tx/models/llama3.py | 30 ++++++++++++++++++++---------- skyrl-tx/tx/models/qwen3.py | 30 ++++++++++++++++++++---------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index e08adcfda..bdf503a83 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -233,16 +233,17 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) all_hidden_states: list[jax.Array] = [] - # Checkpointing: use fori_loop so XLA compiles ONE loop body and reuses + # Checkpointing: use scan so XLA compiles ONE loop body and reuses # buffers during recomputation. Without checkpointing, activations are - # stored anyway, so fori_loop's buffer reuse doesn't help and its weight + # stored anyway, so scan's buffer reuse doesn't help and its weight # stacking overhead makes it worse. if is_training and self.config.gradient_checkpointing: - hidden_states = self._forward_layers_checkpointed( + hidden_states, all_hidden_states = self._forward_layers_checkpointed( hidden_states, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, + output_hidden_states=output_hidden_states, ) updated_keys, updated_values = [], [] new_cache_position = input_ids.shape[1] @@ -275,10 +276,11 @@ def _forward_layers_checkpointed( attention_mask: jax.Array, positions: jax.Array, adapter_indices: jax.Array | None, - ) -> jax.Array: - """Forward pass with gradient checkpointing using fori_loop. + output_hidden_states: bool, + ) -> tuple[jax.Array, list[jax.Array]]: + """Forward pass with gradient checkpointing using scan. - Uses fori_loop so XLA compiles ONE loop body and reuses buffers during + Uses scan so XLA compiles ONE loop body and reuses buffers during backward recomputation. With a Python loop, XLA unrolls N separate checkpoint regions and can't optimize buffer reuse across them. @@ -290,20 +292,28 @@ def _forward_layers_checkpointed( """ num_layers = len(self.layers) - # Stack layer weights for dynamic indexing in fori_loop + # Stack layer weights for dynamic indexing in scan layer_graphdef, _ = nnx.split(self.layers[0]) stacked_weights = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *[nnx.state(layer) for layer in self.layers]) - def body_fn(i, hs): + def body_fn(hs, i): layer_weights = jax.tree.map(lambda x: x[i], stacked_weights) layer = nnx.merge(layer_graphdef, layer_weights) hs, _ = layer( hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=None ) - return hs + return hs, hs # carry, output (collected if output_hidden_states) body_fn = jax.checkpoint(body_fn) - return jax.lax.fori_loop(0, num_layers, body_fn, hidden_states) + final_hs, all_hs = jax.lax.scan(body_fn, hidden_states, jnp.arange(num_layers)) + + if output_hidden_states: + # all_hs is [num_layers, batch, seq, hidden], convert to list and prepend input + all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers)] + else: + all_hidden_states = [] + + return final_hs, all_hidden_states def _forward_layers( self, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 9fac658f7..698fffb0c 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -348,16 +348,17 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) all_hidden_states: list[jax.Array] = [] - # Checkpointing: use fori_loop so XLA compiles ONE loop body and reuses + # Checkpointing: use scan so XLA compiles ONE loop body and reuses # buffers during recomputation. Without checkpointing, activations are - # stored anyway, so fori_loop's buffer reuse doesn't help and its weight + # stored anyway, so scan's buffer reuse doesn't help and its weight # stacking overhead makes it worse. if is_training and self.config.gradient_checkpointing: - hidden_states = self._forward_layers_checkpointed( + hidden_states, all_hidden_states = self._forward_layers_checkpointed( hidden_states, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, + output_hidden_states=output_hidden_states, ) updated_keys, updated_values = [], [] new_cache_position = input_ids.shape[1] @@ -390,10 +391,11 @@ def _forward_layers_checkpointed( attention_mask: jax.Array, positions: jax.Array, adapter_indices: jax.Array | None, - ) -> jax.Array: - """Forward pass with gradient checkpointing using fori_loop. + output_hidden_states: bool, + ) -> tuple[jax.Array, list[jax.Array]]: + """Forward pass with gradient checkpointing using scan. - Uses fori_loop so XLA compiles ONE loop body and reuses buffers during + Uses scan so XLA compiles ONE loop body and reuses buffers during backward recomputation. With a Python loop, XLA unrolls N separate checkpoint regions and can't optimize buffer reuse across them. @@ -405,20 +407,28 @@ def _forward_layers_checkpointed( """ num_layers = len(self.layers) - # Stack layer weights for dynamic indexing in fori_loop + # Stack layer weights for dynamic indexing in scan layer_graphdef, _ = nnx.split(self.layers[0]) stacked_weights = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *[nnx.state(layer) for layer in self.layers]) - def body_fn(i, hs): + def body_fn(hs, i): layer_weights = jax.tree.map(lambda x: x[i], stacked_weights) layer = nnx.merge(layer_graphdef, layer_weights) hs, _ = layer( hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=None ) - return hs + return hs, hs # carry, output (collected if output_hidden_states) body_fn = jax.checkpoint(body_fn) - return jax.lax.fori_loop(0, num_layers, body_fn, hidden_states) + final_hs, all_hs = jax.lax.scan(body_fn, hidden_states, jnp.arange(num_layers)) + + if output_hidden_states: + # all_hs is [num_layers, batch, seq, hidden], convert to list and prepend input + all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers)] + else: + all_hidden_states = [] + + return final_hs, all_hidden_states def _forward_layers( self, From 9ef7e1762eb2eab954705d60bf054f228dcad854 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 14:53:46 -0800 Subject: [PATCH 012/133] perf: return None from scan when output_hidden_states=False to save memory --- skyrl-tx/tx/models/llama3.py | 2 +- skyrl-tx/tx/models/qwen3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index bdf503a83..0fa3380fc 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -302,7 +302,7 @@ def body_fn(hs, i): hs, _ = layer( hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=None ) - return hs, hs # carry, output (collected if output_hidden_states) + return hs, hs if output_hidden_states else None body_fn = jax.checkpoint(body_fn) final_hs, all_hs = jax.lax.scan(body_fn, hidden_states, jnp.arange(num_layers)) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 698fffb0c..98dcecdd4 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -417,7 +417,7 @@ def body_fn(hs, i): hs, _ = layer( hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=None ) - return hs, hs # carry, output (collected if output_hidden_states) + return hs, hs if output_hidden_states else None body_fn = jax.checkpoint(body_fn) final_hs, all_hs = jax.lax.scan(body_fn, hidden_states, jnp.arange(num_layers)) From 03f64fb1d634ab74ebe7454b50f672cb9776257e Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 14:59:28 -0800 Subject: [PATCH 013/133] fix: exclude last layer output from all_hidden_states to match non-checkpointed path --- skyrl-tx/tx/models/llama3.py | 5 +++-- skyrl-tx/tx/models/qwen3.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 0fa3380fc..52c5846a3 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -308,8 +308,9 @@ def body_fn(hs, i): final_hs, all_hs = jax.lax.scan(body_fn, hidden_states, jnp.arange(num_layers)) if output_hidden_states: - # all_hs is [num_layers, batch, seq, hidden], convert to list and prepend input - all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers)] + # all_hs is [num_layers, batch, seq, hidden]. Exclude last layer output since + # it gets normed and appended in __call__ (matching non-checkpointed path). + all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers - 1)] else: all_hidden_states = [] diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 98dcecdd4..11003161d 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -423,8 +423,9 @@ def body_fn(hs, i): final_hs, all_hs = jax.lax.scan(body_fn, hidden_states, jnp.arange(num_layers)) if output_hidden_states: - # all_hs is [num_layers, batch, seq, hidden], convert to list and prepend input - all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers)] + # all_hs is [num_layers, batch, seq, hidden]. Exclude last layer output since + # it gets normed and appended in __call__ (matching non-checkpointed path). + all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers - 1)] else: all_hidden_states = [] From 94a5a56dc5d0c8cb415c8ddd8d24ee8ec19b4f2d Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 15:46:25 -0800 Subject: [PATCH 014/133] test: add gradient checkpointing tests - test_jax_backend.py: extend test_gradient_checkpointing to verify gradients match - test_models_common.py: add common tests for Llama3/Qwen3 (output, hidden_states, edge cases) --- skyrl-tx/tests/models/test_models_common.py | 133 ++++++++++++++++++++ skyrl-tx/tests/tinker/test_jax_backend.py | 11 +- 2 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 skyrl-tx/tests/models/test_models_common.py diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py new file mode 100644 index 000000000..277a791b8 --- /dev/null +++ b/skyrl-tx/tests/models/test_models_common.py @@ -0,0 +1,133 @@ +"""Common tests for Llama3 and Qwen3 models.""" + +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from transformers import PretrainedConfig + +from tx.models.configs import Llama3Config, Qwen3Config +from tx.models.llama3 import Llama3ForCausalLM +from tx.models.qwen3 import Qwen3ForCausalLM + + +def make_small_config(config_class, gradient_checkpointing=False, num_hidden_layers=2): + """Create a minimal config for fast testing.""" + base_config = PretrainedConfig( + hidden_size=64, + intermediate_size=128, + num_hidden_layers=num_hidden_layers, + num_attention_heads=2, + num_key_value_heads=2, + vocab_size=1000, + max_position_embeddings=128, + rms_norm_eps=1e-6, + ) + return config_class( + base_config, + max_lora_adapters=1, + max_lora_rank=1, + shard_attention_heads=False, + gradient_checkpointing=gradient_checkpointing, + ) + + +@pytest.fixture +def input_batch(): + """Common test inputs.""" + batch_size, seq_len = 2, 16 + input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, 1000) + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + return input_ids, attention_mask + + +@pytest.mark.parametrize("model_class,config_class", [ + (Llama3ForCausalLM, Llama3Config), + (Qwen3ForCausalLM, Qwen3Config), +]) +class TestGradientCheckpointing: + + def test_output_matches_non_checkpointed(self, model_class, config_class, input_batch): + """Forward pass should produce identical outputs with/without checkpointing.""" + input_ids, attention_mask = input_batch + + # Create model without checkpointing + config = make_small_config(config_class, gradient_checkpointing=False) + model = model_class(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + out_no_ckpt = model(input_ids, attention_mask=attention_mask, is_training=True) + + # Enable checkpointing + config.gradient_checkpointing = True + out_ckpt = model(input_ids, attention_mask=attention_mask, is_training=True) + + np.testing.assert_allclose(out_no_ckpt.logits, out_ckpt.logits, rtol=1e-5) + + def test_hidden_states_length_matches(self, model_class, config_class, input_batch): + """Both paths should return same number of hidden states.""" + input_ids, attention_mask = input_batch + config = make_small_config(config_class, gradient_checkpointing=False) + model = model_class(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + + out_no_ckpt = model( + input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True + ) + + config.gradient_checkpointing = True + out_ckpt = model( + input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True + ) + + assert len(out_no_ckpt.hidden_states) == len(out_ckpt.hidden_states) + assert len(out_ckpt.hidden_states) == config.num_hidden_layers + 1 + + for i, (hs_no_ckpt, hs_ckpt) in enumerate( + zip(out_no_ckpt.hidden_states, out_ckpt.hidden_states) + ): + np.testing.assert_allclose( + hs_no_ckpt, hs_ckpt, rtol=1e-5, err_msg=f"Mismatch at hidden state {i}" + ) + + def test_is_training_false_uses_standard_path(self, model_class, config_class, input_batch): + """is_training=False should use standard path with KV cache support.""" + input_ids, attention_mask = input_batch + config = make_small_config(config_class, gradient_checkpointing=True) + model = model_class(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + + out = model(input_ids, attention_mask=attention_mask, is_training=False) + + # KV cache should be populated (checkpointed path returns empty) + assert len(out.kv_cache.keys) == config.num_hidden_layers + + def test_single_layer_model(self, model_class, config_class, input_batch): + """Checkpointing should work with single layer.""" + input_ids, attention_mask = input_batch + + config = make_small_config(config_class, gradient_checkpointing=True, num_hidden_layers=1) + model = model_class(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + + out = model( + input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True + ) + + # [embed, normed_output] + assert len(out.hidden_states) == 2 + + def test_single_layer_output_matches(self, model_class, config_class, input_batch): + """Single layer model outputs should match with/without checkpointing.""" + input_ids, attention_mask = input_batch + + config = make_small_config(config_class, gradient_checkpointing=False, num_hidden_layers=1) + model = model_class(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + + out_no_ckpt = model( + input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True + ) + + config.gradient_checkpointing = True + out_ckpt = model( + input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True + ) + + np.testing.assert_allclose(out_no_ckpt.logits, out_ckpt.logits, rtol=1e-5) + assert len(out_no_ckpt.hidden_states) == len(out_ckpt.hidden_states) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index 2edd9d82b..bc176a464 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -330,9 +330,10 @@ def apply_step(request_id: int, model_id: str, optim_input: OptimStepInput) -> f def test_gradient_checkpointing(): """ - Verify gradient checkpointing doesn't affect loss values. + Verify gradient checkpointing doesn't affect loss values or gradients. """ losses = [] + grads_list = [] for use_gradient_checkpointing in (False, True): config = JaxBackendConfig( max_lora_adapters=1, @@ -354,8 +355,8 @@ def test_gradient_checkpointing(): sampling_logprobs = jnp.zeros((B, T), dtype=jnp.float32) advantages = jnp.zeros((B, T), dtype=jnp.float32) - # Compute loss, using gradient checkpointing if enabled - _, per_token_losses, _ = backend._forward_backward_and_accumulate( + # Compute loss and gradients, using gradient checkpointing if enabled + accumulated_grads, per_token_losses, _ = backend._forward_backward_and_accumulate( backend.accumulated_grads, backend.lora_params, backend.non_lora_params, @@ -369,10 +370,14 @@ def test_gradient_checkpointing(): advantages, ) losses.append(float(per_token_losses.mean())) + grads_list.append(accumulated_grads.grad_sum) # Check relative difference between losses is small assert abs(losses[0] - losses[1]) / abs(losses[0]) < 5e-3 + # Check gradients match + _assert_tree_allclose(grads_list[0], grads_list[1], rtol=1e-3, atol=1e-3, min_match_pct=99.0) + def make_sample_input(tokens: list[int], prompt_logprobs: bool = False, max_tokens: int = 16) -> types.SampleInput: """Build a SampleInput for testing.""" From 9ec6b17524512b2cf987a0b8d7c73ebd3b5604b2 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 16:20:54 -0800 Subject: [PATCH 015/133] fix --- skyrl-tx/tests/models/test_models_common.py | 137 ++++++++------------ 1 file changed, 52 insertions(+), 85 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 277a791b8..64a7f622c 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -1,74 +1,71 @@ -"""Common tests for Llama3 and Qwen3 models.""" +"""Common tests for gradient checkpointing.""" from flax import nnx import jax import jax.numpy as jnp import numpy as np import pytest -from transformers import PretrainedConfig +from transformers import AutoConfig, PretrainedConfig from tx.models.configs import Llama3Config, Qwen3Config from tx.models.llama3 import Llama3ForCausalLM from tx.models.qwen3 import Qwen3ForCausalLM -def make_small_config(config_class, gradient_checkpointing=False, num_hidden_layers=2): - """Create a minimal config for fast testing.""" - base_config = PretrainedConfig( - hidden_size=64, - intermediate_size=128, - num_hidden_layers=num_hidden_layers, - num_attention_heads=2, - num_key_value_heads=2, - vocab_size=1000, - max_position_embeddings=128, - rms_norm_eps=1e-6, - ) - return config_class( - base_config, - max_lora_adapters=1, - max_lora_rank=1, - shard_attention_heads=False, - gradient_checkpointing=gradient_checkpointing, - ) - - -@pytest.fixture -def input_batch(): - """Common test inputs.""" - batch_size, seq_len = 2, 16 - input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, 1000) - attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - return input_ids, attention_mask - - -@pytest.mark.parametrize("model_class,config_class", [ - (Llama3ForCausalLM, Llama3Config), - (Qwen3ForCausalLM, Qwen3Config), -]) +QWEN3_MODEL = "Qwen/Qwen3-0.6B" +LLAMA3_MODEL = "unsloth/Llama-3.2-1B" + + +def create_qwen3_model(): + """Create Qwen3 model for testing.""" + base_config = PretrainedConfig.from_pretrained(QWEN3_MODEL) + config = Qwen3Config(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) + mesh = jax.make_mesh((1, 1), ("fsdp", "tp")) + with jax.set_mesh(mesh): + model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + return model, config + + +def create_llama3_model(): + """Create Llama3 model for testing.""" + base_config = AutoConfig.from_pretrained(LLAMA3_MODEL) + config = Llama3Config(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) + mesh = jax.make_mesh((1, 1), ("dp", "tp")) + with jax.set_mesh(mesh): + model = Llama3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + return model, config + + +@pytest.mark.parametrize("create_model", [create_qwen3_model, create_llama3_model], ids=["qwen3", "llama3"]) class TestGradientCheckpointing: - def test_output_matches_non_checkpointed(self, model_class, config_class, input_batch): + def test_output_matches_non_checkpointed(self, create_model): """Forward pass should produce identical outputs with/without checkpointing.""" - input_ids, attention_mask = input_batch + model, config = create_model() - # Create model without checkpointing - config = make_small_config(config_class, gradient_checkpointing=False) - model = model_class(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + batch_size, seq_len = 2, 8 + input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + + # Run without checkpointing + config.gradient_checkpointing = False out_no_ckpt = model(input_ids, attention_mask=attention_mask, is_training=True) - # Enable checkpointing + # Run with checkpointing config.gradient_checkpointing = True out_ckpt = model(input_ids, attention_mask=attention_mask, is_training=True) - np.testing.assert_allclose(out_no_ckpt.logits, out_ckpt.logits, rtol=1e-5) + np.testing.assert_allclose(out_no_ckpt.logits, out_ckpt.logits, rtol=1e-4, atol=1e-6) - def test_hidden_states_length_matches(self, model_class, config_class, input_batch): + def test_hidden_states_length_matches(self, create_model): """Both paths should return same number of hidden states.""" - input_ids, attention_mask = input_batch - config = make_small_config(config_class, gradient_checkpointing=False) - model = model_class(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + model, config = create_model() + + batch_size, seq_len = 2, 8 + input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + config.gradient_checkpointing = False out_no_ckpt = model( input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True ) @@ -85,49 +82,19 @@ def test_hidden_states_length_matches(self, model_class, config_class, input_bat zip(out_no_ckpt.hidden_states, out_ckpt.hidden_states) ): np.testing.assert_allclose( - hs_no_ckpt, hs_ckpt, rtol=1e-5, err_msg=f"Mismatch at hidden state {i}" + hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}" ) - def test_is_training_false_uses_standard_path(self, model_class, config_class, input_batch): + def test_is_training_false_uses_standard_path(self, create_model): """is_training=False should use standard path with KV cache support.""" - input_ids, attention_mask = input_batch - config = make_small_config(config_class, gradient_checkpointing=True) - model = model_class(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + model, config = create_model() + config.gradient_checkpointing = True + + batch_size, seq_len = 2, 8 + input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) out = model(input_ids, attention_mask=attention_mask, is_training=False) # KV cache should be populated (checkpointed path returns empty) assert len(out.kv_cache.keys) == config.num_hidden_layers - - def test_single_layer_model(self, model_class, config_class, input_batch): - """Checkpointing should work with single layer.""" - input_ids, attention_mask = input_batch - - config = make_small_config(config_class, gradient_checkpointing=True, num_hidden_layers=1) - model = model_class(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) - - out = model( - input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True - ) - - # [embed, normed_output] - assert len(out.hidden_states) == 2 - - def test_single_layer_output_matches(self, model_class, config_class, input_batch): - """Single layer model outputs should match with/without checkpointing.""" - input_ids, attention_mask = input_batch - - config = make_small_config(config_class, gradient_checkpointing=False, num_hidden_layers=1) - model = model_class(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) - - out_no_ckpt = model( - input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True - ) - - config.gradient_checkpointing = True - out_ckpt = model( - input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True - ) - - np.testing.assert_allclose(out_no_ckpt.logits, out_ckpt.logits, rtol=1e-5) - assert len(out_no_ckpt.hidden_states) == len(out_ckpt.hidden_states) From f3cda4fba63aaa62c1882744a866f6f0fd013fd3 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 16:35:34 -0800 Subject: [PATCH 016/133] lint --- skyrl-tx/tests/models/test_models_common.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 64a7f622c..eb792a4a6 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -66,21 +66,15 @@ def test_hidden_states_length_matches(self, create_model): attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) config.gradient_checkpointing = False - out_no_ckpt = model( - input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True - ) + out_no_ckpt = model(input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True) config.gradient_checkpointing = True - out_ckpt = model( - input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True - ) + out_ckpt = model(input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True) assert len(out_no_ckpt.hidden_states) == len(out_ckpt.hidden_states) assert len(out_ckpt.hidden_states) == config.num_hidden_layers + 1 - for i, (hs_no_ckpt, hs_ckpt) in enumerate( - zip(out_no_ckpt.hidden_states, out_ckpt.hidden_states) - ): + for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(out_no_ckpt.hidden_states, out_ckpt.hidden_states)): np.testing.assert_allclose( hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}" ) From 5cf1c666b2315564e3d747c29606a8d556558687 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 20 Jan 2026 16:45:22 -0800 Subject: [PATCH 017/133] fix: add guard for empty layers in checkpointed forward Handle edge case where self.layers is empty to prevent IndexError. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/llama3.py | 2 ++ skyrl-tx/tx/models/qwen3.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 52c5846a3..5b2f06daa 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -291,6 +291,8 @@ def _forward_layers_checkpointed( Currently we have both self.layers (original) and stacked copy during forward. """ num_layers = len(self.layers) + if num_layers == 0: + return hidden_states, [] # Stack layer weights for dynamic indexing in scan layer_graphdef, _ = nnx.split(self.layers[0]) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 11003161d..dd37acf6c 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -406,6 +406,8 @@ def _forward_layers_checkpointed( Currently we have both self.layers (original) and stacked copy during forward. """ num_layers = len(self.layers) + if num_layers == 0: + return hidden_states, [] # Stack layer weights for dynamic indexing in scan layer_graphdef, _ = nnx.split(self.layers[0]) From cb0e72e50e1f85ca40fd615d12410183a1eec0fc Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 13:35:46 -0800 Subject: [PATCH 018/133] Unify logprobs computation in LogitsProcessor - Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths - Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers - Simplify jax.py to single compute_logprobs call --- skyrl-tx/tx/layers/logits_processor.py | 58 +++++++++++++++++++------- skyrl-tx/tx/tinker/backends/jax.py | 19 ++++----- 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/skyrl-tx/tx/layers/logits_processor.py b/skyrl-tx/tx/layers/logits_processor.py index 601ea7c2f..0a837555b 100644 --- a/skyrl-tx/tx/layers/logits_processor.py +++ b/skyrl-tx/tx/layers/logits_processor.py @@ -1,11 +1,11 @@ -"""LogitsProcessor for computing logits from hidden states.""" +"""LogitsProcessor for computing logits and logprobs from hidden states.""" import jax import jax.numpy as jnp class LogitsProcessor: - """Computes logits from hidden states using lm_head.""" + """Handles logits and log probability computation from hidden states.""" def __init__(self, config) -> None: self.config = config @@ -17,7 +17,7 @@ def __call__( adapter_indices: jax.Array | None = None, skip_prompt_logits: bool = False, ) -> jax.Array: - """Compute logits from hidden states. + """Compute logits from hidden states (for sampling). Args: hidden_states: Hidden states from the model backbone. @@ -30,28 +30,58 @@ def __call__( return lm_head(hidden_states, adapter_indices) @staticmethod - def compute_chunked_logprobs( - hidden_states: jax.Array, - lm_head_weight: jax.Array, + def compute_logprobs( + forward_output: jax.Array, target_ids: jax.Array, - chunk_size: int, + lm_head_weight: jax.Array | None = None, + chunk_size: int = 0, gradient_checkpointing: bool = False, ) -> jax.Array: - """Compute log probabilities using chunked lm_head computation. + """Compute log probabilities from model forward output. - This avoids materializing the full [B*T, V] logits tensor by computing - lm_head and log probabilities for each chunk sequentially. + Supports two modes: + - Chunked: forward_output is hidden_states [B, T, H], requires lm_head_weight + - Non-chunked: forward_output is logits [B, T, V] Args: - hidden_states: Hidden states from the model backbone [B, T, H]. - lm_head_weight: Language model head weight matrix [H, V]. + forward_output: Either hidden_states [B, T, H] (chunked) or logits [B, T, V]. target_ids: Target token IDs [B, T]. - chunk_size: Number of tokens to process per chunk. - gradient_checkpointing: Whether to checkpoint each chunk for memory savings. + lm_head_weight: LM head weight matrix [H, V] for chunked mode (None for non-chunked). + chunk_size: Chunk size for chunked computation (0 or negative = non-chunked). + gradient_checkpointing: Whether to checkpoint each chunk (chunked mode only). Returns: Log probabilities for target tokens [B, T]. """ + use_chunked = lm_head_weight is not None and chunk_size > 0 + + if use_chunked: + return LogitsProcessor._compute_chunked_logprobs( + forward_output, lm_head_weight, target_ids, chunk_size, gradient_checkpointing + ) + else: + return LogitsProcessor._logits_to_logprobs(forward_output, target_ids) + + @staticmethod + def _logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: + """Convert logits to log probabilities for target tokens.""" + log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) + target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) + return (target_logits - log_sum_exp).squeeze(-1) + + @staticmethod + def _compute_chunked_logprobs( + hidden_states: jax.Array, + lm_head_weight: jax.Array, + target_ids: jax.Array, + chunk_size: int, + gradient_checkpointing: bool, + ) -> jax.Array: + """Compute log probabilities using chunked lm_head computation. + + This avoids materializing the full [B*T, V] logits tensor by computing + lm_head and log probabilities for each chunk sequentially. + """ B, T, H = hidden_states.shape total_tokens = B * T diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 3511b9399..455b90af9 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -285,18 +285,13 @@ def loss_for_lora( self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices ) - if use_chunked: - # Chunked cross-entropy using LogitsProcessor - hidden_states = forward_out # [B, T, H] - target_logprobs = LogitsProcessor.compute_chunked_logprobs( - hidden_states, lm_head_weight, target_ids, loss_chunk_size, gradient_checkpointing - ) - else: - # Non-chunked: use pre-computed logits (with LoRA applied to lm_head) - logits = forward_out # [B, T, V] - log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) - target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) - target_logprobs = (target_logits - log_sum_exp).squeeze(-1) + target_logprobs = LogitsProcessor.compute_logprobs( + forward_out, + target_ids, + lm_head_weight if use_chunked else None, + loss_chunk_size if use_chunked else 0, + gradient_checkpointing, + ) def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): return jax.lax.switch( From dc6f2a48fd554927477c36de8be883645258bc72 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 14:32:52 -0800 Subject: [PATCH 019/133] fix: restore skip_prompt_logits parameter (separate from skip_logits) --- skyrl-tx/tx/models/llama3.py | 3 ++- skyrl-tx/tx/models/qwen3.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 2e1cae901..55c7c76eb 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -307,6 +307,7 @@ def __call__( adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, skip_logits: bool = False, + skip_prompt_logits: bool = False, ) -> CausalLMOutput: if positions is None: positions = compute_positions(attention_mask) @@ -324,7 +325,7 @@ def __call__( # Skip logits computation for chunked cross-entropy (uses lm_head weight directly) logits = None else: - logits = self.logits_processor(outputs.last_hidden_state, self.lm_head, adapter_indices) + logits = self.logits_processor(outputs.last_hidden_state, self.lm_head, adapter_indices, skip_prompt_logits) return CausalLMOutput( logits=logits, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 574506b15..126b9e55b 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -422,6 +422,7 @@ def __call__( adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, skip_logits: bool = False, + skip_prompt_logits: bool = False, ) -> CausalLMOutput: if positions is None: positions = compute_positions(attention_mask) @@ -439,7 +440,7 @@ def __call__( # Skip logits computation for chunked cross-entropy (uses lm_head weight directly) logits = None else: - logits = self.logits_processor(outputs.last_hidden_state, self.lm_head, adapter_indices) + logits = self.logits_processor(outputs.last_hidden_state, self.lm_head, adapter_indices, skip_prompt_logits) return CausalLMOutput( logits=logits, From 1e4b246055eaa35f522cccf4af47adab4597640e Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 15:49:05 -0800 Subject: [PATCH 020/133] docs: add LogitsProcessor design document --- skyrl-tx/docs/design/logits_processor.md | 199 +++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 skyrl-tx/docs/design/logits_processor.md diff --git a/skyrl-tx/docs/design/logits_processor.md b/skyrl-tx/docs/design/logits_processor.md new file mode 100644 index 000000000..e82a37a3d --- /dev/null +++ b/skyrl-tx/docs/design/logits_processor.md @@ -0,0 +1,199 @@ +# LogitsProcessor Design + +## Overview + +This document proposes a design for `LogitsProcessor` - a utility for computing logits and log probabilities from model hidden states. + +## Background + +In causal language models, the forward pass produces hidden states `[B, T, H]` which must be projected to vocabulary logits `[B, T, V]` via the `lm_head` layer. Different scenarios have different requirements: + +### Training + +Compute logprobs for all positions to calculate loss. + +``` +hidden_states [B, T, H] → logprobs [B, T] → loss +``` + +Full logits `[B, T, V]` are not needed - we only need logprobs of target tokens. This enables **chunked computation**: process tokens in chunks, compute logits and extract logprobs per chunk, avoiding full `[B*T, V]` materialization. + +### Inference: Prefill + +Process the prompt. Return logits for the last position (to start decoding). Optionally return logprobs of prompt tokens. + +``` +hidden_states [B, T, H] → logits [B, 1, V] (last position, for sampling) + → logprobs [B, T-1] (optional, for prompt logprobs) +``` + +For prompt logprobs, same as training - full logits not needed, can use chunked computation. + +### Inference: Decode + +Generate one token at a time. + +1. **Compute logits:** `hidden_states [B, 1, H] → logits [B, 1, V]` +2. **Apply sampling transforms:** temperature scaling, top_k filtering, top_p filtering on logits +3. **Sample:** draw next_token from the transformed distribution +4. **Extract logprob:** get log probability of the sampled token from original logits + +**Full logits required** because step 2 operates on the full vocabulary distribution. + +## Existing Designs + +### SGLang + +**Pattern:** LogitsProcessor as a model attribute, called inside `model.forward()`. + +**Key files:** +- [LogitsProcessor class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/logits_processor.py#L235) +- [LlamaForCausalLM.forward()](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py#L499) calls [logits_processor](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py#L522) + +```python +class LlamaForCausalLM(nn.Module): + def __init__(self, ...): + self.logits_processor = LogitsProcessor(config) + + def forward(self, input_ids, positions, forward_batch, ...) -> LogitsProcessorOutput: + hidden_states = self.model(input_ids, ...) + return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch, ...) +``` + +**Problems:** + +1. **Wrapper pattern:** `forward()` just returns `logits_processor(...)` output. No encapsulation benefit. + +2. **Inconsistent return types:** `forward()` returns [different types](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py#L520-L532) based on runtime conditions (LogitsProcessorOutput, PoolerOutput, or Tensor). + +3. **God object:** [LogitsProcessor.forward()](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/logits_processor.py#L379) is 500+ lines handling many modes through complex branching. + +### vLLM + +**Pattern:** LogitsProcessor as a model attribute, called via separate `compute_logits()` method. + +**Key files:** +- [LogitsProcessor class](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/logits_processor.py#L18) +- [LlamaForCausalLM.compute_logits()](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py#L640) +- [model_runner calls compute_logits()](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu_model_runner.py#L3336) + +```python +class LlamaForCausalLM(nn.Module): + def __init__(self, ...): + self.logits_processor = LogitsProcessor(vocab_size, scale=logit_scale) + + def forward(self, input_ids, positions, ...) -> Tensor: + return self.model(input_ids, positions, ...) # returns hidden_states + + def compute_logits(self, hidden_states) -> Tensor: + return self.logits_processor(self.lm_head, hidden_states) +``` + +**Improvements over SGLang:** +- `forward()` has single responsibility (returns hidden_states) +- Logits computation is explicit via separate method + +**Remaining Problems:** + +1. **Still a wrapper:** `compute_logits()` just wraps `self.logits_processor(...)`. + +2. **Unnecessary model attribute:** `logits_processor` stores minimal state. Could be a static utility. + +3. **No logprobs support:** Only computes logits. Logprobs computation happens elsewhere. + +## Proposed Design + +### Principles + +1. **Standalone utility** - Not a model attribute +2. **Model returns hidden_states** - Single responsibility, consistent return type +3. **Caller decides what to compute** - Logits for sampling, logprobs for training +4. **Unified logprobs API** - Same method for training and prompt logprobs + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Caller │ +│ (JaxBackend for training, Generator for sampling) │ +└─────────────────────────────────────────────────────────────────┘ + │ │ + │ model(input_ids, ...) │ LogitsProcessor.*() + ▼ ▼ +┌───────────────────────────┐ ┌───────────────────────────────┐ +│ CausalLM Model │ │ LogitsProcessor │ +│ │ │ │ +│ forward() → hidden_states│ │ compute_logits() │ +│ lm_head property │ │ compute_logprobs() │ +└───────────────────────────┘ │ logits_to_logprobs() │ + └───────────────────────────────┘ +``` + +### API + +```python +class LogitsProcessor: + """Utility for computing logits and logprobs from hidden states.""" + + @staticmethod + def compute_logits(hidden_states, lm_head, adapter_indices=None) -> jax.Array: + """Compute logits from hidden states. For sampling.""" + + @staticmethod + def compute_logprobs(hidden_states, lm_head, target_ids, adapter_indices=None, + chunk_size=0, gradient_checkpointing=False) -> jax.Array: + """Compute logprobs from hidden states. For training and prompt logprobs. + + Supports chunked computation to avoid materializing full [B*T, V] logits. + """ + + @staticmethod + def logits_to_logprobs(logits, target_ids) -> jax.Array: + """Convert logits to logprobs. For decode logprobs when logits already computed.""" +``` + +### Usage + +**Training:** +```python +output = model(input_ids, attention_mask=attention_mask, ...) +logprobs = LogitsProcessor.compute_logprobs( + output.last_hidden_state, model.lm_head, target_ids, + chunk_size=1024, gradient_checkpointing=True +) +loss = compute_loss(logprobs, ...) +``` + +**Sampling (prompt logprobs):** +```python +output = model(input_ids, attention_mask=attention_mask, ...) +prompt_logprobs = LogitsProcessor.compute_logprobs( + output.last_hidden_state, model.lm_head, input_ids[:, 1:], + chunk_size=1024 +) +``` + +**Sampling (decode):** +```python +output = model(next_token, kv_cache=kv_cache, ...) +logits = LogitsProcessor.compute_logits(output.last_hidden_state, model.lm_head) +next_token = sample(logits, temperature, top_k, top_p) +logprob = LogitsProcessor.logits_to_logprobs(logits, next_token) +``` + +### Benefits + +1. **Separation of concerns** - Model produces hidden states, LogitsProcessor transforms them +2. **Consistent model interface** - forward() always returns hidden_states +3. **Unified logprobs** - Same API for training and prompt logprobs +4. **Reduced code duplication** - Currently, logprobs computation is duplicated in `generator.py` (`compute_prompt_logprobs`) and `jax.py` backend (chunked loss). This design consolidates both into `LogitsProcessor.compute_logprobs()` +5. **Testable** - Easy to unit test with mock inputs + +### Migration Path + +1. Update `LogitsProcessor` to standalone utility with three methods +2. Update model to return hidden_states only (remove `skip_logits`, `skip_prompt_logits` flags) +3. Update generator to use `LogitsProcessor.compute_logits()` and `compute_logprobs()` +4. Update backend to use `LogitsProcessor.compute_logprobs()` +5. Remove `logits_processor` attribute from model classes +6. Simplify `CausalLMOutput` (remove `logits`, `lm_head` fields) From 5e2d93731260b1cde9c9511a8dc20a94ff40a544 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 17:49:26 -0800 Subject: [PATCH 021/133] refactor: implement LogitsProcessor design - LogitsProcessor is now a standalone utility with three static methods: compute_logits(), compute_logprobs(), logits_to_logprobs() - Model forward() returns only hidden_states (removed logits computation) - Simplified CausalLMOutput: removed logits and lm_head fields - Generator uses LogitsProcessor for all logits/logprobs computation - Backend uses LogitsProcessor.compute_logprobs() with chunking - Updated tests to use new LogitsProcessor API Co-Authored-By: Claude Opus 4.5 --- .../tests/models/test_llama3_lora_training.py | 3 +- skyrl-tx/tests/models/test_models_common.py | 28 ++++----- skyrl-tx/tests/models/test_qwen3.py | 6 +- .../tests/models/test_qwen3_lora_training.py | 3 +- skyrl-tx/tests/utils/test_generator.py | 32 +++++++--- skyrl-tx/tx/layers/logits_processor.py | 61 ++++++++++--------- skyrl-tx/tx/models/llama3.py | 12 ---- skyrl-tx/tx/models/qwen3.py | 12 ---- skyrl-tx/tx/models/types.py | 4 -- skyrl-tx/tx/tinker/backends/jax.py | 19 +++--- 10 files changed, 82 insertions(+), 98 deletions(-) diff --git a/skyrl-tx/tests/models/test_llama3_lora_training.py b/skyrl-tx/tests/models/test_llama3_lora_training.py index 012878af2..d01cbfc00 100644 --- a/skyrl-tx/tests/models/test_llama3_lora_training.py +++ b/skyrl-tx/tests/models/test_llama3_lora_training.py @@ -7,6 +7,7 @@ from tx.models.configs import Llama3Config from tx.models.llama3 import Llama3ForCausalLM +from tx.layers.logits_processor import LogitsProcessor from tx.utils.models import get_dtype, load_safetensors from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig @@ -38,7 +39,7 @@ def test_lora_training(): def loss_fn(model, input_ids, target_ids, attention_mask): outputs = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) - logits = outputs.logits + logits = LogitsProcessor.compute_logits(outputs.last_hidden_state, model.lm_head, adapter_indices) return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target_ids).mean() # Compute gradients - we need to use nnx.split to separate parameters diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 28b710366..2707f464b 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -7,6 +7,7 @@ import pytest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from tx.layers.logits_processor import LogitsProcessor from tx.models.configs import Llama3Config, Qwen3Config from tx.models.llama3 import Llama3ForCausalLM from tx.models.qwen3 import Qwen3ForCausalLM @@ -22,8 +23,8 @@ ], ids=["llama3", "qwen3"], ) -def test_skip_prompt_logits(model_name, config_cls, model_cls, mesh_axes): - """Test that skip_prompt_logits returns correct shape and values.""" +def test_logits_processor(model_name, config_cls, model_cls, mesh_axes): + """Test that LogitsProcessor computes correct logits and logprobs.""" tokenizer = AutoTokenizer.from_pretrained(model_name) hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) @@ -41,22 +42,19 @@ def test_skip_prompt_logits(model_name, config_cls, model_cls, mesh_axes): model = model_cls(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) load_safetensors(tmp, config, model) - # Get full logits - outputs_full = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) - assert outputs_full.logits.shape == (batch_size, seq_len, config.vocab_size) + # Get hidden states from model + outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) - # Get last token logits only - outputs_last = model( - batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy(), skip_prompt_logits=True - ) - assert outputs_last.logits.shape == ( - batch_size, - 1, - config.vocab_size, - ), f"Expected shape ({batch_size}, 1, {config.vocab_size}), got {outputs_last.logits.shape}" + # Compute full logits using LogitsProcessor + full_logits = LogitsProcessor.compute_logits(outputs.last_hidden_state, model.lm_head) + assert full_logits.shape == (batch_size, seq_len, config.vocab_size) + + # Compute last token logits only + last_logits = LogitsProcessor.compute_logits(outputs.last_hidden_state[:, -1:, :], model.lm_head) + assert last_logits.shape == (batch_size, 1, config.vocab_size) # Last token logits should match - assert np.allclose(outputs_full.logits[:, -1:, :], outputs_last.logits, rtol=1e-5, atol=1e-5) + assert np.allclose(full_logits[:, -1:, :], last_logits, rtol=1e-5, atol=1e-5) # Test generation equivalence with and without prompt_logprobs input_ids = jnp.array(batch.input_ids.numpy()) diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index c450efbf8..653a31539 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -11,6 +11,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock as HFQwen3MoeSparseMoeBlock +from tx.layers.logits_processor import LogitsProcessor from tx.layers.lora import LoRAMixin from tx.models.configs import Qwen3Config from tx.models.qwen3 import Qwen3ForCausalLM, Qwen3MoeSparseMoeBlock @@ -272,6 +273,9 @@ def test_qwen3_lora(): adapter_indices=adapter_indices, ) + # Compute logits using LogitsProcessor + logits = LogitsProcessor.compute_logits(outputs.last_hidden_state, model.lm_head, adapter_indices) + # Compare outputs with corresponding adapters for idx in range(len(lora_adapters)): - assert np.allclose(hf_outputs_list[idx].logits[0], outputs.logits[idx], rtol=1e-3, atol=1e-3) + assert np.allclose(hf_outputs_list[idx].logits[0], logits[idx], rtol=1e-3, atol=1e-3) diff --git a/skyrl-tx/tests/models/test_qwen3_lora_training.py b/skyrl-tx/tests/models/test_qwen3_lora_training.py index 88d41f433..5eb84e5ac 100644 --- a/skyrl-tx/tests/models/test_qwen3_lora_training.py +++ b/skyrl-tx/tests/models/test_qwen3_lora_training.py @@ -7,6 +7,7 @@ from tx.models.configs import Qwen3Config from tx.models.qwen3 import Qwen3ForCausalLM +from tx.layers.logits_processor import LogitsProcessor from tx.utils.models import get_dtype, load_safetensors from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig @@ -38,7 +39,7 @@ def test_lora_training(): def loss_fn(model, input_ids, target_ids, attention_mask): outputs = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) - logits = outputs.logits + logits = LogitsProcessor.compute_logits(outputs.last_hidden_state, model.lm_head, adapter_indices) return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target_ids).mean() # Compute gradients - we need to use nnx.split to separate parameters diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 4525463e4..4c7ad6e8b 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -5,9 +5,22 @@ from tx.utils.generator import GenerateOutput, GeneratorMixin, KVCache, apply_top_k_batch, apply_top_p_batch +class DummyLMHead: + """Dummy lm_head that acts as identity (hidden_states are already logits).""" + + def __call__(self, hidden_states, adapter_indices=None): + return hidden_states + + class DummyModel(GeneratorMixin, nnx.Module): def __init__(self, vocab_size: int = 16): self.vocab_size = vocab_size + self.lm_head = DummyLMHead() + self._lm_head_weight = jnp.eye(vocab_size, dtype=jnp.float32) + + @property + def lm_head_weight(self): + return self._lm_head_weight def __call__( self, @@ -16,27 +29,26 @@ def __call__( positions=None, kv_cache=None, adapter_indices=None, - skip_prompt_logits=False, ): - """Simple dummy model for testing generator behavior.""" + """Simple dummy model for testing generator behavior. + + In this dummy model, hidden_states directly equal logits (lm_head is identity). + """ batch_size, seq_len = input_ids.shape base = jnp.arange(self.vocab_size, dtype=jnp.float32) if kv_cache is None: - # Prefill: deterministic logits - logits = jnp.tile(base[None, None, :], (batch_size, seq_len, 1)) - # Only return last token logits if requested (saves memory during prefill) - if skip_prompt_logits: - logits = logits[:, -1:, :] + # Prefill: deterministic hidden_states (which equal logits through identity lm_head) + hidden_states = jnp.tile(base[None, None, :], (batch_size, seq_len, 1)) keys = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] values = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] kv_cache = KVCache(keys=keys, values=values, cache_position=seq_len) else: - # Step: logits vary with cache_position - logits = jnp.tile(base[None, None, :] + kv_cache.cache_position, (batch_size, 1, 1)) + # Step: hidden_states vary with cache_position + hidden_states = jnp.tile(base[None, None, :] + kv_cache.cache_position, (batch_size, 1, 1)) kv_cache = KVCache(keys=kv_cache.keys, values=kv_cache.values, cache_position=kv_cache.cache_position + 1) - return CausalLMOutput(logits=logits, last_hidden_state=logits, kv_cache=kv_cache) + return CausalLMOutput(last_hidden_state=hidden_states, kv_cache=kv_cache) def make_inputs(batch_size: int, prompt_length: int): diff --git a/skyrl-tx/tx/layers/logits_processor.py b/skyrl-tx/tx/layers/logits_processor.py index 0a837555b..7555e9871 100644 --- a/skyrl-tx/tx/layers/logits_processor.py +++ b/skyrl-tx/tx/layers/logits_processor.py @@ -5,66 +5,67 @@ class LogitsProcessor: - """Handles logits and log probability computation from hidden states.""" + """Utility for computing logits and logprobs from hidden states.""" - def __init__(self, config) -> None: - self.config = config - - def __call__( - self, + @staticmethod + def compute_logits( hidden_states: jax.Array, lm_head, adapter_indices: jax.Array | None = None, - skip_prompt_logits: bool = False, ) -> jax.Array: - """Compute logits from hidden states (for sampling). + """Compute logits from hidden states. For sampling. Args: - hidden_states: Hidden states from the model backbone. - lm_head: Language model head (LoRALinear or embed_tokens.T). + hidden_states: Hidden states from the model backbone [B, T, H]. + lm_head: Language model head (LoRALinear or transposed embedding). adapter_indices: Optional adapter indices for LoRA. - skip_prompt_logits: If True, only compute logits for the last token (saves memory). + + Returns: + Logits [B, T, V]. """ - if skip_prompt_logits: - hidden_states = hidden_states[:, -1:, :] return lm_head(hidden_states, adapter_indices) @staticmethod def compute_logprobs( - forward_output: jax.Array, + hidden_states: jax.Array, + lm_head_weight: jax.Array, target_ids: jax.Array, - lm_head_weight: jax.Array | None = None, chunk_size: int = 0, gradient_checkpointing: bool = False, ) -> jax.Array: - """Compute log probabilities from model forward output. + """Compute logprobs from hidden states. For training and prompt logprobs. - Supports two modes: - - Chunked: forward_output is hidden_states [B, T, H], requires lm_head_weight - - Non-chunked: forward_output is logits [B, T, V] + Supports chunked computation to avoid materializing full [B*T, V] logits. Args: - forward_output: Either hidden_states [B, T, H] (chunked) or logits [B, T, V]. + hidden_states: Hidden states [B, T, H]. + lm_head_weight: LM head weight matrix [H, V]. target_ids: Target token IDs [B, T]. - lm_head_weight: LM head weight matrix [H, V] for chunked mode (None for non-chunked). - chunk_size: Chunk size for chunked computation (0 or negative = non-chunked). - gradient_checkpointing: Whether to checkpoint each chunk (chunked mode only). + chunk_size: Chunk size for chunked computation (0 = non-chunked). + gradient_checkpointing: Whether to checkpoint each chunk. Returns: Log probabilities for target tokens [B, T]. """ - use_chunked = lm_head_weight is not None and chunk_size > 0 - - if use_chunked: + if chunk_size > 0: return LogitsProcessor._compute_chunked_logprobs( - forward_output, lm_head_weight, target_ids, chunk_size, gradient_checkpointing + hidden_states, lm_head_weight, target_ids, chunk_size, gradient_checkpointing ) else: - return LogitsProcessor._logits_to_logprobs(forward_output, target_ids) + logits = hidden_states @ lm_head_weight + return LogitsProcessor.logits_to_logprobs(logits, target_ids) @staticmethod - def _logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: - """Convert logits to log probabilities for target tokens.""" + def logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: + """Convert logits to logprobs. For decode logprobs when logits already computed. + + Args: + logits: Logits [B, T, V] or [B, V]. + target_ids: Target token IDs [B, T] or [B]. + + Returns: + Log probabilities for target tokens [B, T] or [B]. + """ log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) return (target_logits - log_sum_exp).squeeze(-1) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 55c7c76eb..d838ffc97 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -5,7 +5,6 @@ from transformers import LlamaConfig from tx.layers.lora import LoRAEmbed, LoRALinear -from tx.layers.logits_processor import LogitsProcessor from tx.layers.rotary_embedding import apply_rope from tx.layers.layernorm import RMSNorm from tx.models.types import CausalLMOutput, ModelOutput @@ -282,7 +281,6 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_rank=config.max_lora_rank, rngs=rngs, ) - self.logits_processor = LogitsProcessor(config) @staticmethod def is_lora_param(path: tuple, _value) -> bool: @@ -306,8 +304,6 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, - skip_logits: bool = False, - skip_prompt_logits: bool = False, ) -> CausalLMOutput: if positions is None: positions = compute_positions(attention_mask) @@ -321,18 +317,10 @@ def __call__( kv_cache=kv_cache, ) - if skip_logits: - # Skip logits computation for chunked cross-entropy (uses lm_head weight directly) - logits = None - else: - logits = self.logits_processor(outputs.last_hidden_state, self.lm_head, adapter_indices, skip_prompt_logits) - return CausalLMOutput( - logits=logits, last_hidden_state=outputs.last_hidden_state, kv_cache=outputs.kv_cache, hidden_states=outputs.hidden_states, - lm_head=self.lm_head_weight, ) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 126b9e55b..bc58bea83 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -4,7 +4,6 @@ from jax.sharding import get_abstract_mesh from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear -from tx.layers.logits_processor import LogitsProcessor from tx.layers.util import prepare_routing, shard_map_ep from tx.layers.rotary_embedding import apply_rope from tx.models.configs import Qwen3Config @@ -397,7 +396,6 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_rank=config.max_lora_rank, rngs=rngs, ) - self.logits_processor = LogitsProcessor(config) @staticmethod def is_lora_param(path: tuple, _value) -> bool: @@ -421,8 +419,6 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, - skip_logits: bool = False, - skip_prompt_logits: bool = False, ) -> CausalLMOutput: if positions is None: positions = compute_positions(attention_mask) @@ -436,18 +432,10 @@ def __call__( kv_cache=kv_cache, ) - if skip_logits: - # Skip logits computation for chunked cross-entropy (uses lm_head weight directly) - logits = None - else: - logits = self.logits_processor(outputs.last_hidden_state, self.lm_head, adapter_indices, skip_prompt_logits) - return CausalLMOutput( - logits=logits, last_hidden_state=outputs.last_hidden_state, kv_cache=outputs.kv_cache, hidden_states=outputs.hidden_states, - lm_head=self.lm_head_weight, ) diff --git a/skyrl-tx/tx/models/types.py b/skyrl-tx/tx/models/types.py index ab9a32723..be60f6ec9 100644 --- a/skyrl-tx/tx/models/types.py +++ b/skyrl-tx/tx/models/types.py @@ -36,15 +36,11 @@ class CausalLMOutput: """Output type for causal language models like Qwen3ForCausalLM. Attributes: - logits: The language modeling logits (None if skip_logits=True). last_hidden_state: The last hidden state from the model. kv_cache: The updated key-value cache. hidden_states: All hidden states, if output_hidden_states=True. - lm_head: The lm_head weight [H, V] for external logits computation. """ - logits: jax.Array | None last_hidden_state: jax.Array kv_cache: KVCache hidden_states: list[jax.Array] | None = None - lm_head: jax.Array | None = None diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 455b90af9..7abcebfb8 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -239,8 +239,7 @@ def _jit_timing_context(self, seq_len: int, mode: str): def _create_loss_and_grad_fn(self): """Compile and cache the loss function to avoid re-jitting on every call.""" - use_chunked = self._use_chunked_loss - loss_chunk_size = self.config.loss_chunk_size + loss_chunk_size = self.config.loss_chunk_size if self._use_chunked_loss else 0 gradient_checkpointing = self.config.gradient_checkpointing def _model_forward( @@ -251,18 +250,14 @@ def _model_forward( attention_mask: jax.Array, adapter_indices: jax.Array, ) -> tuple[jax.Array, jax.Array]: - """Forward pass returning (hidden_states, lm_head) or (logits, None).""" + """Forward pass returning (hidden_states, lm_head_weight).""" model = nnx.merge(graphdef, lora_params, non_lora_params) output = model( input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices, - skip_logits=use_chunked, ) - if use_chunked: - return output.last_hidden_state, output.lm_head - else: - return output.logits, None + return output.last_hidden_state, model.lm_head_weight if self.config.gradient_checkpointing: # Wrap the model forward call to use jax.checkpoint for gradient checkpointing @@ -281,15 +276,15 @@ def loss_for_lora( sampling_logprobs: jax.Array, advantages: jax.Array, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - forward_out, lm_head_weight = _model_forward( + hidden_states, lm_head_weight = _model_forward( self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices ) target_logprobs = LogitsProcessor.compute_logprobs( - forward_out, + hidden_states, + lm_head_weight, target_ids, - lm_head_weight if use_chunked else None, - loss_chunk_size if use_chunked else 0, + loss_chunk_size, gradient_checkpointing, ) From 7f9a762e61b501bbd3810a6cc7e94c4e491109b7 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 18:28:15 -0800 Subject: [PATCH 022/133] refactor: encapsulate LogitsProcessor in CausalLMBase - Create CausalLMBase class with compute_logits/compute_logprobs methods - Models expose wrapper methods instead of direct LogitsProcessor access - Update generator and jax.py backend to use model methods - LogitsProcessor is now internal implementation detail Co-Authored-By: Claude Opus 4.5 --- .../tests/models/test_llama3_lora_training.py | 3 +- skyrl-tx/tests/models/test_models_common.py | 7 +- skyrl-tx/tests/models/test_qwen3.py | 5 +- .../tests/models/test_qwen3_lora_training.py | 3 +- skyrl-tx/tests/utils/test_generator.py | 32 +++++---- skyrl-tx/tx/models/base.py | 67 +++++++++++++++++++ skyrl-tx/tx/models/llama3.py | 3 +- skyrl-tx/tx/models/qwen3.py | 3 +- skyrl-tx/tx/tinker/backends/jax.py | 28 +++----- skyrl-tx/tx/utils/generator.py | 31 ++++----- 10 files changed, 118 insertions(+), 64 deletions(-) create mode 100644 skyrl-tx/tx/models/base.py diff --git a/skyrl-tx/tests/models/test_llama3_lora_training.py b/skyrl-tx/tests/models/test_llama3_lora_training.py index d01cbfc00..fb3ecce39 100644 --- a/skyrl-tx/tests/models/test_llama3_lora_training.py +++ b/skyrl-tx/tests/models/test_llama3_lora_training.py @@ -7,7 +7,6 @@ from tx.models.configs import Llama3Config from tx.models.llama3 import Llama3ForCausalLM -from tx.layers.logits_processor import LogitsProcessor from tx.utils.models import get_dtype, load_safetensors from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig @@ -39,7 +38,7 @@ def test_lora_training(): def loss_fn(model, input_ids, target_ids, attention_mask): outputs = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) - logits = LogitsProcessor.compute_logits(outputs.last_hidden_state, model.lm_head, adapter_indices) + logits = model.compute_logits(outputs.last_hidden_state, adapter_indices) return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target_ids).mean() # Compute gradients - we need to use nnx.split to separate parameters diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 2707f464b..247856665 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -7,7 +7,6 @@ import pytest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -from tx.layers.logits_processor import LogitsProcessor from tx.models.configs import Llama3Config, Qwen3Config from tx.models.llama3 import Llama3ForCausalLM from tx.models.qwen3 import Qwen3ForCausalLM @@ -45,12 +44,12 @@ def test_logits_processor(model_name, config_cls, model_cls, mesh_axes): # Get hidden states from model outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) - # Compute full logits using LogitsProcessor - full_logits = LogitsProcessor.compute_logits(outputs.last_hidden_state, model.lm_head) + # Compute full logits using model.compute_logits + full_logits = model.compute_logits(outputs.last_hidden_state) assert full_logits.shape == (batch_size, seq_len, config.vocab_size) # Compute last token logits only - last_logits = LogitsProcessor.compute_logits(outputs.last_hidden_state[:, -1:, :], model.lm_head) + last_logits = model.compute_logits(outputs.last_hidden_state[:, -1:, :]) assert last_logits.shape == (batch_size, 1, config.vocab_size) # Last token logits should match diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index 653a31539..8a3d5d2a7 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -11,7 +11,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock as HFQwen3MoeSparseMoeBlock -from tx.layers.logits_processor import LogitsProcessor from tx.layers.lora import LoRAMixin from tx.models.configs import Qwen3Config from tx.models.qwen3 import Qwen3ForCausalLM, Qwen3MoeSparseMoeBlock @@ -273,8 +272,8 @@ def test_qwen3_lora(): adapter_indices=adapter_indices, ) - # Compute logits using LogitsProcessor - logits = LogitsProcessor.compute_logits(outputs.last_hidden_state, model.lm_head, adapter_indices) + # Compute logits using model.compute_logits + logits = model.compute_logits(outputs.last_hidden_state, adapter_indices) # Compare outputs with corresponding adapters for idx in range(len(lora_adapters)): diff --git a/skyrl-tx/tests/models/test_qwen3_lora_training.py b/skyrl-tx/tests/models/test_qwen3_lora_training.py index 5eb84e5ac..46bc368d7 100644 --- a/skyrl-tx/tests/models/test_qwen3_lora_training.py +++ b/skyrl-tx/tests/models/test_qwen3_lora_training.py @@ -7,7 +7,6 @@ from tx.models.configs import Qwen3Config from tx.models.qwen3 import Qwen3ForCausalLM -from tx.layers.logits_processor import LogitsProcessor from tx.utils.models import get_dtype, load_safetensors from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig @@ -39,7 +38,7 @@ def test_lora_training(): def loss_fn(model, input_ids, target_ids, attention_mask): outputs = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) - logits = LogitsProcessor.compute_logits(outputs.last_hidden_state, model.lm_head, adapter_indices) + logits = model.compute_logits(outputs.last_hidden_state, adapter_indices) return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target_ids).mean() # Compute gradients - we need to use nnx.split to separate parameters diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 4c7ad6e8b..e2b973e25 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -1,26 +1,20 @@ from flax import nnx +import jax import jax.numpy as jnp +from tx.models.base import CausalLMBase from tx.models.types import CausalLMOutput from tx.tinker.types import SamplingParams from tx.utils.generator import GenerateOutput, GeneratorMixin, KVCache, apply_top_k_batch, apply_top_p_batch -class DummyLMHead: - """Dummy lm_head that acts as identity (hidden_states are already logits).""" - - def __call__(self, hidden_states, adapter_indices=None): - return hidden_states +class DummyModel(GeneratorMixin, CausalLMBase, nnx.Module): + """Dummy model for testing generator behavior. + In this dummy model, hidden_states directly equal logits (identity transformation). + """ -class DummyModel(GeneratorMixin, nnx.Module): def __init__(self, vocab_size: int = 16): self.vocab_size = vocab_size - self.lm_head = DummyLMHead() - self._lm_head_weight = jnp.eye(vocab_size, dtype=jnp.float32) - - @property - def lm_head_weight(self): - return self._lm_head_weight def __call__( self, @@ -30,15 +24,11 @@ def __call__( kv_cache=None, adapter_indices=None, ): - """Simple dummy model for testing generator behavior. - - In this dummy model, hidden_states directly equal logits (lm_head is identity). - """ batch_size, seq_len = input_ids.shape base = jnp.arange(self.vocab_size, dtype=jnp.float32) if kv_cache is None: - # Prefill: deterministic hidden_states (which equal logits through identity lm_head) + # Prefill: deterministic hidden_states (which equal logits) hidden_states = jnp.tile(base[None, None, :], (batch_size, seq_len, 1)) keys = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] values = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] @@ -50,6 +40,14 @@ def __call__( return CausalLMOutput(last_hidden_state=hidden_states, kv_cache=kv_cache) + def compute_logits(self, hidden_states, adapter_indices=None): + """In dummy model, hidden_states are already logits.""" + return hidden_states + + def compute_logprobs(self, hidden_states, target_ids, chunk_size=0, gradient_checkpointing=False): + """Compute logprobs from hidden_states (which are already logits in dummy model).""" + return self.logits_to_logprobs(hidden_states, target_ids) + def make_inputs(batch_size: int, prompt_length: int): input_ids = jnp.tile(jnp.arange(prompt_length, dtype=jnp.int32)[None, :], (batch_size, 1)) diff --git a/skyrl-tx/tx/models/base.py b/skyrl-tx/tx/models/base.py new file mode 100644 index 000000000..6d8652189 --- /dev/null +++ b/skyrl-tx/tx/models/base.py @@ -0,0 +1,67 @@ +"""Base class for causal language models.""" + +import jax + +from tx.layers.logits_processor import LogitsProcessor + + +class CausalLMBase: + """Base class providing logits/logprobs computation for causal language models. + + Subclasses must set: + - lm_head: The language model head (callable) + - lm_head_weight: The lm_head weight matrix [H, V] + """ + + def compute_logits( + self, + hidden_states: jax.Array, + adapter_indices: jax.Array | None = None, + ) -> jax.Array: + """Compute logits from hidden states. For sampling. + + Args: + hidden_states: Hidden states from model forward [B, T, H]. + adapter_indices: Optional adapter indices for LoRA. + + Returns: + Logits [B, T, V]. + """ + return LogitsProcessor.compute_logits(hidden_states, self.lm_head, adapter_indices) + + def compute_logprobs( + self, + hidden_states: jax.Array, + target_ids: jax.Array, + chunk_size: int = 0, + gradient_checkpointing: bool = False, + ) -> jax.Array: + """Compute logprobs from hidden states. For training and prompt logprobs. + + Supports chunked computation to avoid materializing full [B*T, V] logits. + + Args: + hidden_states: Hidden states [B, T, H]. + target_ids: Target token IDs [B, T]. + chunk_size: Chunk size for chunked computation (0 = non-chunked). + gradient_checkpointing: Whether to checkpoint each chunk. + + Returns: + Log probabilities for target tokens [B, T]. + """ + return LogitsProcessor.compute_logprobs( + hidden_states, self.lm_head_weight, target_ids, chunk_size, gradient_checkpointing + ) + + @staticmethod + def logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: + """Convert logits to logprobs. For decode logprobs when logits already computed. + + Args: + logits: Logits [B, T, V] or [B, V]. + target_ids: Target token IDs [B, T] or [B]. + + Returns: + Log probabilities for target tokens [B, T] or [B]. + """ + return LogitsProcessor.logits_to_logprobs(logits, target_ids) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index d838ffc97..3ede0d727 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -7,6 +7,7 @@ from tx.layers.lora import LoRAEmbed, LoRALinear from tx.layers.rotary_embedding import apply_rope from tx.layers.layernorm import RMSNorm +from tx.models.base import CausalLMBase from tx.models.types import CausalLMOutput, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache, compute_positions @@ -261,7 +262,7 @@ def __call__( ) -class Llama3ForCausalLM(nnx.Module, GeneratorMixin): +class Llama3ForCausalLM(nnx.Module, GeneratorMixin, CausalLMBase): def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index bc58bea83..0c1706857 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -6,6 +6,7 @@ from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear from tx.layers.util import prepare_routing, shard_map_ep from tx.layers.rotary_embedding import apply_rope +from tx.models.base import CausalLMBase from tx.models.configs import Qwen3Config from tx.layers.layernorm import RMSNorm from tx.models.types import CausalLMOutput, ModelOutput @@ -376,7 +377,7 @@ def __call__( ) -class Qwen3ForCausalLM(nnx.Module, GeneratorMixin): +class Qwen3ForCausalLM(nnx.Module, GeneratorMixin, CausalLMBase): def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 7abcebfb8..df35afebe 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -38,7 +38,6 @@ from tx.models.configs import Qwen3Config from tx.layers.lora import clear_lora_adapter, init_lora_adapter -from tx.layers.logits_processor import LogitsProcessor from tx.tinker import types from tx.tinker.backends.backend import AbstractBackend from tx.tinker.backends.utils import pad, pad_batch, pad_to_fsdp @@ -242,27 +241,30 @@ def _create_loss_and_grad_fn(self): loss_chunk_size = self.config.loss_chunk_size if self._use_chunked_loss else 0 gradient_checkpointing = self.config.gradient_checkpointing - def _model_forward( + def _forward_and_logprobs( graphdef: nnx.GraphDef, lora_params: nnx.State, non_lora_params: nnx.State, input_ids: jax.Array, attention_mask: jax.Array, adapter_indices: jax.Array, - ) -> tuple[jax.Array, jax.Array]: - """Forward pass returning (hidden_states, lm_head_weight).""" + target_ids: jax.Array, + ) -> jax.Array: + """Forward pass and logprobs computation.""" model = nnx.merge(graphdef, lora_params, non_lora_params) output = model( input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices, ) - return output.last_hidden_state, model.lm_head_weight + return model.compute_logprobs( + output.last_hidden_state, target_ids, loss_chunk_size, gradient_checkpointing + ) if self.config.gradient_checkpointing: - # Wrap the model forward call to use jax.checkpoint for gradient checkpointing + # Wrap the forward + logprobs call to use jax.checkpoint for gradient checkpointing # policy=None corresponds to full activation recomputation - _model_forward = jax.checkpoint(_model_forward, policy=None) + _forward_and_logprobs = jax.checkpoint(_forward_and_logprobs, policy=None) def loss_for_lora( lora_params: nnx.State, @@ -276,16 +278,8 @@ def loss_for_lora( sampling_logprobs: jax.Array, advantages: jax.Array, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - hidden_states, lm_head_weight = _model_forward( - self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices - ) - - target_logprobs = LogitsProcessor.compute_logprobs( - hidden_states, - lm_head_weight, - target_ids, - loss_chunk_size, - gradient_checkpointing, + target_logprobs = _forward_and_logprobs( + self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices, target_ids ) def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index 431396605..cb83a5cbb 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -113,16 +113,6 @@ def find_string_stop_position( return None -def compute_prompt_logprobs(prefill_logits: jax.Array, input_ids: jax.Array) -> jax.Array: - """Compute log probabilities of prompt tokens from prefill logits""" - # TODO: Optimize memory usage by avoiding allocation of full vocab dimension. - logits_for_prompt = prefill_logits[:, :-1, :] - log_probs = jax.nn.log_softmax(logits_for_prompt, axis=-1) - prompt_tokens = input_ids[:, 1:] - prompt_logprobs = jnp.take_along_axis(log_probs, prompt_tokens[..., None], axis=-1).squeeze(-1) - return prompt_logprobs - - class GeneratorMixin: """Adds autoregressive generation with KV caching to causal language models.""" @@ -151,17 +141,23 @@ def _prefill_and_decode( positions = compute_positions(attention_mask) # Prefill: process full prompt - # Use skip_prompt_logits=True when we don't need prompt_logprobs to save memory outputs = model( input_ids, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, - skip_prompt_logits=not prompt_logprobs, ) + # Compute logits for last position (needed for sampling first token) + last_logits = model.compute_logits(outputs.last_hidden_state[:, -1:, :], adapter_indices)[:, 0, :] + # Compute prompt logprobs if requested - prompt_logprobs_array = compute_prompt_logprobs(outputs.logits, input_ids) if prompt_logprobs else None + if prompt_logprobs: + prompt_logprobs_array = model.compute_logprobs( + outputs.last_hidden_state[:, :-1, :], input_ids[:, 1:] + ) + else: + prompt_logprobs_array = None # Pad KV cache and attention mask kv_cache = outputs.kv_cache.pad_to_length(max_length) @@ -187,8 +183,7 @@ def decode_fn(s: DecodeState, step: jax.Array) -> tuple[DecodeState, tuple[jax.A ) greedy = jnp.argmax(s.logits, axis=-1) next_token = jnp.where(zero_temp_mask[:, None], greedy[:, None], sampled[:, None]) - log_probs = jax.nn.log_softmax(s.logits, axis=-1) - sampled_logprob = jnp.take_along_axis(log_probs, next_token, axis=-1) + sampled_logprob = model.logits_to_logprobs(s.logits, next_token[:, 0])[:, None] # Track first stop token position (-1 means not stopped yet) is_stop = jnp.any(next_token == stop_tokens, axis=1) @@ -204,12 +199,14 @@ def decode_fn(s: DecodeState, step: jax.Array) -> tuple[DecodeState, tuple[jax.A kv_cache=s.kv_cache, adapter_indices=adapter_indices, ) + # Compute logits for the next token + next_logits = model.compute_logits(outputs.last_hidden_state, adapter_indices)[:, 0, :] next_state = DecodeState( kv_cache=outputs.kv_cache, rngs=rngs, attention_mask=next_attention_mask, last_positions=s.last_positions + 1, - logits=outputs.logits[:, -1, :], + logits=next_logits, stop_pos=stop_pos, ) return next_state, (next_token, sampled_logprob) @@ -219,7 +216,7 @@ def decode_fn(s: DecodeState, step: jax.Array) -> tuple[DecodeState, tuple[jax.A rngs=rngs, attention_mask=decode_attention_mask, last_positions=positions[:, -1:], - logits=outputs.logits[:, -1, :], + logits=last_logits, stop_pos=jnp.full((input_ids.shape[0],), -1), ) From 6cbe1cbb071514732414952e9a4b659155739d3a Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 19:40:55 -0800 Subject: [PATCH 023/133] inline logits processor --- skyrl-tx/tests/models/test_qwen3.py | 1 - skyrl-tx/tests/tinker/test_jax_backend.py | 8 +- skyrl-tx/tests/utils/test_generator.py | 19 ++-- skyrl-tx/tx/layers/logits_processor.py | 122 ---------------------- skyrl-tx/tx/models/base.py | 96 ++++++++++++++--- skyrl-tx/tx/tinker/backends/jax.py | 18 ++-- 6 files changed, 108 insertions(+), 156 deletions(-) delete mode 100644 skyrl-tx/tx/layers/logits_processor.py diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index 8a3d5d2a7..cfa57bdd9 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -272,7 +272,6 @@ def test_qwen3_lora(): adapter_indices=adapter_indices, ) - # Compute logits using model.compute_logits logits = model.compute_logits(outputs.last_hidden_state, adapter_indices) # Compare outputs with corresponding adapters diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index 2b8d20e9e..9ba11352e 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -619,15 +619,15 @@ def _run_forward(self, backend: JaxBackend, inputs: tuple): ) return losses, logprobs - def test_fallback_on_train_unembed(self): - """Verify backend switches to non-chunked when train_unembed=True.""" + def test_train_unembed_enables_lora_on_lm_head(self): + """Verify backend enables LoRA on lm_head when train_unembed=True.""" backend = self._create_backend(loss_chunk_size=1024) - assert backend._use_chunked_loss is True + assert backend._has_train_unembed is False lora_config = LoraConfig(rank=8, alpha=16, seed=0, train_unembed=True) backend.create_model("model_with_unembed", lora_config) - assert backend._use_chunked_loss is False + assert backend._has_train_unembed is True @pytest.mark.parametrize( "chunk_size,expected", diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index e2b973e25..b7705813a 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -15,6 +15,17 @@ class DummyModel(GeneratorMixin, CausalLMBase, nnx.Module): def __init__(self, vocab_size: int = 16): self.vocab_size = vocab_size + self._lm_head_weight = jnp.eye(vocab_size, dtype=jnp.float32) + + @property + def lm_head(self): + """Identity lm_head - hidden_states are already logits.""" + return lambda hidden_states, adapter_indices=None: hidden_states + + @property + def lm_head_weight(self) -> jax.Array: + """Identity matrix for dummy model.""" + return self._lm_head_weight def __call__( self, @@ -40,14 +51,6 @@ def __call__( return CausalLMOutput(last_hidden_state=hidden_states, kv_cache=kv_cache) - def compute_logits(self, hidden_states, adapter_indices=None): - """In dummy model, hidden_states are already logits.""" - return hidden_states - - def compute_logprobs(self, hidden_states, target_ids, chunk_size=0, gradient_checkpointing=False): - """Compute logprobs from hidden_states (which are already logits in dummy model).""" - return self.logits_to_logprobs(hidden_states, target_ids) - def make_inputs(batch_size: int, prompt_length: int): input_ids = jnp.tile(jnp.arange(prompt_length, dtype=jnp.int32)[None, :], (batch_size, 1)) diff --git a/skyrl-tx/tx/layers/logits_processor.py b/skyrl-tx/tx/layers/logits_processor.py deleted file mode 100644 index 7555e9871..000000000 --- a/skyrl-tx/tx/layers/logits_processor.py +++ /dev/null @@ -1,122 +0,0 @@ -"""LogitsProcessor for computing logits and logprobs from hidden states.""" - -import jax -import jax.numpy as jnp - - -class LogitsProcessor: - """Utility for computing logits and logprobs from hidden states.""" - - @staticmethod - def compute_logits( - hidden_states: jax.Array, - lm_head, - adapter_indices: jax.Array | None = None, - ) -> jax.Array: - """Compute logits from hidden states. For sampling. - - Args: - hidden_states: Hidden states from the model backbone [B, T, H]. - lm_head: Language model head (LoRALinear or transposed embedding). - adapter_indices: Optional adapter indices for LoRA. - - Returns: - Logits [B, T, V]. - """ - return lm_head(hidden_states, adapter_indices) - - @staticmethod - def compute_logprobs( - hidden_states: jax.Array, - lm_head_weight: jax.Array, - target_ids: jax.Array, - chunk_size: int = 0, - gradient_checkpointing: bool = False, - ) -> jax.Array: - """Compute logprobs from hidden states. For training and prompt logprobs. - - Supports chunked computation to avoid materializing full [B*T, V] logits. - - Args: - hidden_states: Hidden states [B, T, H]. - lm_head_weight: LM head weight matrix [H, V]. - target_ids: Target token IDs [B, T]. - chunk_size: Chunk size for chunked computation (0 = non-chunked). - gradient_checkpointing: Whether to checkpoint each chunk. - - Returns: - Log probabilities for target tokens [B, T]. - """ - if chunk_size > 0: - return LogitsProcessor._compute_chunked_logprobs( - hidden_states, lm_head_weight, target_ids, chunk_size, gradient_checkpointing - ) - else: - logits = hidden_states @ lm_head_weight - return LogitsProcessor.logits_to_logprobs(logits, target_ids) - - @staticmethod - def logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: - """Convert logits to logprobs. For decode logprobs when logits already computed. - - Args: - logits: Logits [B, T, V] or [B, V]. - target_ids: Target token IDs [B, T] or [B]. - - Returns: - Log probabilities for target tokens [B, T] or [B]. - """ - log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) - target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) - return (target_logits - log_sum_exp).squeeze(-1) - - @staticmethod - def _compute_chunked_logprobs( - hidden_states: jax.Array, - lm_head_weight: jax.Array, - target_ids: jax.Array, - chunk_size: int, - gradient_checkpointing: bool, - ) -> jax.Array: - """Compute log probabilities using chunked lm_head computation. - - This avoids materializing the full [B*T, V] logits tensor by computing - lm_head and log probabilities for each chunk sequentially. - """ - B, T, H = hidden_states.shape - total_tokens = B * T - - # Flatten batch and sequence dimensions - flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] - flat_target_ids = target_ids.reshape(-1) # [B*T] - - # Pad to multiple of chunk_size for clean slicing - num_chunks = (total_tokens + chunk_size - 1) // chunk_size - padded_size = num_chunks * chunk_size - pad_amount = padded_size - total_tokens - - if pad_amount > 0: - flat_hidden = jnp.pad(flat_hidden, ((0, pad_amount), (0, 0))) - flat_target_ids = jnp.pad(flat_target_ids, (0, pad_amount)) - - # Reshape into chunks: [num_chunks, chunk_size, H] and [num_chunks, chunk_size] - chunked_hidden = flat_hidden.reshape(num_chunks, chunk_size, H) - chunked_targets = flat_target_ids.reshape(num_chunks, chunk_size) - - def compute_chunk_logprobs(args): - """Compute lm_head and log probabilities for a chunk of tokens.""" - chunk_hidden, chunk_targets = args - # Compute logits for this chunk only: [chunk_size, H] @ [H, V] = [chunk_size, V] - chunk_logits = chunk_hidden @ lm_head_weight - # Compute log probabilities - log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) - target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) - return (target_logits - log_sum_exp).squeeze(-1) - - if gradient_checkpointing: - compute_chunk_logprobs = jax.checkpoint(compute_chunk_logprobs, policy=None) - - # Process chunks sequentially using lax.map (not vmap) to reduce memory - all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets)) - # Flatten and slice to original size, then reshape to [B, T] - return all_logprobs.reshape(-1)[:total_tokens].reshape(B, T) diff --git a/skyrl-tx/tx/models/base.py b/skyrl-tx/tx/models/base.py index 6d8652189..613efb6f6 100644 --- a/skyrl-tx/tx/models/base.py +++ b/skyrl-tx/tx/models/base.py @@ -1,17 +1,27 @@ """Base class for causal language models.""" +from abc import abstractmethod + import jax +import jax.numpy as jnp -from tx.layers.logits_processor import LogitsProcessor +from tx.layers.lora import LoRALinear class CausalLMBase: - """Base class providing logits/logprobs computation for causal language models. + """Base class providing logits/logprobs computation for causal language models.""" + + @property + @abstractmethod + def lm_head(self) -> LoRALinear: + """Language model head. LoRALinear or transposed LoRAEmbed.""" + ... - Subclasses must set: - - lm_head: The language model head (callable) - - lm_head_weight: The lm_head weight matrix [H, V] - """ + @property + @abstractmethod + def lm_head_weight(self) -> jax.Array: + """LM head weight matrix [H, V] for efficient chunked computation.""" + ... def compute_logits( self, @@ -27,31 +37,38 @@ def compute_logits( Returns: Logits [B, T, V]. """ - return LogitsProcessor.compute_logits(hidden_states, self.lm_head, adapter_indices) + return self.lm_head(hidden_states, adapter_indices) def compute_logprobs( self, hidden_states: jax.Array, target_ids: jax.Array, + adapter_indices: jax.Array | None = None, chunk_size: int = 0, gradient_checkpointing: bool = False, ) -> jax.Array: """Compute logprobs from hidden states. For training and prompt logprobs. - Supports chunked computation to avoid materializing full [B*T, V] logits. - Args: hidden_states: Hidden states [B, T, H]. target_ids: Target token IDs [B, T]. + adapter_indices: Adapter indices for LoRA on lm_head. + Pass when train_unembed=True. Forces non-chunked path. chunk_size: Chunk size for chunked computation (0 = non-chunked). gradient_checkpointing: Whether to checkpoint each chunk. Returns: Log probabilities for target tokens [B, T]. """ - return LogitsProcessor.compute_logprobs( - hidden_states, self.lm_head_weight, target_ids, chunk_size, gradient_checkpointing - ) + # Chunked path doesn't support LoRA on lm_head + use_chunk = chunk_size > 0 and adapter_indices is None + if use_chunk: + return self._compute_chunked_logprobs( + hidden_states, target_ids, chunk_size, gradient_checkpointing + ) + else: + logits = self.compute_logits(hidden_states, adapter_indices) + return self.logits_to_logprobs(logits, target_ids) @staticmethod def logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: @@ -64,4 +81,57 @@ def logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: Returns: Log probabilities for target tokens [B, T] or [B]. """ - return LogitsProcessor.logits_to_logprobs(logits, target_ids) + log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) + target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) + return (target_logits - log_sum_exp).squeeze(-1) + + def _compute_chunked_logprobs( + self, + hidden_states: jax.Array, + target_ids: jax.Array, + chunk_size: int, + gradient_checkpointing: bool, + ) -> jax.Array: + """Compute log probabilities using chunked lm_head computation. + + This avoids materializing the full [B*T, V] logits tensor by computing + lm_head and log probabilities for each chunk sequentially. + """ + B, T, H = hidden_states.shape + total_tokens = B * T + lm_head_weight = self.lm_head_weight + + # Flatten batch and sequence dimensions + flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] + flat_target_ids = target_ids.reshape(-1) # [B*T] + + # Pad to multiple of chunk_size for clean slicing + num_chunks = (total_tokens + chunk_size - 1) // chunk_size + padded_size = num_chunks * chunk_size + pad_amount = padded_size - total_tokens + + if pad_amount > 0: + flat_hidden = jnp.pad(flat_hidden, ((0, pad_amount), (0, 0))) + flat_target_ids = jnp.pad(flat_target_ids, (0, pad_amount)) + + # Reshape into chunks: [num_chunks, chunk_size, H] and [num_chunks, chunk_size] + chunked_hidden = flat_hidden.reshape(num_chunks, chunk_size, H) + chunked_targets = flat_target_ids.reshape(num_chunks, chunk_size) + + def compute_chunk_logprobs(args): + """Compute lm_head and log probabilities for a chunk of tokens.""" + chunk_hidden, chunk_targets = args + # Compute logits for this chunk only: [chunk_size, H] @ [H, V] = [chunk_size, V] + chunk_logits = chunk_hidden @ lm_head_weight + # Compute log probabilities + log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) + target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) + return (target_logits - log_sum_exp).squeeze(-1) + + if gradient_checkpointing: + compute_chunk_logprobs = jax.checkpoint(compute_chunk_logprobs, policy=None) + + # Process chunks sequentially using lax.map (not vmap) to reduce memory + all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets)) + # Flatten and slice to original size, then reshape to [B, T] + return all_logprobs.reshape(-1)[:total_tokens].reshape(B, T) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index df35afebe..f4224cd26 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -205,10 +205,9 @@ def __init__(self, base_model: str, config: JaxBackendConfig): ) # Use chunked cross-entropy by default for memory efficiency. - # Falls back to non-chunked when: - # - loss_chunk_size <= 0 (disabled via config) - # - any model uses train_unembed=True (chunked path doesn't apply LoRA to lm_head) self._use_chunked_loss = config.loss_chunk_size > 0 + # Track if any model uses train_unembed=True (requires LoRA on lm_head) + self._has_train_unembed = False logger.info(f"Chunked cross-entropy loss: {self._use_chunked_loss} (chunk_size={config.loss_chunk_size})") self._create_loss_and_grad_fn() @@ -240,6 +239,7 @@ def _create_loss_and_grad_fn(self): """Compile and cache the loss function to avoid re-jitting on every call.""" loss_chunk_size = self.config.loss_chunk_size if self._use_chunked_loss else 0 gradient_checkpointing = self.config.gradient_checkpointing + has_train_unembed = self._has_train_unembed def _forward_and_logprobs( graphdef: nnx.GraphDef, @@ -257,8 +257,10 @@ def _forward_and_logprobs( attention_mask=attention_mask, adapter_indices=adapter_indices, ) + # Pass adapter_indices when train_unembed=True to apply LoRA on lm_head + lm_head_adapter_indices = adapter_indices if has_train_unembed else None return model.compute_logprobs( - output.last_hidden_state, target_ids, loss_chunk_size, gradient_checkpointing + output.last_hidden_state, target_ids, lm_head_adapter_indices, loss_chunk_size, gradient_checkpointing ) if self.config.gradient_checkpointing: @@ -451,10 +453,10 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: if not (0 < lora_config.rank <= self.config.max_lora_rank): raise ValueError(f"LoRA rank {lora_config.rank} must be between 1 and {self.config.max_lora_rank}") - # Switch to non-chunked loss if train_unembed=True (chunked doesn't apply LoRA to lm_head) - if lora_config.train_unembed and self._use_chunked_loss: - logger.info("Switching to non-chunked loss mode (train_unembed=True requires LoRA on lm_head)") - self._use_chunked_loss = False + # Enable LoRA on lm_head path when train_unembed=True + if lora_config.train_unembed and not self._has_train_unembed: + logger.info("Enabling LoRA on lm_head (train_unembed=True)") + self._has_train_unembed = True self._create_loss_and_grad_fn() # Store model metadata From cd2fd4e5fa79651745a8d600e5ab5743b3b77b18 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 19:55:41 -0800 Subject: [PATCH 024/133] refactor: runtime train_unembed check with per-adapter mask - Replace _has_train_unembed flag with _train_unembed_mask array - Check at runtime if any adapter in batch needs LoRA on lm_head - Use jax.lax.cond to choose chunked vs non-chunked path - Handle adapter reuse correctly (reset mask on delete) - Remove unused _use_chunked_loss flag Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/tinker/test_jax_backend.py | 86 ++++++++++++++++------- skyrl-tx/tx/tinker/backends/jax.py | 36 +++++----- 2 files changed, 78 insertions(+), 44 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index 9ba11352e..baeabe438 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -619,29 +619,6 @@ def _run_forward(self, backend: JaxBackend, inputs: tuple): ) return losses, logprobs - def test_train_unembed_enables_lora_on_lm_head(self): - """Verify backend enables LoRA on lm_head when train_unembed=True.""" - backend = self._create_backend(loss_chunk_size=1024) - assert backend._has_train_unembed is False - - lora_config = LoraConfig(rank=8, alpha=16, seed=0, train_unembed=True) - backend.create_model("model_with_unembed", lora_config) - - assert backend._has_train_unembed is True - - @pytest.mark.parametrize( - "chunk_size,expected", - [ - (0, False), # Disabled - (-1, False), # Disabled - (1024, True), # Enabled - ], - ) - def test_use_chunked_loss_config(self, chunk_size, expected): - """Verify _use_chunked_loss is set correctly based on loss_chunk_size.""" - backend = self._create_backend(loss_chunk_size=chunk_size) - assert backend._use_chunked_loss is expected - @pytest.mark.parametrize( "batch_size,seq_len,chunk_size", [ @@ -659,8 +636,8 @@ def test_chunked_vs_nonchunked_logprobs(self, batch_size, seq_len, chunk_size): backend_chunked = self._create_backend(loss_chunk_size=chunk_size) backend_nonchunked = self._create_backend(loss_chunk_size=0) - assert backend_chunked._use_chunked_loss is True - assert backend_nonchunked._use_chunked_loss is False + assert backend_chunked.config.loss_chunk_size > 0 + assert backend_nonchunked.config.loss_chunk_size == 0 inputs = self._create_inputs(backend_chunked, batch_size, seq_len) losses_chunked, logprobs_chunked = self._run_forward(backend_chunked, inputs) @@ -680,3 +657,62 @@ def test_chunked_vs_nonchunked_logprobs(self, batch_size, seq_len, chunk_size): atol=1e-4, err_msg=f"Losses mismatch for batch_size={batch_size}, seq_len={seq_len}, chunk_size={chunk_size}", ) + + def test_mixed_train_unembed_adapters(self): + """Test that chunked and non-chunked paths produce same results with mixed adapters.""" + backend_chunked = self._create_backend(loss_chunk_size=1024) + backend_nonchunked = self._create_backend(loss_chunk_size=0) + + # Create same models on both backends + for backend in [backend_chunked, backend_nonchunked]: + backend.create_model("model_normal", LoraConfig(rank=8, alpha=16, seed=0, train_unembed=False)) + backend.create_model("model_unembed", LoraConfig(rank=8, alpha=16, seed=1, train_unembed=True)) + + normal_idx = backend_chunked.models["model_normal"].adapter_index + unembed_idx = backend_chunked.models["model_unembed"].adapter_index + + batch_size, seq_len = 2, 16 + vocab = backend_chunked.model.config.vocab_size + input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + target_ids = (input_ids + 1) % vocab + loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) + loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) + sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + + def run_forward(backend, adapter_indices): + _, losses, logprobs = backend._forward( + backend.accumulated_grads, + backend.lora_params, + backend.non_lora_params, + input_ids, + attention_mask, + adapter_indices, + target_ids, + loss_mask, + loss_fn_types, + sampling_logprobs, + advantages, + ) + return losses, logprobs + + # Test with mixed adapters: one normal, one unembed + adapter_indices = jnp.array([normal_idx, unembed_idx], dtype=jnp.int32) + losses_chunked, logprobs_chunked = run_forward(backend_chunked, adapter_indices) + losses_nonchunked, logprobs_nonchunked = run_forward(backend_nonchunked, adapter_indices) + + np.testing.assert_allclose( + np.asarray(logprobs_chunked), + np.asarray(logprobs_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg="Chunked vs non-chunked logprobs mismatch with mixed train_unembed adapters", + ) + np.testing.assert_allclose( + np.asarray(losses_chunked), + np.asarray(losses_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg="Chunked vs non-chunked losses mismatch with mixed train_unembed adapters", + ) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index f4224cd26..4bc606451 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -204,11 +204,8 @@ def __init__(self, base_model: str, config: JaxBackendConfig): f"max_lora_adapters={config.max_lora_adapters}, max_lora_rank={config.max_lora_rank}" ) - # Use chunked cross-entropy by default for memory efficiency. - self._use_chunked_loss = config.loss_chunk_size > 0 - # Track if any model uses train_unembed=True (requires LoRA on lm_head) - self._has_train_unembed = False - logger.info(f"Chunked cross-entropy loss: {self._use_chunked_loss} (chunk_size={config.loss_chunk_size})") + # Track which adapters use train_unembed=True (requires LoRA on lm_head) + self._train_unembed_mask = jnp.zeros(config.max_lora_adapters, dtype=jnp.bool_) self._create_loss_and_grad_fn() def _micro_batch_size(self, total: int) -> int: @@ -237,9 +234,8 @@ def _jit_timing_context(self, seq_len: int, mode: str): def _create_loss_and_grad_fn(self): """Compile and cache the loss function to avoid re-jitting on every call.""" - loss_chunk_size = self.config.loss_chunk_size if self._use_chunked_loss else 0 + loss_chunk_size = self.config.loss_chunk_size gradient_checkpointing = self.config.gradient_checkpointing - has_train_unembed = self._has_train_unembed def _forward_and_logprobs( graphdef: nnx.GraphDef, @@ -249,6 +245,7 @@ def _forward_and_logprobs( attention_mask: jax.Array, adapter_indices: jax.Array, target_ids: jax.Array, + train_unembed_mask: jax.Array, ) -> jax.Array: """Forward pass and logprobs computation.""" model = nnx.merge(graphdef, lora_params, non_lora_params) @@ -257,11 +254,13 @@ def _forward_and_logprobs( attention_mask=attention_mask, adapter_indices=adapter_indices, ) - # Pass adapter_indices when train_unembed=True to apply LoRA on lm_head - lm_head_adapter_indices = adapter_indices if has_train_unembed else None - return model.compute_logprobs( - output.last_hidden_state, target_ids, lm_head_adapter_indices, loss_chunk_size, gradient_checkpointing - ) + # Check at runtime if any adapter in batch needs LoRA on lm_head + needs_lm_head_lora = train_unembed_mask[adapter_indices].any() + def logprobs(lm_head_adapter_indices): + return model.compute_logprobs( + output.last_hidden_state, target_ids, lm_head_adapter_indices, loss_chunk_size, gradient_checkpointing + ) + return jax.lax.cond(needs_lm_head_lora, lambda: logprobs(adapter_indices), lambda: logprobs(None)) if self.config.gradient_checkpointing: # Wrap the forward + logprobs call to use jax.checkpoint for gradient checkpointing @@ -281,7 +280,8 @@ def loss_for_lora( advantages: jax.Array, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: target_logprobs = _forward_and_logprobs( - self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices, target_ids + self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices, target_ids, + self._train_unembed_mask, ) def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): @@ -453,11 +453,8 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: if not (0 < lora_config.rank <= self.config.max_lora_rank): raise ValueError(f"LoRA rank {lora_config.rank} must be between 1 and {self.config.max_lora_rank}") - # Enable LoRA on lm_head path when train_unembed=True - if lora_config.train_unembed and not self._has_train_unembed: - logger.info("Enabling LoRA on lm_head (train_unembed=True)") - self._has_train_unembed = True - self._create_loss_and_grad_fn() + # Set train_unembed mask for this adapter + self._train_unembed_mask = self._train_unembed_mask.at[adapter_index].set(lora_config.train_unembed) # Store model metadata self.models[model_id] = types.ModelMetadata( @@ -482,9 +479,10 @@ def delete_model(self, model_id: str) -> None: # Get adapter index before deleting metadata adapter_index = self.models[model_id].adapter_index - # Clear LoRA adapter weights + # Clear LoRA adapter weights and reset train_unembed mask with jax.set_mesh(self.mesh): clear_lora_adapter(self.model, adapter_index) + self._train_unembed_mask = self._train_unembed_mask.at[adapter_index].set(False) # Delete optimizer del self.optimizers[model_id] From f9cb17718f258d2234d30a7c33835246b6058e02 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 20:13:29 -0800 Subject: [PATCH 025/133] refactor: explicit CausalLMBase.__init__ for lm_head - Replace abstract property with __init__(lm_head) in base class - Subclasses explicitly call CausalLMBase.__init__(self, lm_head) - Fix test to support multiple adapters for mixed train_unembed test Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/tinker/test_jax_backend.py | 8 ++++---- skyrl-tx/tx/models/base.py | 7 ++----- skyrl-tx/tx/models/llama3.py | 7 ++++--- skyrl-tx/tx/models/qwen3.py | 7 ++++--- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index baeabe438..edf91b0db 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -561,10 +561,10 @@ def test_adapter_reuse_initializes_lora_adapter(): class TestChunkedCrossEntropyLoss: """Tests for chunked cross-entropy loss computation.""" - def _create_backend(self, loss_chunk_size: int) -> JaxBackend: + def _create_backend(self, loss_chunk_size: int, max_lora_adapters: int = 2) -> JaxBackend: """Create a backend with specified chunk size.""" config = JaxBackendConfig( - max_lora_adapters=2, + max_lora_adapters=max_lora_adapters, max_lora_rank=32, loss_chunk_size=loss_chunk_size, ) @@ -660,8 +660,8 @@ def test_chunked_vs_nonchunked_logprobs(self, batch_size, seq_len, chunk_size): def test_mixed_train_unembed_adapters(self): """Test that chunked and non-chunked paths produce same results with mixed adapters.""" - backend_chunked = self._create_backend(loss_chunk_size=1024) - backend_nonchunked = self._create_backend(loss_chunk_size=0) + backend_chunked = self._create_backend(loss_chunk_size=1024, max_lora_adapters=3) + backend_nonchunked = self._create_backend(loss_chunk_size=0, max_lora_adapters=3) # Create same models on both backends for backend in [backend_chunked, backend_nonchunked]: diff --git a/skyrl-tx/tx/models/base.py b/skyrl-tx/tx/models/base.py index 613efb6f6..41cb4ce16 100644 --- a/skyrl-tx/tx/models/base.py +++ b/skyrl-tx/tx/models/base.py @@ -11,11 +11,8 @@ class CausalLMBase: """Base class providing logits/logprobs computation for causal language models.""" - @property - @abstractmethod - def lm_head(self) -> LoRALinear: - """Language model head. LoRALinear or transposed LoRAEmbed.""" - ... + def __init__(self, lm_head: LoRALinear): + self.lm_head = lm_head @property @abstractmethod diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 3ede0d727..2db911b92 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -268,10 +268,10 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.config = config self.model = Llama3Model(config, dtype=dtype, rngs=rngs) - if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens.T + if config.tie_word_embeddings: + lm_head = self.model.embed_tokens.T else: - self.lm_head = LoRALinear( + lm_head = LoRALinear( config.hidden_size, config.vocab_size, use_bias=False, @@ -282,6 +282,7 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_rank=config.max_lora_rank, rngs=rngs, ) + CausalLMBase.__init__(self, lm_head) @staticmethod def is_lora_param(path: tuple, _value) -> bool: diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 0c1706857..f720163b9 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -383,10 +383,10 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.config = config self.model = Qwen3Model(config, dtype=dtype, rngs=rngs) - if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens.T + if config.tie_word_embeddings: + lm_head = self.model.embed_tokens.T else: - self.lm_head = LoRALinear( + lm_head = LoRALinear( config.hidden_size, config.vocab_size, use_bias=False, @@ -397,6 +397,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_rank=config.max_lora_rank, rngs=rngs, ) + CausalLMBase.__init__(self, lm_head) @staticmethod def is_lora_param(path: tuple, _value) -> bool: From 929d96b7f92825066b8963c1c9aee86ec0898f2f Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 20:13:40 -0800 Subject: [PATCH 026/133] remove doc --- skyrl-tx/docs/design/logits_processor.md | 199 ----------------------- 1 file changed, 199 deletions(-) delete mode 100644 skyrl-tx/docs/design/logits_processor.md diff --git a/skyrl-tx/docs/design/logits_processor.md b/skyrl-tx/docs/design/logits_processor.md deleted file mode 100644 index e82a37a3d..000000000 --- a/skyrl-tx/docs/design/logits_processor.md +++ /dev/null @@ -1,199 +0,0 @@ -# LogitsProcessor Design - -## Overview - -This document proposes a design for `LogitsProcessor` - a utility for computing logits and log probabilities from model hidden states. - -## Background - -In causal language models, the forward pass produces hidden states `[B, T, H]` which must be projected to vocabulary logits `[B, T, V]` via the `lm_head` layer. Different scenarios have different requirements: - -### Training - -Compute logprobs for all positions to calculate loss. - -``` -hidden_states [B, T, H] → logprobs [B, T] → loss -``` - -Full logits `[B, T, V]` are not needed - we only need logprobs of target tokens. This enables **chunked computation**: process tokens in chunks, compute logits and extract logprobs per chunk, avoiding full `[B*T, V]` materialization. - -### Inference: Prefill - -Process the prompt. Return logits for the last position (to start decoding). Optionally return logprobs of prompt tokens. - -``` -hidden_states [B, T, H] → logits [B, 1, V] (last position, for sampling) - → logprobs [B, T-1] (optional, for prompt logprobs) -``` - -For prompt logprobs, same as training - full logits not needed, can use chunked computation. - -### Inference: Decode - -Generate one token at a time. - -1. **Compute logits:** `hidden_states [B, 1, H] → logits [B, 1, V]` -2. **Apply sampling transforms:** temperature scaling, top_k filtering, top_p filtering on logits -3. **Sample:** draw next_token from the transformed distribution -4. **Extract logprob:** get log probability of the sampled token from original logits - -**Full logits required** because step 2 operates on the full vocabulary distribution. - -## Existing Designs - -### SGLang - -**Pattern:** LogitsProcessor as a model attribute, called inside `model.forward()`. - -**Key files:** -- [LogitsProcessor class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/logits_processor.py#L235) -- [LlamaForCausalLM.forward()](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py#L499) calls [logits_processor](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py#L522) - -```python -class LlamaForCausalLM(nn.Module): - def __init__(self, ...): - self.logits_processor = LogitsProcessor(config) - - def forward(self, input_ids, positions, forward_batch, ...) -> LogitsProcessorOutput: - hidden_states = self.model(input_ids, ...) - return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch, ...) -``` - -**Problems:** - -1. **Wrapper pattern:** `forward()` just returns `logits_processor(...)` output. No encapsulation benefit. - -2. **Inconsistent return types:** `forward()` returns [different types](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py#L520-L532) based on runtime conditions (LogitsProcessorOutput, PoolerOutput, or Tensor). - -3. **God object:** [LogitsProcessor.forward()](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/logits_processor.py#L379) is 500+ lines handling many modes through complex branching. - -### vLLM - -**Pattern:** LogitsProcessor as a model attribute, called via separate `compute_logits()` method. - -**Key files:** -- [LogitsProcessor class](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/logits_processor.py#L18) -- [LlamaForCausalLM.compute_logits()](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py#L640) -- [model_runner calls compute_logits()](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu_model_runner.py#L3336) - -```python -class LlamaForCausalLM(nn.Module): - def __init__(self, ...): - self.logits_processor = LogitsProcessor(vocab_size, scale=logit_scale) - - def forward(self, input_ids, positions, ...) -> Tensor: - return self.model(input_ids, positions, ...) # returns hidden_states - - def compute_logits(self, hidden_states) -> Tensor: - return self.logits_processor(self.lm_head, hidden_states) -``` - -**Improvements over SGLang:** -- `forward()` has single responsibility (returns hidden_states) -- Logits computation is explicit via separate method - -**Remaining Problems:** - -1. **Still a wrapper:** `compute_logits()` just wraps `self.logits_processor(...)`. - -2. **Unnecessary model attribute:** `logits_processor` stores minimal state. Could be a static utility. - -3. **No logprobs support:** Only computes logits. Logprobs computation happens elsewhere. - -## Proposed Design - -### Principles - -1. **Standalone utility** - Not a model attribute -2. **Model returns hidden_states** - Single responsibility, consistent return type -3. **Caller decides what to compute** - Logits for sampling, logprobs for training -4. **Unified logprobs API** - Same method for training and prompt logprobs - -### Architecture - -``` -┌─────────────────────────────────────────────────────────────────┐ -│ Caller │ -│ (JaxBackend for training, Generator for sampling) │ -└─────────────────────────────────────────────────────────────────┘ - │ │ - │ model(input_ids, ...) │ LogitsProcessor.*() - ▼ ▼ -┌───────────────────────────┐ ┌───────────────────────────────┐ -│ CausalLM Model │ │ LogitsProcessor │ -│ │ │ │ -│ forward() → hidden_states│ │ compute_logits() │ -│ lm_head property │ │ compute_logprobs() │ -└───────────────────────────┘ │ logits_to_logprobs() │ - └───────────────────────────────┘ -``` - -### API - -```python -class LogitsProcessor: - """Utility for computing logits and logprobs from hidden states.""" - - @staticmethod - def compute_logits(hidden_states, lm_head, adapter_indices=None) -> jax.Array: - """Compute logits from hidden states. For sampling.""" - - @staticmethod - def compute_logprobs(hidden_states, lm_head, target_ids, adapter_indices=None, - chunk_size=0, gradient_checkpointing=False) -> jax.Array: - """Compute logprobs from hidden states. For training and prompt logprobs. - - Supports chunked computation to avoid materializing full [B*T, V] logits. - """ - - @staticmethod - def logits_to_logprobs(logits, target_ids) -> jax.Array: - """Convert logits to logprobs. For decode logprobs when logits already computed.""" -``` - -### Usage - -**Training:** -```python -output = model(input_ids, attention_mask=attention_mask, ...) -logprobs = LogitsProcessor.compute_logprobs( - output.last_hidden_state, model.lm_head, target_ids, - chunk_size=1024, gradient_checkpointing=True -) -loss = compute_loss(logprobs, ...) -``` - -**Sampling (prompt logprobs):** -```python -output = model(input_ids, attention_mask=attention_mask, ...) -prompt_logprobs = LogitsProcessor.compute_logprobs( - output.last_hidden_state, model.lm_head, input_ids[:, 1:], - chunk_size=1024 -) -``` - -**Sampling (decode):** -```python -output = model(next_token, kv_cache=kv_cache, ...) -logits = LogitsProcessor.compute_logits(output.last_hidden_state, model.lm_head) -next_token = sample(logits, temperature, top_k, top_p) -logprob = LogitsProcessor.logits_to_logprobs(logits, next_token) -``` - -### Benefits - -1. **Separation of concerns** - Model produces hidden states, LogitsProcessor transforms them -2. **Consistent model interface** - forward() always returns hidden_states -3. **Unified logprobs** - Same API for training and prompt logprobs -4. **Reduced code duplication** - Currently, logprobs computation is duplicated in `generator.py` (`compute_prompt_logprobs`) and `jax.py` backend (chunked loss). This design consolidates both into `LogitsProcessor.compute_logprobs()` -5. **Testable** - Easy to unit test with mock inputs - -### Migration Path - -1. Update `LogitsProcessor` to standalone utility with three methods -2. Update model to return hidden_states only (remove `skip_logits`, `skip_prompt_logits` flags) -3. Update generator to use `LogitsProcessor.compute_logits()` and `compute_logprobs()` -4. Update backend to use `LogitsProcessor.compute_logprobs()` -5. Remove `logits_processor` attribute from model classes -6. Simplify `CausalLMOutput` (remove `logits`, `lm_head` fields) From b1254c69801e827fda7eb0b73c0e9065365a232c Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 20:14:20 -0800 Subject: [PATCH 027/133] rename test_logits_processor to test_compute_logits Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_models_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 247856665..df7ee22bd 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -22,8 +22,8 @@ ], ids=["llama3", "qwen3"], ) -def test_logits_processor(model_name, config_cls, model_cls, mesh_axes): - """Test that LogitsProcessor computes correct logits and logprobs.""" +def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): + """Test that model.compute_logits computes correct logits.""" tokenizer = AutoTokenizer.from_pretrained(model_name) hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) From 1ad161201fbbe626c225d9276e9829d465d4146e Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 20:17:10 -0800 Subject: [PATCH 028/133] fix: DummyModel calls CausalLMBase.__init__ Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/utils/test_generator.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index b7705813a..20a92176c 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -16,11 +16,8 @@ class DummyModel(GeneratorMixin, CausalLMBase, nnx.Module): def __init__(self, vocab_size: int = 16): self.vocab_size = vocab_size self._lm_head_weight = jnp.eye(vocab_size, dtype=jnp.float32) - - @property - def lm_head(self): - """Identity lm_head - hidden_states are already logits.""" - return lambda hidden_states, adapter_indices=None: hidden_states + # Identity lm_head - hidden_states are already logits + CausalLMBase.__init__(self, lambda hidden_states, adapter_indices=None: hidden_states) @property def lm_head_weight(self) -> jax.Array: From 9e396a3f899f46c92cdee4ba71b45ff6e9e15b9c Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 20:26:15 -0800 Subject: [PATCH 029/133] refactor: remove ModelForCausalLM Protocol, use CausalLMBase Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/layers/lora.py | 10 +++++++--- skyrl-tx/tx/models/types.py | 6 ------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 4ee0741d0..dd156f8ea 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -1,12 +1,16 @@ +from typing import TYPE_CHECKING + from flax import nnx import jax from jax import numpy as jnp from tx.utils.models import filter_lora from tx.layers.util import Param, prepare_routing, ragged_dot -from tx.models.types import ModelForCausalLM from tx.tinker.types import LoraConfig +if TYPE_CHECKING: + from tx.models.base import CausalLMBase + class LoRAMixin: """A mixin for flax NNX modules to add multi-adapter LoRA support. @@ -286,7 +290,7 @@ def __call__( return base_out + lora_output -def init_lora_adapter(model: ModelForCausalLM, adapter_index: int, lora_config: LoraConfig): +def init_lora_adapter(model: "CausalLMBase", adapter_index: int, lora_config: LoraConfig): """Initialize a LoRA adapter for training. Initializes the adapter: lora_A with he_uniform, lora_B with zeros, @@ -335,7 +339,7 @@ def init_adapter(path, value): nnx.update(model, updated_state) -def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int): +def clear_lora_adapter(model: "CausalLMBase", adapter_index: int): """Clear/reset a LoRA adapter, freeing it for reuse. Sets rank=0, scaling=0, and zeros out lora_A and lora_B for the adapter. diff --git a/skyrl-tx/tx/models/types.py b/skyrl-tx/tx/models/types.py index be60f6ec9..d038d1a47 100644 --- a/skyrl-tx/tx/models/types.py +++ b/skyrl-tx/tx/models/types.py @@ -2,18 +2,12 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Protocol import jax -from transformers import PretrainedConfig from tx.utils.generator import KVCache -class ModelForCausalLM(Protocol): - config: PretrainedConfig - - @jax.tree_util.register_dataclass @dataclass class ModelOutput: From 345114436bff1ac1274881bf8d332fe1e0f9eccb Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 20:28:18 -0800 Subject: [PATCH 030/133] refactor: move config to CausalLMBase.__init__ Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/utils/test_generator.py | 2 +- skyrl-tx/tx/models/base.py | 4 +++- skyrl-tx/tx/models/llama3.py | 3 +-- skyrl-tx/tx/models/qwen3.py | 3 +-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 20a92176c..3383b8e89 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -17,7 +17,7 @@ def __init__(self, vocab_size: int = 16): self.vocab_size = vocab_size self._lm_head_weight = jnp.eye(vocab_size, dtype=jnp.float32) # Identity lm_head - hidden_states are already logits - CausalLMBase.__init__(self, lambda hidden_states, adapter_indices=None: hidden_states) + CausalLMBase.__init__(self, None, lambda hidden_states, adapter_indices=None: hidden_states) @property def lm_head_weight(self) -> jax.Array: diff --git a/skyrl-tx/tx/models/base.py b/skyrl-tx/tx/models/base.py index 41cb4ce16..e48a64034 100644 --- a/skyrl-tx/tx/models/base.py +++ b/skyrl-tx/tx/models/base.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp +from transformers import PretrainedConfig from tx.layers.lora import LoRALinear @@ -11,7 +12,8 @@ class CausalLMBase: """Base class providing logits/logprobs computation for causal language models.""" - def __init__(self, lm_head: LoRALinear): + def __init__(self, config: PretrainedConfig, lm_head: LoRALinear): + self.config = config self.lm_head = lm_head @property diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 2db911b92..8aa6363ad 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -265,7 +265,6 @@ def __call__( class Llama3ForCausalLM(nnx.Module, GeneratorMixin, CausalLMBase): def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: - self.config = config self.model = Llama3Model(config, dtype=dtype, rngs=rngs) if config.tie_word_embeddings: @@ -282,7 +281,7 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_rank=config.max_lora_rank, rngs=rngs, ) - CausalLMBase.__init__(self, lm_head) + CausalLMBase.__init__(self, config, lm_head) @staticmethod def is_lora_param(path: tuple, _value) -> bool: diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index f720163b9..a1fd95910 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -380,7 +380,6 @@ def __call__( class Qwen3ForCausalLM(nnx.Module, GeneratorMixin, CausalLMBase): def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: - self.config = config self.model = Qwen3Model(config, dtype=dtype, rngs=rngs) if config.tie_word_embeddings: @@ -397,7 +396,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_rank=config.max_lora_rank, rngs=rngs, ) - CausalLMBase.__init__(self, lm_head) + CausalLMBase.__init__(self, config, lm_head) @staticmethod def is_lora_param(path: tuple, _value) -> bool: From 4a63a2bad4c8bec05415dbeaccf4e1c4ce28553a Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 21 Jan 2026 20:31:40 -0800 Subject: [PATCH 031/133] fix: lm_head type is Callable, not LoRALinear When tie_word_embeddings=True, lm_head is a lambda from LoRAEmbed.T Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/models/base.py b/skyrl-tx/tx/models/base.py index e48a64034..9d82ece2d 100644 --- a/skyrl-tx/tx/models/base.py +++ b/skyrl-tx/tx/models/base.py @@ -1,18 +1,21 @@ """Base class for causal language models.""" from abc import abstractmethod +from typing import Callable import jax import jax.numpy as jnp from transformers import PretrainedConfig -from tx.layers.lora import LoRALinear + +# lm_head: (hidden_states, adapter_indices) -> logits +LMHead = Callable[[jax.Array, jax.Array | None], jax.Array] class CausalLMBase: """Base class providing logits/logprobs computation for causal language models.""" - def __init__(self, config: PretrainedConfig, lm_head: LoRALinear): + def __init__(self, config: PretrainedConfig, lm_head: LMHead): self.config = config self.lm_head = lm_head From 2789a48d86a1d90febca6de9ad8815e584814086 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 11:48:22 -0800 Subject: [PATCH 032/133] Revert: remove chunked logprobs (to be submitted in separate PR) This reverts the chunked logprobs feature while keeping the CausalLMBase refactoring. Changes removed: - _compute_chunked_logprobs method - lm_head_weight property - loss_chunk_size config - _train_unembed_mask runtime check - TestChunkedCrossEntropyLoss tests Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/tinker/test_jax_backend.py | 160 ---------------------- skyrl-tx/tests/utils/test_generator.py | 7 - skyrl-tx/tx/models/base.py | 74 +--------- skyrl-tx/tx/models/llama3.py | 8 -- skyrl-tx/tx/models/qwen3.py | 8 -- skyrl-tx/tx/tinker/backends/jax.py | 24 +--- 6 files changed, 4 insertions(+), 277 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index edf91b0db..2edd9d82b 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -556,163 +556,3 @@ def test_adapter_reuse_initializes_lora_adapter(): # Verify lora_B is zeros assert (lora_layer.lora_B[adapter_idx] == 0.0).all(), "lora_B should be zeros" - - -class TestChunkedCrossEntropyLoss: - """Tests for chunked cross-entropy loss computation.""" - - def _create_backend(self, loss_chunk_size: int, max_lora_adapters: int = 2) -> JaxBackend: - """Create a backend with specified chunk size.""" - config = JaxBackendConfig( - max_lora_adapters=max_lora_adapters, - max_lora_rank=32, - loss_chunk_size=loss_chunk_size, - ) - return JaxBackend(BASE_MODEL, config) - - def _create_inputs(self, backend: JaxBackend, batch_size: int, seq_len: int, adapter_idx: int = 0): - """Create test inputs for forward pass.""" - vocab = backend.model.config.vocab_size - input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab - attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - adapter_indices = jnp.full((batch_size,), adapter_idx, dtype=jnp.int32) - target_ids = (input_ids + 1) % vocab - loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) - loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) - sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) - advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) - return ( - input_ids, - attention_mask, - adapter_indices, - target_ids, - loss_mask, - loss_fn_types, - sampling_logprobs, - advantages, - ) - - def _run_forward(self, backend: JaxBackend, inputs: tuple): - """Run forward pass and return losses and logprobs.""" - ( - input_ids, - attention_mask, - adapter_indices, - target_ids, - loss_mask, - loss_fn_types, - sampling_logprobs, - advantages, - ) = inputs - _, losses, logprobs = backend._forward( - backend.accumulated_grads, - backend.lora_params, - backend.non_lora_params, - input_ids, - attention_mask, - adapter_indices, - target_ids, - loss_mask, - loss_fn_types, - sampling_logprobs, - advantages, - ) - return losses, logprobs - - @pytest.mark.parametrize( - "batch_size,seq_len,chunk_size", - [ - (2, 16, 8), # Multiple batches - (1, 16, 16), # Exact multiple (1 chunk) - (1, 17, 16), # One extra token (worst case padding) - (1, 8, 16), # Fewer tokens than chunk size - (1, 32, 16), # Exact 2 chunks - (1, 1, 16), # Single token - (1, 31, 16), # Almost 2 chunks - ], - ) - def test_chunked_vs_nonchunked_logprobs(self, batch_size, seq_len, chunk_size): - """Verify chunked and non-chunked loss produce identical logprobs.""" - backend_chunked = self._create_backend(loss_chunk_size=chunk_size) - backend_nonchunked = self._create_backend(loss_chunk_size=0) - - assert backend_chunked.config.loss_chunk_size > 0 - assert backend_nonchunked.config.loss_chunk_size == 0 - - inputs = self._create_inputs(backend_chunked, batch_size, seq_len) - losses_chunked, logprobs_chunked = self._run_forward(backend_chunked, inputs) - losses_nonchunked, logprobs_nonchunked = self._run_forward(backend_nonchunked, inputs) - - np.testing.assert_allclose( - np.asarray(logprobs_chunked), - np.asarray(logprobs_nonchunked), - rtol=1e-4, - atol=1e-4, - err_msg=f"Logprobs mismatch for batch_size={batch_size}, seq_len={seq_len}, chunk_size={chunk_size}", - ) - np.testing.assert_allclose( - np.asarray(losses_chunked), - np.asarray(losses_nonchunked), - rtol=1e-4, - atol=1e-4, - err_msg=f"Losses mismatch for batch_size={batch_size}, seq_len={seq_len}, chunk_size={chunk_size}", - ) - - def test_mixed_train_unembed_adapters(self): - """Test that chunked and non-chunked paths produce same results with mixed adapters.""" - backend_chunked = self._create_backend(loss_chunk_size=1024, max_lora_adapters=3) - backend_nonchunked = self._create_backend(loss_chunk_size=0, max_lora_adapters=3) - - # Create same models on both backends - for backend in [backend_chunked, backend_nonchunked]: - backend.create_model("model_normal", LoraConfig(rank=8, alpha=16, seed=0, train_unembed=False)) - backend.create_model("model_unembed", LoraConfig(rank=8, alpha=16, seed=1, train_unembed=True)) - - normal_idx = backend_chunked.models["model_normal"].adapter_index - unembed_idx = backend_chunked.models["model_unembed"].adapter_index - - batch_size, seq_len = 2, 16 - vocab = backend_chunked.model.config.vocab_size - input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab - attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - target_ids = (input_ids + 1) % vocab - loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) - loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) - sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) - advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) - - def run_forward(backend, adapter_indices): - _, losses, logprobs = backend._forward( - backend.accumulated_grads, - backend.lora_params, - backend.non_lora_params, - input_ids, - attention_mask, - adapter_indices, - target_ids, - loss_mask, - loss_fn_types, - sampling_logprobs, - advantages, - ) - return losses, logprobs - - # Test with mixed adapters: one normal, one unembed - adapter_indices = jnp.array([normal_idx, unembed_idx], dtype=jnp.int32) - losses_chunked, logprobs_chunked = run_forward(backend_chunked, adapter_indices) - losses_nonchunked, logprobs_nonchunked = run_forward(backend_nonchunked, adapter_indices) - - np.testing.assert_allclose( - np.asarray(logprobs_chunked), - np.asarray(logprobs_nonchunked), - rtol=1e-4, - atol=1e-4, - err_msg="Chunked vs non-chunked logprobs mismatch with mixed train_unembed adapters", - ) - np.testing.assert_allclose( - np.asarray(losses_chunked), - np.asarray(losses_nonchunked), - rtol=1e-4, - atol=1e-4, - err_msg="Chunked vs non-chunked losses mismatch with mixed train_unembed adapters", - ) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 3383b8e89..4862fc457 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -1,5 +1,4 @@ from flax import nnx -import jax import jax.numpy as jnp from tx.models.base import CausalLMBase from tx.models.types import CausalLMOutput @@ -15,15 +14,9 @@ class DummyModel(GeneratorMixin, CausalLMBase, nnx.Module): def __init__(self, vocab_size: int = 16): self.vocab_size = vocab_size - self._lm_head_weight = jnp.eye(vocab_size, dtype=jnp.float32) # Identity lm_head - hidden_states are already logits CausalLMBase.__init__(self, None, lambda hidden_states, adapter_indices=None: hidden_states) - @property - def lm_head_weight(self) -> jax.Array: - """Identity matrix for dummy model.""" - return self._lm_head_weight - def __call__( self, input_ids, diff --git a/skyrl-tx/tx/models/base.py b/skyrl-tx/tx/models/base.py index 9d82ece2d..f9d59aa31 100644 --- a/skyrl-tx/tx/models/base.py +++ b/skyrl-tx/tx/models/base.py @@ -1,6 +1,5 @@ """Base class for causal language models.""" -from abc import abstractmethod from typing import Callable import jax @@ -19,12 +18,6 @@ def __init__(self, config: PretrainedConfig, lm_head: LMHead): self.config = config self.lm_head = lm_head - @property - @abstractmethod - def lm_head_weight(self) -> jax.Array: - """LM head weight matrix [H, V] for efficient chunked computation.""" - ... - def compute_logits( self, hidden_states: jax.Array, @@ -46,8 +39,6 @@ def compute_logprobs( hidden_states: jax.Array, target_ids: jax.Array, adapter_indices: jax.Array | None = None, - chunk_size: int = 0, - gradient_checkpointing: bool = False, ) -> jax.Array: """Compute logprobs from hidden states. For training and prompt logprobs. @@ -55,22 +46,12 @@ def compute_logprobs( hidden_states: Hidden states [B, T, H]. target_ids: Target token IDs [B, T]. adapter_indices: Adapter indices for LoRA on lm_head. - Pass when train_unembed=True. Forces non-chunked path. - chunk_size: Chunk size for chunked computation (0 = non-chunked). - gradient_checkpointing: Whether to checkpoint each chunk. Returns: Log probabilities for target tokens [B, T]. """ - # Chunked path doesn't support LoRA on lm_head - use_chunk = chunk_size > 0 and adapter_indices is None - if use_chunk: - return self._compute_chunked_logprobs( - hidden_states, target_ids, chunk_size, gradient_checkpointing - ) - else: - logits = self.compute_logits(hidden_states, adapter_indices) - return self.logits_to_logprobs(logits, target_ids) + logits = self.compute_logits(hidden_states, adapter_indices) + return self.logits_to_logprobs(logits, target_ids) @staticmethod def logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: @@ -86,54 +67,3 @@ def logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) return (target_logits - log_sum_exp).squeeze(-1) - - def _compute_chunked_logprobs( - self, - hidden_states: jax.Array, - target_ids: jax.Array, - chunk_size: int, - gradient_checkpointing: bool, - ) -> jax.Array: - """Compute log probabilities using chunked lm_head computation. - - This avoids materializing the full [B*T, V] logits tensor by computing - lm_head and log probabilities for each chunk sequentially. - """ - B, T, H = hidden_states.shape - total_tokens = B * T - lm_head_weight = self.lm_head_weight - - # Flatten batch and sequence dimensions - flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] - flat_target_ids = target_ids.reshape(-1) # [B*T] - - # Pad to multiple of chunk_size for clean slicing - num_chunks = (total_tokens + chunk_size - 1) // chunk_size - padded_size = num_chunks * chunk_size - pad_amount = padded_size - total_tokens - - if pad_amount > 0: - flat_hidden = jnp.pad(flat_hidden, ((0, pad_amount), (0, 0))) - flat_target_ids = jnp.pad(flat_target_ids, (0, pad_amount)) - - # Reshape into chunks: [num_chunks, chunk_size, H] and [num_chunks, chunk_size] - chunked_hidden = flat_hidden.reshape(num_chunks, chunk_size, H) - chunked_targets = flat_target_ids.reshape(num_chunks, chunk_size) - - def compute_chunk_logprobs(args): - """Compute lm_head and log probabilities for a chunk of tokens.""" - chunk_hidden, chunk_targets = args - # Compute logits for this chunk only: [chunk_size, H] @ [H, V] = [chunk_size, V] - chunk_logits = chunk_hidden @ lm_head_weight - # Compute log probabilities - log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) - target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) - return (target_logits - log_sum_exp).squeeze(-1) - - if gradient_checkpointing: - compute_chunk_logprobs = jax.checkpoint(compute_chunk_logprobs, policy=None) - - # Process chunks sequentially using lax.map (not vmap) to reduce memory - all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets)) - # Flatten and slice to original size, then reshape to [B, T] - return all_logprobs.reshape(-1)[:total_tokens].reshape(B, T) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 8aa6363ad..238c81450 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -288,14 +288,6 @@ def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" return any(name in path for name in ("lora_A", "lora_B")) - @property - def lm_head_weight(self) -> jax.Array: - """Returns lm_head weight [H, V] for external matmul (e.g., chunked cross-entropy).""" - if self.config.tie_word_embeddings: - return self.model.embed_tokens.embedding[...].T - else: - return self.lm_head.kernel[...] - def __call__( self, input_ids: jax.Array, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index a1fd95910..72f8a7b33 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -403,14 +403,6 @@ def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" return any(name in path for name in ("lora_A", "lora_B")) - @property - def lm_head_weight(self) -> jax.Array: - """Returns lm_head weight [H, V] for external matmul (e.g., chunked cross-entropy).""" - if self.config.tie_word_embeddings: - return self.model.embed_tokens.embedding[...].T - else: - return self.lm_head.kernel[...] - def __call__( self, input_ids: jax.Array, diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 4bc606451..686c8e4d7 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -83,10 +83,6 @@ class JaxBackendConfig(BaseModel, extra="forbid"): default=False, description="Whether to use gradient checkpointing (full recomputation strategy)", ) - loss_chunk_size: int = Field( - default=1024, - description="Chunk size for cross-entropy loss computation. Reduces memory by avoiding full [B*T, V] logits materialization. Set to 0 to disable chunking.", - ) # Multi-node configuration coordinator_address: str | None = Field( default=None, @@ -204,8 +200,6 @@ def __init__(self, base_model: str, config: JaxBackendConfig): f"max_lora_adapters={config.max_lora_adapters}, max_lora_rank={config.max_lora_rank}" ) - # Track which adapters use train_unembed=True (requires LoRA on lm_head) - self._train_unembed_mask = jnp.zeros(config.max_lora_adapters, dtype=jnp.bool_) self._create_loss_and_grad_fn() def _micro_batch_size(self, total: int) -> int: @@ -234,8 +228,6 @@ def _jit_timing_context(self, seq_len: int, mode: str): def _create_loss_and_grad_fn(self): """Compile and cache the loss function to avoid re-jitting on every call.""" - loss_chunk_size = self.config.loss_chunk_size - gradient_checkpointing = self.config.gradient_checkpointing def _forward_and_logprobs( graphdef: nnx.GraphDef, @@ -245,7 +237,6 @@ def _forward_and_logprobs( attention_mask: jax.Array, adapter_indices: jax.Array, target_ids: jax.Array, - train_unembed_mask: jax.Array, ) -> jax.Array: """Forward pass and logprobs computation.""" model = nnx.merge(graphdef, lora_params, non_lora_params) @@ -254,13 +245,7 @@ def _forward_and_logprobs( attention_mask=attention_mask, adapter_indices=adapter_indices, ) - # Check at runtime if any adapter in batch needs LoRA on lm_head - needs_lm_head_lora = train_unembed_mask[adapter_indices].any() - def logprobs(lm_head_adapter_indices): - return model.compute_logprobs( - output.last_hidden_state, target_ids, lm_head_adapter_indices, loss_chunk_size, gradient_checkpointing - ) - return jax.lax.cond(needs_lm_head_lora, lambda: logprobs(adapter_indices), lambda: logprobs(None)) + return model.compute_logprobs(output.last_hidden_state, target_ids, adapter_indices) if self.config.gradient_checkpointing: # Wrap the forward + logprobs call to use jax.checkpoint for gradient checkpointing @@ -281,7 +266,6 @@ def loss_for_lora( ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: target_logprobs = _forward_and_logprobs( self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices, target_ids, - self._train_unembed_mask, ) def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): @@ -453,9 +437,6 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: if not (0 < lora_config.rank <= self.config.max_lora_rank): raise ValueError(f"LoRA rank {lora_config.rank} must be between 1 and {self.config.max_lora_rank}") - # Set train_unembed mask for this adapter - self._train_unembed_mask = self._train_unembed_mask.at[adapter_index].set(lora_config.train_unembed) - # Store model metadata self.models[model_id] = types.ModelMetadata( adapter_index=adapter_index, @@ -479,10 +460,9 @@ def delete_model(self, model_id: str) -> None: # Get adapter index before deleting metadata adapter_index = self.models[model_id].adapter_index - # Clear LoRA adapter weights and reset train_unembed mask + # Clear LoRA adapter weights with jax.set_mesh(self.mesh): clear_lora_adapter(self.model, adapter_index) - self._train_unembed_mask = self._train_unembed_mask.at[adapter_index].set(False) # Delete optimizer del self.optimizers[model_id] From e14911294fcc1240326293e0b8f27b7761b52770 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 11:55:53 -0800 Subject: [PATCH 033/133] refactor: split test_models_common into focused tests - test_compute_logits: compare with HuggingFace logits - test_compute_logprobs: verify equivalence with manual computation - Remove generation tests (belong in generator tests) Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_models_common.py | 44 ++++----------------- 1 file changed, 8 insertions(+), 36 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index df7ee22bd..9e932f439 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -2,7 +2,6 @@ from flax import nnx import jax -import jax.numpy as jnp import numpy as np import pytest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer @@ -10,7 +9,6 @@ from tx.models.configs import Llama3Config, Qwen3Config from tx.models.llama3 import Llama3ForCausalLM from tx.models.qwen3 import Qwen3ForCausalLM -from tx.tinker.types import SamplingParams from tx.utils.models import get_dtype, load_safetensors @@ -23,13 +21,12 @@ ids=["llama3", "qwen3"], ) def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): - """Test that model.compute_logits computes correct logits.""" + """Test that model.compute_logits matches HuggingFace logits.""" tokenizer = AutoTokenizer.from_pretrained(model_name) hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) inputs = ["The capital of France is", "Hello world"] batch = tokenizer(inputs, return_tensors="pt", padding=True) - batch_size, seq_len = batch.input_ids.shape with tempfile.TemporaryDirectory() as tmp: hf_model.save_pretrained(tmp, safe_serialization=True) @@ -41,37 +38,12 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): model = model_cls(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) load_safetensors(tmp, config, model) - # Get hidden states from model - outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) - - # Compute full logits using model.compute_logits - full_logits = model.compute_logits(outputs.last_hidden_state) - assert full_logits.shape == (batch_size, seq_len, config.vocab_size) - - # Compute last token logits only - last_logits = model.compute_logits(outputs.last_hidden_state[:, -1:, :]) - assert last_logits.shape == (batch_size, 1, config.vocab_size) + # Get HF logits + hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask) + hf_logits = hf_outputs.logits.detach().numpy() - # Last token logits should match - assert np.allclose(full_logits[:, -1:, :], last_logits, rtol=1e-5, atol=1e-5) - - # Test generation equivalence with and without prompt_logprobs - input_ids = jnp.array(batch.input_ids.numpy()) - attention_mask = jnp.array(batch.attention_mask.numpy()) - sampling_params = [SamplingParams(max_tokens=8, temperature=0.0, seed=42)] * batch_size - - result_with = model.generate(input_ids, attention_mask, sampling_params=sampling_params, prompt_logprobs=True) - result_without = model.generate( - input_ids, attention_mask, sampling_params=sampling_params, prompt_logprobs=False - ) + # Get our logits via compute_logits + outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) + our_logits = model.compute_logits(outputs.last_hidden_state) - for i in range(batch_size): - assert ( - result_with.generated_ids[i] == result_without.generated_ids[i] - ), f"Generated tokens should match for seq {i}" - assert ( - result_with.stop_reasons[i] == result_without.stop_reasons[i] - ), f"Stop reasons should match for seq {i}" - assert np.allclose( - result_with.logprobs[i], result_without.logprobs[i] - ), f"Logprobs should match for seq {i}" + np.testing.assert_allclose(our_logits, hf_logits, rtol=1e-4, atol=1e-4) From 9575da35203b5fd341b5ab57e0f0ea0897739c53 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 12:23:00 -0800 Subject: [PATCH 034/133] lint --- skyrl-tx/tx/tinker/backends/jax.py | 8 +++++++- skyrl-tx/tx/utils/generator.py | 4 +--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 686c8e4d7..dbb871a0d 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -265,7 +265,13 @@ def loss_for_lora( advantages: jax.Array, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: target_logprobs = _forward_and_logprobs( - self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices, target_ids, + self.graphdef, + lora_params, + non_lora_params, + input_ids, + attention_mask, + adapter_indices, + target_ids, ) def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index cb83a5cbb..a6229160d 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -153,9 +153,7 @@ def _prefill_and_decode( # Compute prompt logprobs if requested if prompt_logprobs: - prompt_logprobs_array = model.compute_logprobs( - outputs.last_hidden_state[:, :-1, :], input_ids[:, 1:] - ) + prompt_logprobs_array = model.compute_logprobs(outputs.last_hidden_state[:, :-1, :], input_ids[:, 1:]) else: prompt_logprobs_array = None From 36a6961ade24838b21d395f0f8b90a09c18f8bf2 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 12:31:43 -0800 Subject: [PATCH 035/133] address comments --- skyrl-tx/tests/utils/test_generator.py | 4 +++- skyrl-tx/tx/layers/lora.py | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 4862fc457..270b4db3c 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + from flax import nnx import jax.numpy as jnp from tx.models.base import CausalLMBase @@ -15,7 +17,7 @@ class DummyModel(GeneratorMixin, CausalLMBase, nnx.Module): def __init__(self, vocab_size: int = 16): self.vocab_size = vocab_size # Identity lm_head - hidden_states are already logits - CausalLMBase.__init__(self, None, lambda hidden_states, adapter_indices=None: hidden_states) + CausalLMBase.__init__(self, MagicMock(), lambda hidden_states, adapter_indices=None: hidden_states) def __call__( self, diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index dd156f8ea..aac62dcb1 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -302,6 +302,11 @@ def init_lora_adapter(model: "CausalLMBase", adapter_index: int, lora_config: Lo adapter_index: Index of the adapter to initialize lora_config: LoraConfig object containing rank, alpha, seed, and training flags """ + if lora_config.train_unembed and getattr(model.config, "tie_word_embeddings", False): + raise ValueError( + "train_unembed=True is incompatible with tie_word_embeddings=True. " + "Tied embeddings use embed_tokens.T which does not support LoRA." + ) rngs = nnx.Rngs(lora_config.seed) state = nnx.state(model) From f6ed3fb36efa6f4bdf6a6cd24b1d9ffb21448e4f Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 12:55:52 -0800 Subject: [PATCH 036/133] fix: pass adapter_indices to compute_logprobs for prompt logprobs The prompt_logprobs computation was not passing adapter_indices to compute_logprobs, which would cause incorrect results when using LoRA adapters. Added test coverage for this case. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/utils/test_generator.py | 25 +++++++++++++++++++++++-- skyrl-tx/tx/utils/generator.py | 4 +++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 270b4db3c..99f50cb42 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -12,12 +12,20 @@ class DummyModel(GeneratorMixin, CausalLMBase, nnx.Module): """Dummy model for testing generator behavior. In this dummy model, hidden_states directly equal logits (identity transformation). + When adapter_indices is provided, it adds the adapter index to logits. """ def __init__(self, vocab_size: int = 16): self.vocab_size = vocab_size - # Identity lm_head - hidden_states are already logits - CausalLMBase.__init__(self, MagicMock(), lambda hidden_states, adapter_indices=None: hidden_states) + + def lm_head(hidden_states, adapter_indices=None): + # Scale logits by (1 + adapter_index) so different adapters give different log-softmax results + if adapter_indices is not None: + scale = (1 + adapter_indices[:, None, None]).astype(jnp.float32) + return hidden_states * scale + return hidden_states + + CausalLMBase.__init__(self, MagicMock(), lm_head) def __call__( self, @@ -141,6 +149,19 @@ def test_prompt_logprobs(): len(result_batch.prompt_logprobs[i]) == expected_length ), f"Sequence {i}: expected prompt_logprobs length {expected_length}" + # Test that adapter_indices affects prompt_logprobs (verifies adapter_indices is passed to compute_logprobs) + adapter_0 = jnp.array([0], dtype=jnp.int32) + adapter_1 = jnp.array([1], dtype=jnp.int32) + result_adapter_0 = model.generate( + input_ids, attention_mask, sampling_params=[sampling], adapter_indices=adapter_0, prompt_logprobs=True + ) + result_adapter_1 = model.generate( + input_ids, attention_mask, sampling_params=[sampling], adapter_indices=adapter_1, prompt_logprobs=True + ) + assert not jnp.allclose( + jnp.array(result_adapter_0.prompt_logprobs[0]), jnp.array(result_adapter_1.prompt_logprobs[0]) + ), "prompt_logprobs should differ when adapter_indices differ" + def test_top_k_filtering(): """Test apply_top_k_batch function directly.""" diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index a6229160d..520fefbc5 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -153,7 +153,9 @@ def _prefill_and_decode( # Compute prompt logprobs if requested if prompt_logprobs: - prompt_logprobs_array = model.compute_logprobs(outputs.last_hidden_state[:, :-1, :], input_ids[:, 1:]) + prompt_logprobs_array = model.compute_logprobs( + outputs.last_hidden_state[:, :-1, :], input_ids[:, 1:], adapter_indices + ) else: prompt_logprobs_array = None From d635429b8cd2d9724b58a852b199066b366ab245 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 14:20:47 -0800 Subject: [PATCH 037/133] use mixin --- skyrl-tx/tests/utils/test_generator.py | 14 ++++++++------ skyrl-tx/tx/layers/lora.py | 10 +++------- skyrl-tx/tx/models/llama3.py | 14 +++++++++----- skyrl-tx/tx/models/qwen3.py | 14 +++++++++----- skyrl-tx/tx/models/types.py | 6 ++++++ .../base.py => utils/logits_processor.py} | 17 +++++++++-------- 6 files changed, 44 insertions(+), 31 deletions(-) rename skyrl-tx/tx/{models/base.py => utils/logits_processor.py} (82%) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 99f50cb42..7b1752eaa 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -1,18 +1,16 @@ -from unittest.mock import MagicMock - from flax import nnx import jax.numpy as jnp -from tx.models.base import CausalLMBase from tx.models.types import CausalLMOutput from tx.tinker.types import SamplingParams from tx.utils.generator import GenerateOutput, GeneratorMixin, KVCache, apply_top_k_batch, apply_top_p_batch +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead -class DummyModel(GeneratorMixin, CausalLMBase, nnx.Module): +class DummyModel(GeneratorMixin, LogitsProcessorMixin, nnx.Module): """Dummy model for testing generator behavior. In this dummy model, hidden_states directly equal logits (identity transformation). - When adapter_indices is provided, it adds the adapter index to logits. + When adapter_indices is provided, it scales logits by (1 + adapter_index). """ def __init__(self, vocab_size: int = 16): @@ -25,7 +23,11 @@ def lm_head(hidden_states, adapter_indices=None): return hidden_states * scale return hidden_states - CausalLMBase.__init__(self, MagicMock(), lm_head) + self.lm_head = lm_head + + def get_lm_head(self) -> LMHead: + """Return the lm_head callable for logits computation.""" + return self.lm_head def __call__( self, diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index aac62dcb1..648c470a5 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -1,16 +1,12 @@ -from typing import TYPE_CHECKING - from flax import nnx import jax from jax import numpy as jnp from tx.utils.models import filter_lora from tx.layers.util import Param, prepare_routing, ragged_dot +from tx.models.types import ModelForCausalLM from tx.tinker.types import LoraConfig -if TYPE_CHECKING: - from tx.models.base import CausalLMBase - class LoRAMixin: """A mixin for flax NNX modules to add multi-adapter LoRA support. @@ -290,7 +286,7 @@ def __call__( return base_out + lora_output -def init_lora_adapter(model: "CausalLMBase", adapter_index: int, lora_config: LoraConfig): +def init_lora_adapter(model: ModelForCausalLM, adapter_index: int, lora_config: LoraConfig): """Initialize a LoRA adapter for training. Initializes the adapter: lora_A with he_uniform, lora_B with zeros, @@ -344,7 +340,7 @@ def init_adapter(path, value): nnx.update(model, updated_state) -def clear_lora_adapter(model: "CausalLMBase", adapter_index: int): +def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int): """Clear/reset a LoRA adapter, freeing it for reuse. Sets rank=0, scaling=0, and zeros out lora_A and lora_B for the adapter. diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 238c81450..b7eb14d52 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -7,7 +7,7 @@ from tx.layers.lora import LoRAEmbed, LoRALinear from tx.layers.rotary_embedding import apply_rope from tx.layers.layernorm import RMSNorm -from tx.models.base import CausalLMBase +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead from tx.models.types import CausalLMOutput, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache, compute_positions @@ -262,15 +262,16 @@ def __call__( ) -class Llama3ForCausalLM(nnx.Module, GeneratorMixin, CausalLMBase): +class Llama3ForCausalLM(nnx.Module, GeneratorMixin, LogitsProcessorMixin): def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config self.model = Llama3Model(config, dtype=dtype, rngs=rngs) if config.tie_word_embeddings: - lm_head = self.model.embed_tokens.T + self.lm_head = self.model.embed_tokens.T else: - lm_head = LoRALinear( + self.lm_head = LoRALinear( config.hidden_size, config.vocab_size, use_bias=False, @@ -281,7 +282,10 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_rank=config.max_lora_rank, rngs=rngs, ) - CausalLMBase.__init__(self, config, lm_head) + + def get_lm_head(self) -> LMHead: + """Return the lm_head callable for logits computation.""" + return self.lm_head @staticmethod def is_lora_param(path: tuple, _value) -> bool: diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 72f8a7b33..fdf68ee48 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -6,7 +6,7 @@ from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear from tx.layers.util import prepare_routing, shard_map_ep from tx.layers.rotary_embedding import apply_rope -from tx.models.base import CausalLMBase +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead from tx.models.configs import Qwen3Config from tx.layers.layernorm import RMSNorm from tx.models.types import CausalLMOutput, ModelOutput @@ -377,15 +377,16 @@ def __call__( ) -class Qwen3ForCausalLM(nnx.Module, GeneratorMixin, CausalLMBase): +class Qwen3ForCausalLM(nnx.Module, GeneratorMixin, LogitsProcessorMixin): def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config self.model = Qwen3Model(config, dtype=dtype, rngs=rngs) if config.tie_word_embeddings: - lm_head = self.model.embed_tokens.T + self.lm_head = self.model.embed_tokens.T else: - lm_head = LoRALinear( + self.lm_head = LoRALinear( config.hidden_size, config.vocab_size, use_bias=False, @@ -396,7 +397,10 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_rank=config.max_lora_rank, rngs=rngs, ) - CausalLMBase.__init__(self, config, lm_head) + + def get_lm_head(self) -> LMHead: + """Return the lm_head callable for logits computation.""" + return self.lm_head @staticmethod def is_lora_param(path: tuple, _value) -> bool: diff --git a/skyrl-tx/tx/models/types.py b/skyrl-tx/tx/models/types.py index d038d1a47..be60f6ec9 100644 --- a/skyrl-tx/tx/models/types.py +++ b/skyrl-tx/tx/models/types.py @@ -2,12 +2,18 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Protocol import jax +from transformers import PretrainedConfig from tx.utils.generator import KVCache +class ModelForCausalLM(Protocol): + config: PretrainedConfig + + @jax.tree_util.register_dataclass @dataclass class ModelOutput: diff --git a/skyrl-tx/tx/models/base.py b/skyrl-tx/tx/utils/logits_processor.py similarity index 82% rename from skyrl-tx/tx/models/base.py rename to skyrl-tx/tx/utils/logits_processor.py index f9d59aa31..68ee87434 100644 --- a/skyrl-tx/tx/models/base.py +++ b/skyrl-tx/tx/utils/logits_processor.py @@ -1,22 +1,23 @@ -"""Base class for causal language models.""" +"""Mixin for logits computation in causal language models.""" +from abc import abstractmethod from typing import Callable import jax import jax.numpy as jnp -from transformers import PretrainedConfig # lm_head: (hidden_states, adapter_indices) -> logits LMHead = Callable[[jax.Array, jax.Array | None], jax.Array] -class CausalLMBase: - """Base class providing logits/logprobs computation for causal language models.""" +class LogitsProcessorMixin: + """Mixin providing logits/logprobs computation for causal language models.""" - def __init__(self, config: PretrainedConfig, lm_head: LMHead): - self.config = config - self.lm_head = lm_head + @abstractmethod + def get_lm_head(self) -> LMHead: + """Return the lm_head callable for logits computation.""" + ... def compute_logits( self, @@ -32,7 +33,7 @@ def compute_logits( Returns: Logits [B, T, V]. """ - return self.lm_head(hidden_states, adapter_indices) + return self.get_lm_head()(hidden_states, adapter_indices) def compute_logprobs( self, From a81c27fc34c7e9e902b5992a1a1e8ced08be3047 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 12:56:59 -0800 Subject: [PATCH 038/133] feat: add chunked cross-entropy loss computation Adds memory-efficient chunked logprobs computation to avoid materializing full [B*T, V] logits tensor during training: - CausalLMBase._compute_chunked_logprobs: processes tokens in chunks - loss_chunk_size config in JaxBackend (default 1024) - Runtime check for train_unembed to use non-chunked path when needed - lm_head_weight abstract property for direct weight access Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/tinker/test_jax_backend.py | 160 ++++++++++++++++++++++ skyrl-tx/tests/utils/test_generator.py | 7 + skyrl-tx/tx/models/llama3.py | 8 ++ skyrl-tx/tx/models/qwen3.py | 8 ++ skyrl-tx/tx/tinker/backends/jax.py | 24 +++- skyrl-tx/tx/utils/logits_processor.py | 73 +++++++++- 6 files changed, 276 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index 2edd9d82b..edf91b0db 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -556,3 +556,163 @@ def test_adapter_reuse_initializes_lora_adapter(): # Verify lora_B is zeros assert (lora_layer.lora_B[adapter_idx] == 0.0).all(), "lora_B should be zeros" + + +class TestChunkedCrossEntropyLoss: + """Tests for chunked cross-entropy loss computation.""" + + def _create_backend(self, loss_chunk_size: int, max_lora_adapters: int = 2) -> JaxBackend: + """Create a backend with specified chunk size.""" + config = JaxBackendConfig( + max_lora_adapters=max_lora_adapters, + max_lora_rank=32, + loss_chunk_size=loss_chunk_size, + ) + return JaxBackend(BASE_MODEL, config) + + def _create_inputs(self, backend: JaxBackend, batch_size: int, seq_len: int, adapter_idx: int = 0): + """Create test inputs for forward pass.""" + vocab = backend.model.config.vocab_size + input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + adapter_indices = jnp.full((batch_size,), adapter_idx, dtype=jnp.int32) + target_ids = (input_ids + 1) % vocab + loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) + loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) + sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + return ( + input_ids, + attention_mask, + adapter_indices, + target_ids, + loss_mask, + loss_fn_types, + sampling_logprobs, + advantages, + ) + + def _run_forward(self, backend: JaxBackend, inputs: tuple): + """Run forward pass and return losses and logprobs.""" + ( + input_ids, + attention_mask, + adapter_indices, + target_ids, + loss_mask, + loss_fn_types, + sampling_logprobs, + advantages, + ) = inputs + _, losses, logprobs = backend._forward( + backend.accumulated_grads, + backend.lora_params, + backend.non_lora_params, + input_ids, + attention_mask, + adapter_indices, + target_ids, + loss_mask, + loss_fn_types, + sampling_logprobs, + advantages, + ) + return losses, logprobs + + @pytest.mark.parametrize( + "batch_size,seq_len,chunk_size", + [ + (2, 16, 8), # Multiple batches + (1, 16, 16), # Exact multiple (1 chunk) + (1, 17, 16), # One extra token (worst case padding) + (1, 8, 16), # Fewer tokens than chunk size + (1, 32, 16), # Exact 2 chunks + (1, 1, 16), # Single token + (1, 31, 16), # Almost 2 chunks + ], + ) + def test_chunked_vs_nonchunked_logprobs(self, batch_size, seq_len, chunk_size): + """Verify chunked and non-chunked loss produce identical logprobs.""" + backend_chunked = self._create_backend(loss_chunk_size=chunk_size) + backend_nonchunked = self._create_backend(loss_chunk_size=0) + + assert backend_chunked.config.loss_chunk_size > 0 + assert backend_nonchunked.config.loss_chunk_size == 0 + + inputs = self._create_inputs(backend_chunked, batch_size, seq_len) + losses_chunked, logprobs_chunked = self._run_forward(backend_chunked, inputs) + losses_nonchunked, logprobs_nonchunked = self._run_forward(backend_nonchunked, inputs) + + np.testing.assert_allclose( + np.asarray(logprobs_chunked), + np.asarray(logprobs_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg=f"Logprobs mismatch for batch_size={batch_size}, seq_len={seq_len}, chunk_size={chunk_size}", + ) + np.testing.assert_allclose( + np.asarray(losses_chunked), + np.asarray(losses_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg=f"Losses mismatch for batch_size={batch_size}, seq_len={seq_len}, chunk_size={chunk_size}", + ) + + def test_mixed_train_unembed_adapters(self): + """Test that chunked and non-chunked paths produce same results with mixed adapters.""" + backend_chunked = self._create_backend(loss_chunk_size=1024, max_lora_adapters=3) + backend_nonchunked = self._create_backend(loss_chunk_size=0, max_lora_adapters=3) + + # Create same models on both backends + for backend in [backend_chunked, backend_nonchunked]: + backend.create_model("model_normal", LoraConfig(rank=8, alpha=16, seed=0, train_unembed=False)) + backend.create_model("model_unembed", LoraConfig(rank=8, alpha=16, seed=1, train_unembed=True)) + + normal_idx = backend_chunked.models["model_normal"].adapter_index + unembed_idx = backend_chunked.models["model_unembed"].adapter_index + + batch_size, seq_len = 2, 16 + vocab = backend_chunked.model.config.vocab_size + input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + target_ids = (input_ids + 1) % vocab + loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) + loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) + sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + + def run_forward(backend, adapter_indices): + _, losses, logprobs = backend._forward( + backend.accumulated_grads, + backend.lora_params, + backend.non_lora_params, + input_ids, + attention_mask, + adapter_indices, + target_ids, + loss_mask, + loss_fn_types, + sampling_logprobs, + advantages, + ) + return losses, logprobs + + # Test with mixed adapters: one normal, one unembed + adapter_indices = jnp.array([normal_idx, unembed_idx], dtype=jnp.int32) + losses_chunked, logprobs_chunked = run_forward(backend_chunked, adapter_indices) + losses_nonchunked, logprobs_nonchunked = run_forward(backend_nonchunked, adapter_indices) + + np.testing.assert_allclose( + np.asarray(logprobs_chunked), + np.asarray(logprobs_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg="Chunked vs non-chunked logprobs mismatch with mixed train_unembed adapters", + ) + np.testing.assert_allclose( + np.asarray(losses_chunked), + np.asarray(losses_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg="Chunked vs non-chunked losses mismatch with mixed train_unembed adapters", + ) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 7b1752eaa..dc25459b8 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -1,4 +1,5 @@ from flax import nnx +import jax import jax.numpy as jnp from tx.models.types import CausalLMOutput from tx.tinker.types import SamplingParams @@ -15,6 +16,7 @@ class DummyModel(GeneratorMixin, LogitsProcessorMixin, nnx.Module): def __init__(self, vocab_size: int = 16): self.vocab_size = vocab_size + self._lm_head_weight = jnp.eye(vocab_size, dtype=jnp.float32) def lm_head(hidden_states, adapter_indices=None): # Scale logits by (1 + adapter_index) so different adapters give different log-softmax results @@ -29,6 +31,11 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head + @property + def lm_head_weight(self) -> jax.Array: + """Identity matrix for dummy model.""" + return self._lm_head_weight + def __call__( self, input_ids, diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index b7eb14d52..125390038 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -292,6 +292,14 @@ def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" return any(name in path for name in ("lora_A", "lora_B")) + @property + def lm_head_weight(self) -> jax.Array: + """Returns lm_head weight [H, V] for external matmul (e.g., chunked cross-entropy).""" + if self.config.tie_word_embeddings: + return self.model.embed_tokens.embedding[...].T + else: + return self.lm_head.kernel[...] + def __call__( self, input_ids: jax.Array, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index fdf68ee48..ec4226052 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -407,6 +407,14 @@ def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" return any(name in path for name in ("lora_A", "lora_B")) + @property + def lm_head_weight(self) -> jax.Array: + """Returns lm_head weight [H, V] for external matmul (e.g., chunked cross-entropy).""" + if self.config.tie_word_embeddings: + return self.model.embed_tokens.embedding[...].T + else: + return self.lm_head.kernel[...] + def __call__( self, input_ids: jax.Array, diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index dbb871a0d..f44f7737b 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -83,6 +83,10 @@ class JaxBackendConfig(BaseModel, extra="forbid"): default=False, description="Whether to use gradient checkpointing (full recomputation strategy)", ) + loss_chunk_size: int = Field( + default=1024, + description="Chunk size for cross-entropy loss computation. Reduces memory by avoiding full [B*T, V] logits materialization. Set to 0 to disable chunking.", + ) # Multi-node configuration coordinator_address: str | None = Field( default=None, @@ -200,6 +204,8 @@ def __init__(self, base_model: str, config: JaxBackendConfig): f"max_lora_adapters={config.max_lora_adapters}, max_lora_rank={config.max_lora_rank}" ) + # Track which adapters use train_unembed=True (requires LoRA on lm_head) + self._train_unembed_mask = jnp.zeros(config.max_lora_adapters, dtype=jnp.bool_) self._create_loss_and_grad_fn() def _micro_batch_size(self, total: int) -> int: @@ -228,6 +234,8 @@ def _jit_timing_context(self, seq_len: int, mode: str): def _create_loss_and_grad_fn(self): """Compile and cache the loss function to avoid re-jitting on every call.""" + loss_chunk_size = self.config.loss_chunk_size + gradient_checkpointing = self.config.gradient_checkpointing def _forward_and_logprobs( graphdef: nnx.GraphDef, @@ -237,6 +245,7 @@ def _forward_and_logprobs( attention_mask: jax.Array, adapter_indices: jax.Array, target_ids: jax.Array, + train_unembed_mask: jax.Array, ) -> jax.Array: """Forward pass and logprobs computation.""" model = nnx.merge(graphdef, lora_params, non_lora_params) @@ -245,7 +254,13 @@ def _forward_and_logprobs( attention_mask=attention_mask, adapter_indices=adapter_indices, ) - return model.compute_logprobs(output.last_hidden_state, target_ids, adapter_indices) + # Check at runtime if any adapter in batch needs LoRA on lm_head + needs_lm_head_lora = train_unembed_mask[adapter_indices].any() + def logprobs(lm_head_adapter_indices): + return model.compute_logprobs( + output.last_hidden_state, target_ids, lm_head_adapter_indices, loss_chunk_size, gradient_checkpointing + ) + return jax.lax.cond(needs_lm_head_lora, lambda: logprobs(adapter_indices), lambda: logprobs(None)) if self.config.gradient_checkpointing: # Wrap the forward + logprobs call to use jax.checkpoint for gradient checkpointing @@ -272,6 +287,7 @@ def loss_for_lora( attention_mask, adapter_indices, target_ids, + self._train_unembed_mask, ) def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): @@ -443,6 +459,9 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: if not (0 < lora_config.rank <= self.config.max_lora_rank): raise ValueError(f"LoRA rank {lora_config.rank} must be between 1 and {self.config.max_lora_rank}") + # Set train_unembed mask for this adapter + self._train_unembed_mask = self._train_unembed_mask.at[adapter_index].set(lora_config.train_unembed) + # Store model metadata self.models[model_id] = types.ModelMetadata( adapter_index=adapter_index, @@ -466,9 +485,10 @@ def delete_model(self, model_id: str) -> None: # Get adapter index before deleting metadata adapter_index = self.models[model_id].adapter_index - # Clear LoRA adapter weights + # Clear LoRA adapter weights and reset train_unembed mask with jax.set_mesh(self.mesh): clear_lora_adapter(self.model, adapter_index) + self._train_unembed_mask = self._train_unembed_mask.at[adapter_index].set(False) # Delete optimizer del self.optimizers[model_id] diff --git a/skyrl-tx/tx/utils/logits_processor.py b/skyrl-tx/tx/utils/logits_processor.py index 68ee87434..bbd0feca5 100644 --- a/skyrl-tx/tx/utils/logits_processor.py +++ b/skyrl-tx/tx/utils/logits_processor.py @@ -19,6 +19,12 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" ... + @property + @abstractmethod + def lm_head_weight(self) -> jax.Array: + """LM head weight matrix [H, V] for efficient chunked computation.""" + ... + def compute_logits( self, hidden_states: jax.Array, @@ -40,6 +46,8 @@ def compute_logprobs( hidden_states: jax.Array, target_ids: jax.Array, adapter_indices: jax.Array | None = None, + chunk_size: int = 0, + gradient_checkpointing: bool = False, ) -> jax.Array: """Compute logprobs from hidden states. For training and prompt logprobs. @@ -47,12 +55,22 @@ def compute_logprobs( hidden_states: Hidden states [B, T, H]. target_ids: Target token IDs [B, T]. adapter_indices: Adapter indices for LoRA on lm_head. + Pass when train_unembed=True. Forces non-chunked path. + chunk_size: Chunk size for chunked computation (0 = non-chunked). + gradient_checkpointing: Whether to checkpoint each chunk. Returns: Log probabilities for target tokens [B, T]. """ - logits = self.compute_logits(hidden_states, adapter_indices) - return self.logits_to_logprobs(logits, target_ids) + # Chunked path doesn't support LoRA on lm_head + use_chunk = chunk_size > 0 and adapter_indices is None + if use_chunk: + return self._compute_chunked_logprobs( + hidden_states, target_ids, chunk_size, gradient_checkpointing + ) + else: + logits = self.compute_logits(hidden_states, adapter_indices) + return self.logits_to_logprobs(logits, target_ids) @staticmethod def logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: @@ -68,3 +86,54 @@ def logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) return (target_logits - log_sum_exp).squeeze(-1) + + def _compute_chunked_logprobs( + self, + hidden_states: jax.Array, + target_ids: jax.Array, + chunk_size: int, + gradient_checkpointing: bool, + ) -> jax.Array: + """Compute log probabilities using chunked lm_head computation. + + This avoids materializing the full [B*T, V] logits tensor by computing + lm_head and log probabilities for each chunk sequentially. + """ + B, T, H = hidden_states.shape + total_tokens = B * T + lm_head_weight = self.lm_head_weight + + # Flatten batch and sequence dimensions + flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] + flat_target_ids = target_ids.reshape(-1) # [B*T] + + # Pad to multiple of chunk_size for clean slicing + num_chunks = (total_tokens + chunk_size - 1) // chunk_size + padded_size = num_chunks * chunk_size + pad_amount = padded_size - total_tokens + + if pad_amount > 0: + flat_hidden = jnp.pad(flat_hidden, ((0, pad_amount), (0, 0))) + flat_target_ids = jnp.pad(flat_target_ids, (0, pad_amount)) + + # Reshape into chunks: [num_chunks, chunk_size, H] and [num_chunks, chunk_size] + chunked_hidden = flat_hidden.reshape(num_chunks, chunk_size, H) + chunked_targets = flat_target_ids.reshape(num_chunks, chunk_size) + + def compute_chunk_logprobs(args): + """Compute lm_head and log probabilities for a chunk of tokens.""" + chunk_hidden, chunk_targets = args + # Compute logits for this chunk only: [chunk_size, H] @ [H, V] = [chunk_size, V] + chunk_logits = chunk_hidden @ lm_head_weight + # Compute log probabilities + log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) + target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) + return (target_logits - log_sum_exp).squeeze(-1) + + if gradient_checkpointing: + compute_chunk_logprobs = jax.checkpoint(compute_chunk_logprobs, policy=None) + + # Process chunks sequentially using lax.map (not vmap) to reduce memory + all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets)) + # Flatten and slice to original size, then reshape to [B, T] + return all_logprobs.reshape(-1)[:total_tokens].reshape(B, T) From 38175fe4c212ff8c927d07d668e82f4eeaea622e Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 14:48:42 -0800 Subject: [PATCH 039/133] fix --- skyrl-tx/tests/utils/test_generator.py | 5 ++--- skyrl-tx/tx/models/llama3.py | 15 +++++++-------- skyrl-tx/tx/models/qwen3.py | 15 +++++++-------- skyrl-tx/tx/utils/logits_processor.py | 20 ++++++++------------ 4 files changed, 24 insertions(+), 31 deletions(-) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index dc25459b8..f0311a8ed 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -31,9 +31,8 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head - @property - def lm_head_weight(self) -> jax.Array: - """Identity matrix for dummy model.""" + def get_lm_head_weight(self) -> jax.Array: + """Return identity matrix for dummy model.""" return self._lm_head_weight def __call__( diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 125390038..46adab27b 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -287,19 +287,18 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head - @staticmethod - def is_lora_param(path: tuple, _value) -> bool: - """Return True if a parameter path corresponds to LoRA weights.""" - return any(name in path for name in ("lora_A", "lora_B")) - - @property - def lm_head_weight(self) -> jax.Array: - """Returns lm_head weight [H, V] for external matmul (e.g., chunked cross-entropy).""" + def get_lm_head_weight(self) -> jax.Array: + """Return the lm_head weight [H, V] for chunked cross-entropy.""" if self.config.tie_word_embeddings: return self.model.embed_tokens.embedding[...].T else: return self.lm_head.kernel[...] + @staticmethod + def is_lora_param(path: tuple, _value) -> bool: + """Return True if a parameter path corresponds to LoRA weights.""" + return any(name in path for name in ("lora_A", "lora_B")) + def __call__( self, input_ids: jax.Array, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index ec4226052..9614a0136 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -402,19 +402,18 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head - @staticmethod - def is_lora_param(path: tuple, _value) -> bool: - """Return True if a parameter path corresponds to LoRA weights.""" - return any(name in path for name in ("lora_A", "lora_B")) - - @property - def lm_head_weight(self) -> jax.Array: - """Returns lm_head weight [H, V] for external matmul (e.g., chunked cross-entropy).""" + def get_lm_head_weight(self) -> jax.Array: + """Return the lm_head weight [H, V] for chunked cross-entropy.""" if self.config.tie_word_embeddings: return self.model.embed_tokens.embedding[...].T else: return self.lm_head.kernel[...] + @staticmethod + def is_lora_param(path: tuple, _value) -> bool: + """Return True if a parameter path corresponds to LoRA weights.""" + return any(name in path for name in ("lora_A", "lora_B")) + def __call__( self, input_ids: jax.Array, diff --git a/skyrl-tx/tx/utils/logits_processor.py b/skyrl-tx/tx/utils/logits_processor.py index bbd0feca5..8aa773c92 100644 --- a/skyrl-tx/tx/utils/logits_processor.py +++ b/skyrl-tx/tx/utils/logits_processor.py @@ -6,12 +6,13 @@ import jax import jax.numpy as jnp +from tx.models.types import ModelForCausalLM # lm_head: (hidden_states, adapter_indices) -> logits LMHead = Callable[[jax.Array, jax.Array | None], jax.Array] -class LogitsProcessorMixin: +class LogitsProcessorMixin(ModelForCausalLM): """Mixin providing logits/logprobs computation for causal language models.""" @abstractmethod @@ -19,10 +20,9 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" ... - @property @abstractmethod - def lm_head_weight(self) -> jax.Array: - """LM head weight matrix [H, V] for efficient chunked computation.""" + def get_lm_head_weight(self) -> jax.Array: + """Return the lm_head weight matrix [H, V] for efficient chunked computation.""" ... def compute_logits( @@ -46,8 +46,6 @@ def compute_logprobs( hidden_states: jax.Array, target_ids: jax.Array, adapter_indices: jax.Array | None = None, - chunk_size: int = 0, - gradient_checkpointing: bool = False, ) -> jax.Array: """Compute logprobs from hidden states. For training and prompt logprobs. @@ -56,17 +54,16 @@ def compute_logprobs( target_ids: Target token IDs [B, T]. adapter_indices: Adapter indices for LoRA on lm_head. Pass when train_unembed=True. Forces non-chunked path. - chunk_size: Chunk size for chunked computation (0 = non-chunked). - gradient_checkpointing: Whether to checkpoint each chunk. Returns: Log probabilities for target tokens [B, T]. """ + chunk_size = self.config.loss_chunk_size # Chunked path doesn't support LoRA on lm_head use_chunk = chunk_size > 0 and adapter_indices is None if use_chunk: return self._compute_chunked_logprobs( - hidden_states, target_ids, chunk_size, gradient_checkpointing + hidden_states, target_ids, chunk_size ) else: logits = self.compute_logits(hidden_states, adapter_indices) @@ -92,7 +89,6 @@ def _compute_chunked_logprobs( hidden_states: jax.Array, target_ids: jax.Array, chunk_size: int, - gradient_checkpointing: bool, ) -> jax.Array: """Compute log probabilities using chunked lm_head computation. @@ -101,7 +97,7 @@ def _compute_chunked_logprobs( """ B, T, H = hidden_states.shape total_tokens = B * T - lm_head_weight = self.lm_head_weight + lm_head_weight = self.get_lm_head_weight() # Flatten batch and sequence dimensions flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] @@ -130,7 +126,7 @@ def compute_chunk_logprobs(args): target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) return (target_logits - log_sum_exp).squeeze(-1) - if gradient_checkpointing: + if self.config.gradient_checkpointing: compute_chunk_logprobs = jax.checkpoint(compute_chunk_logprobs, policy=None) # Process chunks sequentially using lax.map (not vmap) to reduce memory From 524168392d9f3690e36e514505a8d48235475ea6 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 15:18:47 -0800 Subject: [PATCH 040/133] fix --- skyrl-tx/tests/utils/test_generator.py | 3 +++ skyrl-tx/tx/models/configs.py | 10 +++++++++- skyrl-tx/tx/models/types.py | 9 +++++---- skyrl-tx/tx/tinker/backends/jax.py | 9 +++------ 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index f0311a8ed..49d67197b 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + from flax import nnx import jax import jax.numpy as jnp @@ -15,6 +17,7 @@ class DummyModel(GeneratorMixin, LogitsProcessorMixin, nnx.Module): """ def __init__(self, vocab_size: int = 16): + self.config = MagicMock(loss_chunk_size=0, gradient_checkpointing=False) self.vocab_size = vocab_size self._lm_head_weight = jnp.eye(vocab_size, dtype=jnp.float32) diff --git a/skyrl-tx/tx/models/configs.py b/skyrl-tx/tx/models/configs.py index adc2b57ab..f7b8cc78d 100644 --- a/skyrl-tx/tx/models/configs.py +++ b/skyrl-tx/tx/models/configs.py @@ -14,12 +14,16 @@ class ModelConfig(PretrainedConfig): max_lora_adapters: Maximum number of concurrent LoRA adapters max_lora_rank: Maximum rank for LoRA adapters shard_attention_heads: Whether to shard attention across tensor parallel devices + loss_chunk_size: Chunk size for cross-entropy loss computation (0 = no chunking) + gradient_checkpointing: Whether to use gradient checkpointing for chunked loss """ - # Type hints for LoRA attributes + # Type hints for config attributes max_lora_adapters: int max_lora_rank: int shard_attention_heads: bool + loss_chunk_size: int + gradient_checkpointing: bool def __init__( self, @@ -28,6 +32,8 @@ def __init__( max_lora_adapters: int, max_lora_rank: int, shard_attention_heads: bool, + loss_chunk_size: int, + gradient_checkpointing: bool, ): # Copy all attributes from the base config super().__init__(**config.to_dict()) @@ -36,6 +42,8 @@ def __init__( self.max_lora_adapters = max_lora_adapters self.max_lora_rank = max_lora_rank self.shard_attention_heads = shard_attention_heads + self.loss_chunk_size = loss_chunk_size + self.gradient_checkpointing = gradient_checkpointing # Model-specific aliases for clarity and backwards compatibility diff --git a/skyrl-tx/tx/models/types.py b/skyrl-tx/tx/models/types.py index be60f6ec9..f0b7a6b21 100644 --- a/skyrl-tx/tx/models/types.py +++ b/skyrl-tx/tx/models/types.py @@ -2,16 +2,17 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Protocol import jax -from transformers import PretrainedConfig +from tx.models.configs import ModelConfig from tx.utils.generator import KVCache -class ModelForCausalLM(Protocol): - config: PretrainedConfig +class ModelForCausalLM: + """Base class for causal language models.""" + + config: ModelConfig @jax.tree_util.register_dataclass diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index f44f7737b..48c6892c7 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -167,6 +167,8 @@ def __init__(self, base_model: str, config: JaxBackendConfig): max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, shard_attention_heads=config.shard_attention_heads, + loss_chunk_size=config.loss_chunk_size, + gradient_checkpointing=config.gradient_checkpointing, ) model_class = get_model_class(self.model_config) @@ -234,9 +236,6 @@ def _jit_timing_context(self, seq_len: int, mode: str): def _create_loss_and_grad_fn(self): """Compile and cache the loss function to avoid re-jitting on every call.""" - loss_chunk_size = self.config.loss_chunk_size - gradient_checkpointing = self.config.gradient_checkpointing - def _forward_and_logprobs( graphdef: nnx.GraphDef, lora_params: nnx.State, @@ -257,9 +256,7 @@ def _forward_and_logprobs( # Check at runtime if any adapter in batch needs LoRA on lm_head needs_lm_head_lora = train_unembed_mask[adapter_indices].any() def logprobs(lm_head_adapter_indices): - return model.compute_logprobs( - output.last_hidden_state, target_ids, lm_head_adapter_indices, loss_chunk_size, gradient_checkpointing - ) + return model.compute_logprobs(output.last_hidden_state, target_ids, lm_head_adapter_indices) return jax.lax.cond(needs_lm_head_lora, lambda: logprobs(adapter_indices), lambda: logprobs(None)) if self.config.gradient_checkpointing: From ab68bd7882b57b070a0262cc4f52959929a3d4d6 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 15:30:44 -0800 Subject: [PATCH 041/133] refine tests --- skyrl-tx/tests/models/test_models_common.py | 90 +++++++--- skyrl-tx/tests/tinker/test_jax_backend.py | 186 +++++--------------- 2 files changed, 113 insertions(+), 163 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 9e932f439..cba022b94 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -2,6 +2,7 @@ from flax import nnx import jax +import jax.numpy as jnp import numpy as np import pytest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer @@ -11,39 +12,86 @@ from tx.models.qwen3 import Qwen3ForCausalLM from tx.utils.models import get_dtype, load_safetensors +MODEL_PARAMS = [ + ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp")), + ("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), +] +MODEL_IDS = ["llama3", "qwen3"] -@pytest.mark.parametrize( - "model_name,config_cls,model_cls,mesh_axes", - [ - ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp")), - ("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), - ], - ids=["llama3", "qwen3"], -) -def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): - """Test that model.compute_logits matches HuggingFace logits.""" + +def make_model(model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0, gradient_checkpointing=False): + """Create a model with the given config.""" tokenizer = AutoTokenizer.from_pretrained(model_name) hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) - inputs = ["The capital of France is", "Hello world"] - batch = tokenizer(inputs, return_tensors="pt", padding=True) - with tempfile.TemporaryDirectory() as tmp: hf_model.save_pretrained(tmp, safe_serialization=True) base_config = AutoConfig.from_pretrained(model_name) - config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) + config = config_cls( + base_config, + max_lora_adapters=1, + max_lora_rank=1, + shard_attention_heads=True, + loss_chunk_size=loss_chunk_size, + gradient_checkpointing=gradient_checkpointing, + ) mesh = jax.make_mesh((1, 1), mesh_axes) with jax.set_mesh(mesh): model = model_cls(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) load_safetensors(tmp, config, model) - # Get HF logits - hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask) - hf_logits = hf_outputs.logits.detach().numpy() + return model, tokenizer, hf_model + + +@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) +def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): + """Test that model.compute_logits matches HuggingFace logits.""" + model, tokenizer, hf_model = make_model(model_name, config_cls, model_cls, mesh_axes) + + inputs = ["The capital of France is", "Hello world"] + batch = tokenizer(inputs, return_tensors="pt", padding=True) + + # Get HF logits + hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask) + hf_logits = hf_outputs.logits.detach().numpy() + + # Get our logits via compute_logits + outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) + our_logits = model.compute_logits(outputs.last_hidden_state) + + np.testing.assert_allclose(our_logits, hf_logits, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) +@pytest.mark.parametrize("chunk_size", [8, 16, 32]) +def test_chunked_logprobs(model_name, config_cls, model_cls, mesh_axes, chunk_size): + """Test that chunked and non-chunked compute_logprobs produce identical results.""" + model_chunked, tokenizer, _ = make_model( + model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size + ) + model_nonchunked, _, _ = make_model( + model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0 + ) + + inputs = ["The capital of France is", "Hello world"] + batch = tokenizer(inputs, return_tensors="pt", padding=True) + input_ids = jnp.array(batch.input_ids.numpy()) + attention_mask = jnp.array(batch.attention_mask.numpy()) + target_ids = jnp.roll(input_ids, -1, axis=1) + + # Get hidden states + outputs_chunked = model_chunked(input_ids, attention_mask=attention_mask) + outputs_nonchunked = model_nonchunked(input_ids, attention_mask=attention_mask) - # Get our logits via compute_logits - outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) - our_logits = model.compute_logits(outputs.last_hidden_state) + # Compute logprobs with both methods + logprobs_chunked = model_chunked.compute_logprobs(outputs_chunked.last_hidden_state, target_ids) + logprobs_nonchunked = model_nonchunked.compute_logprobs(outputs_nonchunked.last_hidden_state, target_ids) - np.testing.assert_allclose(our_logits, hf_logits, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose( + np.asarray(logprobs_chunked), + np.asarray(logprobs_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg=f"Chunked vs non-chunked logprobs mismatch for chunk_size={chunk_size}", + ) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index edf91b0db..cc60043f3 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -558,52 +558,32 @@ def test_adapter_reuse_initializes_lora_adapter(): assert (lora_layer.lora_B[adapter_idx] == 0.0).all(), "lora_B should be zeros" -class TestChunkedCrossEntropyLoss: - """Tests for chunked cross-entropy loss computation.""" - - def _create_backend(self, loss_chunk_size: int, max_lora_adapters: int = 2) -> JaxBackend: - """Create a backend with specified chunk size.""" - config = JaxBackendConfig( - max_lora_adapters=max_lora_adapters, - max_lora_rank=32, - loss_chunk_size=loss_chunk_size, - ) - return JaxBackend(BASE_MODEL, config) - - def _create_inputs(self, backend: JaxBackend, batch_size: int, seq_len: int, adapter_idx: int = 0): - """Create test inputs for forward pass.""" - vocab = backend.model.config.vocab_size - input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab - attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - adapter_indices = jnp.full((batch_size,), adapter_idx, dtype=jnp.int32) - target_ids = (input_ids + 1) % vocab - loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) - loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) - sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) - advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) - return ( - input_ids, - attention_mask, - adapter_indices, - target_ids, - loss_mask, - loss_fn_types, - sampling_logprobs, - advantages, - ) - - def _run_forward(self, backend: JaxBackend, inputs: tuple): - """Run forward pass and return losses and logprobs.""" - ( - input_ids, - attention_mask, - adapter_indices, - target_ids, - loss_mask, - loss_fn_types, - sampling_logprobs, - advantages, - ) = inputs +def test_mixed_train_unembed_adapters(): + """Test that backend correctly routes to chunked/non-chunked path based on train_unembed.""" + config_chunked = JaxBackendConfig(max_lora_adapters=3, max_lora_rank=32, loss_chunk_size=1024) + config_nonchunked = JaxBackendConfig(max_lora_adapters=3, max_lora_rank=32, loss_chunk_size=0) + backend_chunked = JaxBackend(BASE_MODEL, config_chunked) + backend_nonchunked = JaxBackend(BASE_MODEL, config_nonchunked) + + # Create same models on both backends + for backend in [backend_chunked, backend_nonchunked]: + backend.create_model("model_normal", LoraConfig(rank=8, alpha=16, seed=0, train_unembed=False)) + backend.create_model("model_unembed", LoraConfig(rank=8, alpha=16, seed=1, train_unembed=True)) + + normal_idx = backend_chunked.models["model_normal"].adapter_index + unembed_idx = backend_chunked.models["model_unembed"].adapter_index + + batch_size, seq_len = 2, 16 + vocab = backend_chunked.model.config.vocab_size + input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + target_ids = (input_ids + 1) % vocab + loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) + loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) + sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + + def run_forward(backend, adapter_indices): _, losses, logprobs = backend._forward( backend.accumulated_grads, backend.lora_params, @@ -619,100 +599,22 @@ def _run_forward(self, backend: JaxBackend, inputs: tuple): ) return losses, logprobs - @pytest.mark.parametrize( - "batch_size,seq_len,chunk_size", - [ - (2, 16, 8), # Multiple batches - (1, 16, 16), # Exact multiple (1 chunk) - (1, 17, 16), # One extra token (worst case padding) - (1, 8, 16), # Fewer tokens than chunk size - (1, 32, 16), # Exact 2 chunks - (1, 1, 16), # Single token - (1, 31, 16), # Almost 2 chunks - ], + # Test with mixed adapters: one normal, one unembed + adapter_indices = jnp.array([normal_idx, unembed_idx], dtype=jnp.int32) + losses_chunked, logprobs_chunked = run_forward(backend_chunked, adapter_indices) + losses_nonchunked, logprobs_nonchunked = run_forward(backend_nonchunked, adapter_indices) + + np.testing.assert_allclose( + np.asarray(logprobs_chunked), + np.asarray(logprobs_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg="Chunked vs non-chunked logprobs mismatch with mixed train_unembed adapters", + ) + np.testing.assert_allclose( + np.asarray(losses_chunked), + np.asarray(losses_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg="Chunked vs non-chunked losses mismatch with mixed train_unembed adapters", ) - def test_chunked_vs_nonchunked_logprobs(self, batch_size, seq_len, chunk_size): - """Verify chunked and non-chunked loss produce identical logprobs.""" - backend_chunked = self._create_backend(loss_chunk_size=chunk_size) - backend_nonchunked = self._create_backend(loss_chunk_size=0) - - assert backend_chunked.config.loss_chunk_size > 0 - assert backend_nonchunked.config.loss_chunk_size == 0 - - inputs = self._create_inputs(backend_chunked, batch_size, seq_len) - losses_chunked, logprobs_chunked = self._run_forward(backend_chunked, inputs) - losses_nonchunked, logprobs_nonchunked = self._run_forward(backend_nonchunked, inputs) - - np.testing.assert_allclose( - np.asarray(logprobs_chunked), - np.asarray(logprobs_nonchunked), - rtol=1e-4, - atol=1e-4, - err_msg=f"Logprobs mismatch for batch_size={batch_size}, seq_len={seq_len}, chunk_size={chunk_size}", - ) - np.testing.assert_allclose( - np.asarray(losses_chunked), - np.asarray(losses_nonchunked), - rtol=1e-4, - atol=1e-4, - err_msg=f"Losses mismatch for batch_size={batch_size}, seq_len={seq_len}, chunk_size={chunk_size}", - ) - - def test_mixed_train_unembed_adapters(self): - """Test that chunked and non-chunked paths produce same results with mixed adapters.""" - backend_chunked = self._create_backend(loss_chunk_size=1024, max_lora_adapters=3) - backend_nonchunked = self._create_backend(loss_chunk_size=0, max_lora_adapters=3) - - # Create same models on both backends - for backend in [backend_chunked, backend_nonchunked]: - backend.create_model("model_normal", LoraConfig(rank=8, alpha=16, seed=0, train_unembed=False)) - backend.create_model("model_unembed", LoraConfig(rank=8, alpha=16, seed=1, train_unembed=True)) - - normal_idx = backend_chunked.models["model_normal"].adapter_index - unembed_idx = backend_chunked.models["model_unembed"].adapter_index - - batch_size, seq_len = 2, 16 - vocab = backend_chunked.model.config.vocab_size - input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab - attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - target_ids = (input_ids + 1) % vocab - loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) - loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) - sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) - advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) - - def run_forward(backend, adapter_indices): - _, losses, logprobs = backend._forward( - backend.accumulated_grads, - backend.lora_params, - backend.non_lora_params, - input_ids, - attention_mask, - adapter_indices, - target_ids, - loss_mask, - loss_fn_types, - sampling_logprobs, - advantages, - ) - return losses, logprobs - - # Test with mixed adapters: one normal, one unembed - adapter_indices = jnp.array([normal_idx, unembed_idx], dtype=jnp.int32) - losses_chunked, logprobs_chunked = run_forward(backend_chunked, adapter_indices) - losses_nonchunked, logprobs_nonchunked = run_forward(backend_nonchunked, adapter_indices) - - np.testing.assert_allclose( - np.asarray(logprobs_chunked), - np.asarray(logprobs_nonchunked), - rtol=1e-4, - atol=1e-4, - err_msg="Chunked vs non-chunked logprobs mismatch with mixed train_unembed adapters", - ) - np.testing.assert_allclose( - np.asarray(losses_chunked), - np.asarray(losses_nonchunked), - rtol=1e-4, - atol=1e-4, - err_msg="Chunked vs non-chunked losses mismatch with mixed train_unembed adapters", - ) From 8b5b02db3db1255a49364b65d1454615829e2b91 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 15:39:47 -0800 Subject: [PATCH 042/133] address comments --- skyrl-tx/tx/tinker/backends/jax.py | 8 ++++---- skyrl-tx/tx/utils/generator.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index dbb871a0d..7c4353a3b 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -229,7 +229,7 @@ def _jit_timing_context(self, seq_len: int, mode: str): def _create_loss_and_grad_fn(self): """Compile and cache the loss function to avoid re-jitting on every call.""" - def _forward_and_logprobs( + def _model_forward( graphdef: nnx.GraphDef, lora_params: nnx.State, non_lora_params: nnx.State, @@ -248,9 +248,9 @@ def _forward_and_logprobs( return model.compute_logprobs(output.last_hidden_state, target_ids, adapter_indices) if self.config.gradient_checkpointing: - # Wrap the forward + logprobs call to use jax.checkpoint for gradient checkpointing + # Wrap the model forward call to use jax.checkpoint for gradient checkpointing # policy=None corresponds to full activation recomputation - _forward_and_logprobs = jax.checkpoint(_forward_and_logprobs, policy=None) + _model_forward = jax.checkpoint(_model_forward, policy=None) def loss_for_lora( lora_params: nnx.State, @@ -264,7 +264,7 @@ def loss_for_lora( sampling_logprobs: jax.Array, advantages: jax.Array, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - target_logprobs = _forward_and_logprobs( + target_logprobs = _model_forward( self.graphdef, lora_params, non_lora_params, diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index 520fefbc5..7d1864ca2 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -148,15 +148,15 @@ def _prefill_and_decode( adapter_indices=adapter_indices, ) - # Compute logits for last position (needed for sampling first token) - last_logits = model.compute_logits(outputs.last_hidden_state[:, -1:, :], adapter_indices)[:, 0, :] - - # Compute prompt logprobs if requested + # Compute logits for sampling and optionally for prompt logprobs if prompt_logprobs: - prompt_logprobs_array = model.compute_logprobs( - outputs.last_hidden_state[:, :-1, :], input_ids[:, 1:], adapter_indices - ) + # Compute all logits for prompt logprobs and sampling the first token + all_logits = model.compute_logits(outputs.last_hidden_state, adapter_indices) + last_logits = all_logits[:, -1, :] + prompt_logprobs_array = model.logits_to_logprobs(all_logits[:, :-1, :], input_ids[:, 1:]) else: + # Only compute logits for the last position for sampling + last_logits = model.compute_logits(outputs.last_hidden_state[:, -1:, :], adapter_indices)[:, 0, :] prompt_logprobs_array = None # Pad KV cache and attention mask From 10ff606f4febc31457d8afa48493b073b635d528 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 15:57:39 -0800 Subject: [PATCH 043/133] fix: use float32 and per-model tolerances in test_compute_logits - Force float32 for our model to match HF for accurate comparison - Use per-model tolerances: 3e-2 for llama3, 5e-4 for qwen3 (llama3 has larger numerical differences, see test_llama3.py) Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_models_common.py | 22 +++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 9e932f439..5d5b989d3 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -13,17 +13,21 @@ @pytest.mark.parametrize( - "model_name,config_cls,model_cls,mesh_axes", + "model_name,config_cls,model_cls,mesh_axes,tol", [ - ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp")), - ("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), + # llama3 has larger numerical differences (see test_llama3.py which uses 5e-2 for hidden states) + ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp"), 3e-2), + ("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp"), 5e-4), ], ids=["llama3", "qwen3"], ) -def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): +def test_compute_logits(model_name, config_cls, model_cls, mesh_axes, tol): """Test that model.compute_logits matches HuggingFace logits.""" tokenizer = AutoTokenizer.from_pretrained(model_name) - hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) + # Load HF model in float32 for the comparison (our model will also use float32) + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, attn_implementation="eager", use_safetensors=True + ) inputs = ["The capital of France is", "Hello world"] batch = tokenizer(inputs, return_tensors="pt", padding=True) @@ -35,7 +39,9 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) mesh = jax.make_mesh((1, 1), mesh_axes) with jax.set_mesh(mesh): - model = model_cls(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) + import jax.numpy as jnp + # Use float32 to match HF model for accurate comparison + model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) load_safetensors(tmp, config, model) # Get HF logits @@ -44,6 +50,6 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): # Get our logits via compute_logits outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) - our_logits = model.compute_logits(outputs.last_hidden_state) + our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state)) - np.testing.assert_allclose(our_logits, hf_logits, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose(our_logits, hf_logits, rtol=tol, atol=tol) From 0781e2050c759d35c1f117d3a2ec4676dde0f024 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 15:57:39 -0800 Subject: [PATCH 044/133] fix: use float32 and per-model tolerances in test_compute_logits - Force float32 for our model to match HF for accurate comparison - Use per-model tolerances: 3e-2 for llama3, 5e-4 for qwen3 (llama3 has larger numerical differences, see test_llama3.py) Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_models_common.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 9e932f439..27099fc3a 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -23,7 +23,10 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): """Test that model.compute_logits matches HuggingFace logits.""" tokenizer = AutoTokenizer.from_pretrained(model_name) - hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) + # Load HF model in float32 for the comparison (our model will also use float32) + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, attn_implementation="eager", use_safetensors=True + ) inputs = ["The capital of France is", "Hello world"] batch = tokenizer(inputs, return_tensors="pt", padding=True) @@ -35,7 +38,9 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) mesh = jax.make_mesh((1, 1), mesh_axes) with jax.set_mesh(mesh): - model = model_cls(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) + import jax.numpy as jnp + # Use float32 to match HF model for accurate comparison + model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) load_safetensors(tmp, config, model) # Get HF logits @@ -44,6 +49,7 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): # Get our logits via compute_logits outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) - our_logits = model.compute_logits(outputs.last_hidden_state) + our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state)) - np.testing.assert_allclose(our_logits, hf_logits, rtol=1e-4, atol=1e-4) + # Use loose tolerance due to numerical differences (see test_llama3.py which uses 5e-2) + np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2) From ff949dff8eaa3a9f8d4370caeefa7d8a9b6befab Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 16:15:08 -0800 Subject: [PATCH 045/133] remove comment --- skyrl-tx/tests/models/test_models_common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 27099fc3a..bc73b720d 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -9,7 +9,7 @@ from tx.models.configs import Llama3Config, Qwen3Config from tx.models.llama3 import Llama3ForCausalLM from tx.models.qwen3 import Qwen3ForCausalLM -from tx.utils.models import get_dtype, load_safetensors +from tx.utils.models import load_safetensors @pytest.mark.parametrize( @@ -51,5 +51,4 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state)) - # Use loose tolerance due to numerical differences (see test_llama3.py which uses 5e-2) np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2) From 1bde686de4f172a08c084936ef336c7ce0fcf12a Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 16:19:13 -0800 Subject: [PATCH 046/133] remove comment --- skyrl-tx/tests/models/test_models_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 127710a65..2bbf179be 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -64,7 +64,6 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state)) - # Use loose tolerance due to numerical differences (see test_llama3.py which uses 5e-2) np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2) From 42ef8f03e2060716d0237845056d8c99ed340674 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 16:22:53 -0800 Subject: [PATCH 047/133] lint --- skyrl-tx/tests/models/test_models_common.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index bc73b720d..dda0994dc 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -24,9 +24,7 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): """Test that model.compute_logits matches HuggingFace logits.""" tokenizer = AutoTokenizer.from_pretrained(model_name) # Load HF model in float32 for the comparison (our model will also use float32) - hf_model = AutoModelForCausalLM.from_pretrained( - model_name, attn_implementation="eager", use_safetensors=True - ) + hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) inputs = ["The capital of France is", "Hello world"] batch = tokenizer(inputs, return_tensors="pt", padding=True) @@ -39,6 +37,7 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): mesh = jax.make_mesh((1, 1), mesh_axes) with jax.set_mesh(mesh): import jax.numpy as jnp + # Use float32 to match HF model for accurate comparison model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) load_safetensors(tmp, config, model) From 8831bf20ba8930a5d3ebc07c54d493c646622854 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 16:23:20 -0800 Subject: [PATCH 048/133] lint --- skyrl-tx/tests/models/test_models_common.py | 12 +++--------- skyrl-tx/tx/tinker/backends/jax.py | 3 +++ skyrl-tx/tx/utils/logits_processor.py | 4 +--- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 2bbf179be..6ea506de8 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -23,9 +23,7 @@ def make_model(model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size= """Create a model with the given config.""" tokenizer = AutoTokenizer.from_pretrained(model_name) # Load HF model in float32 for the comparison (our model will also use float32) - hf_model = AutoModelForCausalLM.from_pretrained( - model_name, attn_implementation="eager", use_safetensors=True - ) + hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) with tempfile.TemporaryDirectory() as tmp: hf_model.save_pretrained(tmp, safe_serialization=True) @@ -71,12 +69,8 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): @pytest.mark.parametrize("chunk_size", [8, 16, 32]) def test_chunked_logprobs(model_name, config_cls, model_cls, mesh_axes, chunk_size): """Test that chunked and non-chunked compute_logprobs produce identical results.""" - model_chunked, tokenizer, _ = make_model( - model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size - ) - model_nonchunked, _, _ = make_model( - model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0 - ) + model_chunked, tokenizer, _ = make_model(model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size) + model_nonchunked, _, _ = make_model(model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0) inputs = ["The capital of France is", "Hello world"] batch = tokenizer(inputs, return_tensors="pt", padding=True) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 86b92b3d1..6cdba72a9 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -236,6 +236,7 @@ def _jit_timing_context(self, seq_len: int, mode: str): def _create_loss_and_grad_fn(self): """Compile and cache the loss function to avoid re-jitting on every call.""" + def _model_forward( graphdef: nnx.GraphDef, lora_params: nnx.State, @@ -255,8 +256,10 @@ def _model_forward( ) # Check at runtime if any adapter in batch needs LoRA on lm_head needs_lm_head_lora = train_unembed_mask[adapter_indices].any() + def logprobs(lm_head_adapter_indices): return model.compute_logprobs(output.last_hidden_state, target_ids, lm_head_adapter_indices) + return jax.lax.cond(needs_lm_head_lora, lambda: logprobs(adapter_indices), lambda: logprobs(None)) if self.config.gradient_checkpointing: diff --git a/skyrl-tx/tx/utils/logits_processor.py b/skyrl-tx/tx/utils/logits_processor.py index 8aa773c92..60a9b225f 100644 --- a/skyrl-tx/tx/utils/logits_processor.py +++ b/skyrl-tx/tx/utils/logits_processor.py @@ -62,9 +62,7 @@ def compute_logprobs( # Chunked path doesn't support LoRA on lm_head use_chunk = chunk_size > 0 and adapter_indices is None if use_chunk: - return self._compute_chunked_logprobs( - hidden_states, target_ids, chunk_size - ) + return self._compute_chunked_logprobs(hidden_states, target_ids, chunk_size) else: logits = self.compute_logits(hidden_states, adapter_indices) return self.logits_to_logprobs(logits, target_ids) From 07b7be769463c0e82793f1eb78f1bd4e1880680e Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 16:38:04 -0800 Subject: [PATCH 049/133] empty From d55e04cd29b0fd97d8fa10d71068f34a386ad1a4 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 17:24:44 -0800 Subject: [PATCH 050/133] refactor: use lm_head() in chunked path to support LoRA - Remove get_lm_head_weight() abstract method (no longer needed) - Chunked path now uses lm_head() directly instead of raw matmul - Expand adapter_indices from [B] to [B*T] for per-token handling - Remove restriction that disabled chunking with adapter_indices Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/utils/test_generator.py | 6 ---- skyrl-tx/tx/models/llama3.py | 7 ---- skyrl-tx/tx/models/qwen3.py | 7 ---- skyrl-tx/tx/utils/logits_processor.py | 48 ++++++++++++++++++-------- 4 files changed, 34 insertions(+), 34 deletions(-) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 49d67197b..2679b69f6 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock from flax import nnx -import jax import jax.numpy as jnp from tx.models.types import CausalLMOutput from tx.tinker.types import SamplingParams @@ -19,7 +18,6 @@ class DummyModel(GeneratorMixin, LogitsProcessorMixin, nnx.Module): def __init__(self, vocab_size: int = 16): self.config = MagicMock(loss_chunk_size=0, gradient_checkpointing=False) self.vocab_size = vocab_size - self._lm_head_weight = jnp.eye(vocab_size, dtype=jnp.float32) def lm_head(hidden_states, adapter_indices=None): # Scale logits by (1 + adapter_index) so different adapters give different log-softmax results @@ -34,10 +32,6 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head - def get_lm_head_weight(self) -> jax.Array: - """Return identity matrix for dummy model.""" - return self._lm_head_weight - def __call__( self, input_ids, diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 46adab27b..b7eb14d52 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -287,13 +287,6 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head - def get_lm_head_weight(self) -> jax.Array: - """Return the lm_head weight [H, V] for chunked cross-entropy.""" - if self.config.tie_word_embeddings: - return self.model.embed_tokens.embedding[...].T - else: - return self.lm_head.kernel[...] - @staticmethod def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 9614a0136..fdf68ee48 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -402,13 +402,6 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head - def get_lm_head_weight(self) -> jax.Array: - """Return the lm_head weight [H, V] for chunked cross-entropy.""" - if self.config.tie_word_embeddings: - return self.model.embed_tokens.embedding[...].T - else: - return self.lm_head.kernel[...] - @staticmethod def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" diff --git a/skyrl-tx/tx/utils/logits_processor.py b/skyrl-tx/tx/utils/logits_processor.py index 60a9b225f..3a62f1a3e 100644 --- a/skyrl-tx/tx/utils/logits_processor.py +++ b/skyrl-tx/tx/utils/logits_processor.py @@ -20,10 +20,6 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" ... - @abstractmethod - def get_lm_head_weight(self) -> jax.Array: - """Return the lm_head weight matrix [H, V] for efficient chunked computation.""" - ... def compute_logits( self, @@ -53,16 +49,13 @@ def compute_logprobs( hidden_states: Hidden states [B, T, H]. target_ids: Target token IDs [B, T]. adapter_indices: Adapter indices for LoRA on lm_head. - Pass when train_unembed=True. Forces non-chunked path. Returns: Log probabilities for target tokens [B, T]. """ chunk_size = self.config.loss_chunk_size - # Chunked path doesn't support LoRA on lm_head - use_chunk = chunk_size > 0 and adapter_indices is None - if use_chunk: - return self._compute_chunked_logprobs(hidden_states, target_ids, chunk_size) + if chunk_size > 0: + return self._compute_chunked_logprobs(hidden_states, target_ids, chunk_size, adapter_indices) else: logits = self.compute_logits(hidden_states, adapter_indices) return self.logits_to_logprobs(logits, target_ids) @@ -87,6 +80,7 @@ def _compute_chunked_logprobs( hidden_states: jax.Array, target_ids: jax.Array, chunk_size: int, + adapter_indices: jax.Array | None, ) -> jax.Array: """Compute log probabilities using chunked lm_head computation. @@ -95,12 +89,18 @@ def _compute_chunked_logprobs( """ B, T, H = hidden_states.shape total_tokens = B * T - lm_head_weight = self.get_lm_head_weight() + lm_head = self.get_lm_head() # Flatten batch and sequence dimensions flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] flat_target_ids = target_ids.reshape(-1) # [B*T] + # Expand adapter_indices from [B] to [B*T] by repeating each T times + if adapter_indices is not None: + flat_adapter_indices = jnp.repeat(adapter_indices, T) # [B*T] + else: + flat_adapter_indices = None + # Pad to multiple of chunk_size for clean slicing num_chunks = (total_tokens + chunk_size - 1) // chunk_size padded_size = num_chunks * chunk_size @@ -109,16 +109,22 @@ def _compute_chunked_logprobs( if pad_amount > 0: flat_hidden = jnp.pad(flat_hidden, ((0, pad_amount), (0, 0))) flat_target_ids = jnp.pad(flat_target_ids, (0, pad_amount)) + if flat_adapter_indices is not None: + flat_adapter_indices = jnp.pad(flat_adapter_indices, (0, pad_amount)) # Reshape into chunks: [num_chunks, chunk_size, H] and [num_chunks, chunk_size] chunked_hidden = flat_hidden.reshape(num_chunks, chunk_size, H) chunked_targets = flat_target_ids.reshape(num_chunks, chunk_size) + if flat_adapter_indices is not None: + chunked_adapter_indices = flat_adapter_indices.reshape(num_chunks, chunk_size) + else: + chunked_adapter_indices = None def compute_chunk_logprobs(args): """Compute lm_head and log probabilities for a chunk of tokens.""" - chunk_hidden, chunk_targets = args - # Compute logits for this chunk only: [chunk_size, H] @ [H, V] = [chunk_size, V] - chunk_logits = chunk_hidden @ lm_head_weight + chunk_hidden, chunk_targets, chunk_adapters = args + # Compute logits for this chunk: [chunk_size, H] -> [chunk_size, V] + chunk_logits = lm_head(chunk_hidden, chunk_adapters) # Compute log probabilities log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) @@ -128,6 +134,20 @@ def compute_chunk_logprobs(args): compute_chunk_logprobs = jax.checkpoint(compute_chunk_logprobs, policy=None) # Process chunks sequentially using lax.map (not vmap) to reduce memory - all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets)) + if chunked_adapter_indices is not None: + all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets, chunked_adapter_indices)) + else: + # Create dummy array for lax.map (needs consistent structure) + dummy_adapters = jnp.zeros((num_chunks, chunk_size), dtype=jnp.int32) + def compute_chunk_logprobs_no_adapter(args): + chunk_hidden, chunk_targets, _ = args + chunk_logits = lm_head(chunk_hidden, None) + log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) + target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) + return (target_logits - log_sum_exp).squeeze(-1) + if self.config.gradient_checkpointing: + compute_chunk_logprobs_no_adapter = jax.checkpoint(compute_chunk_logprobs_no_adapter, policy=None) + all_logprobs = jax.lax.map(compute_chunk_logprobs_no_adapter, (chunked_hidden, chunked_targets, dummy_adapters)) + # Flatten and slice to original size, then reshape to [B, T] return all_logprobs.reshape(-1)[:total_tokens].reshape(B, T) From 006d4128a77019344d379b177bde399ab0671c2c Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 17:40:36 -0800 Subject: [PATCH 051/133] cleanup: remove _train_unembed_mask and simplify chunked lm_head - Remove _train_unembed_mask tracking from JaxBackend - Simplify _model_forward to always pass adapter_indices to compute_logprobs - Fix chunked path to reshape hidden states to [chunk_size, 1, H] for LoRA compatibility Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/tinker/test_jax_backend.py | 2 +- skyrl-tx/tx/tinker/backends/jax.py | 18 ++---------------- skyrl-tx/tx/utils/logits_processor.py | 10 +++++++--- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index cc60043f3..74787df9f 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -559,7 +559,7 @@ def test_adapter_reuse_initializes_lora_adapter(): def test_mixed_train_unembed_adapters(): - """Test that backend correctly routes to chunked/non-chunked path based on train_unembed.""" + """Test that chunked and non-chunked paths produce same results with train_unembed adapters.""" config_chunked = JaxBackendConfig(max_lora_adapters=3, max_lora_rank=32, loss_chunk_size=1024) config_nonchunked = JaxBackendConfig(max_lora_adapters=3, max_lora_rank=32, loss_chunk_size=0) backend_chunked = JaxBackend(BASE_MODEL, config_chunked) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 6cdba72a9..a0a7a6dd6 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -206,8 +206,6 @@ def __init__(self, base_model: str, config: JaxBackendConfig): f"max_lora_adapters={config.max_lora_adapters}, max_lora_rank={config.max_lora_rank}" ) - # Track which adapters use train_unembed=True (requires LoRA on lm_head) - self._train_unembed_mask = jnp.zeros(config.max_lora_adapters, dtype=jnp.bool_) self._create_loss_and_grad_fn() def _micro_batch_size(self, total: int) -> int: @@ -245,7 +243,6 @@ def _model_forward( attention_mask: jax.Array, adapter_indices: jax.Array, target_ids: jax.Array, - train_unembed_mask: jax.Array, ) -> jax.Array: """Forward pass and logprobs computation.""" model = nnx.merge(graphdef, lora_params, non_lora_params) @@ -254,13 +251,7 @@ def _model_forward( attention_mask=attention_mask, adapter_indices=adapter_indices, ) - # Check at runtime if any adapter in batch needs LoRA on lm_head - needs_lm_head_lora = train_unembed_mask[adapter_indices].any() - - def logprobs(lm_head_adapter_indices): - return model.compute_logprobs(output.last_hidden_state, target_ids, lm_head_adapter_indices) - - return jax.lax.cond(needs_lm_head_lora, lambda: logprobs(adapter_indices), lambda: logprobs(None)) + return model.compute_logprobs(output.last_hidden_state, target_ids, adapter_indices) if self.config.gradient_checkpointing: # Wrap the model forward call to use jax.checkpoint for gradient checkpointing @@ -287,7 +278,6 @@ def loss_for_lora( attention_mask, adapter_indices, target_ids, - self._train_unembed_mask, ) def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): @@ -459,9 +449,6 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: if not (0 < lora_config.rank <= self.config.max_lora_rank): raise ValueError(f"LoRA rank {lora_config.rank} must be between 1 and {self.config.max_lora_rank}") - # Set train_unembed mask for this adapter - self._train_unembed_mask = self._train_unembed_mask.at[adapter_index].set(lora_config.train_unembed) - # Store model metadata self.models[model_id] = types.ModelMetadata( adapter_index=adapter_index, @@ -485,10 +472,9 @@ def delete_model(self, model_id: str) -> None: # Get adapter index before deleting metadata adapter_index = self.models[model_id].adapter_index - # Clear LoRA adapter weights and reset train_unembed mask + # Clear LoRA adapter weights with jax.set_mesh(self.mesh): clear_lora_adapter(self.model, adapter_index) - self._train_unembed_mask = self._train_unembed_mask.at[adapter_index].set(False) # Delete optimizer del self.optimizers[model_id] diff --git a/skyrl-tx/tx/utils/logits_processor.py b/skyrl-tx/tx/utils/logits_processor.py index 3a62f1a3e..cbea3e9fb 100644 --- a/skyrl-tx/tx/utils/logits_processor.py +++ b/skyrl-tx/tx/utils/logits_processor.py @@ -123,8 +123,11 @@ def _compute_chunked_logprobs( def compute_chunk_logprobs(args): """Compute lm_head and log probabilities for a chunk of tokens.""" chunk_hidden, chunk_targets, chunk_adapters = args - # Compute logits for this chunk: [chunk_size, H] -> [chunk_size, V] - chunk_logits = lm_head(chunk_hidden, chunk_adapters) + # Reshape to [chunk_size, 1, H] for lm_head (batch=chunk_size, seq=1) + # This allows LoRA to work with per-token adapter indices + chunk_hidden_3d = chunk_hidden[:, None, :] + # Compute logits: [chunk_size, 1, H] -> [chunk_size, 1, V] -> [chunk_size, V] + chunk_logits = lm_head(chunk_hidden_3d, chunk_adapters)[:, 0, :] # Compute log probabilities log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) @@ -141,7 +144,8 @@ def compute_chunk_logprobs(args): dummy_adapters = jnp.zeros((num_chunks, chunk_size), dtype=jnp.int32) def compute_chunk_logprobs_no_adapter(args): chunk_hidden, chunk_targets, _ = args - chunk_logits = lm_head(chunk_hidden, None) + chunk_hidden_3d = chunk_hidden[:, None, :] + chunk_logits = lm_head(chunk_hidden_3d, None)[:, 0, :] log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) return (target_logits - log_sum_exp).squeeze(-1) From b2f8eba1de7040fb1835edc59ec1e4534b361e72 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 17:57:52 -0800 Subject: [PATCH 052/133] refactor: compute adapter indices on-the-fly in chunked path Instead of allocating [B*T] array via jnp.repeat, compute adapter indices per-chunk using only a [chunk_size] buffer. This reduces memory overhead significantly for long sequences. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/utils/logits_processor.py | 46 ++++++++------------------- 1 file changed, 14 insertions(+), 32 deletions(-) diff --git a/skyrl-tx/tx/utils/logits_processor.py b/skyrl-tx/tx/utils/logits_processor.py index cbea3e9fb..066aaf9f1 100644 --- a/skyrl-tx/tx/utils/logits_processor.py +++ b/skyrl-tx/tx/utils/logits_processor.py @@ -95,12 +95,6 @@ def _compute_chunked_logprobs( flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] flat_target_ids = target_ids.reshape(-1) # [B*T] - # Expand adapter_indices from [B] to [B*T] by repeating each T times - if adapter_indices is not None: - flat_adapter_indices = jnp.repeat(adapter_indices, T) # [B*T] - else: - flat_adapter_indices = None - # Pad to multiple of chunk_size for clean slicing num_chunks = (total_tokens + chunk_size - 1) // chunk_size padded_size = num_chunks * chunk_size @@ -109,22 +103,26 @@ def _compute_chunked_logprobs( if pad_amount > 0: flat_hidden = jnp.pad(flat_hidden, ((0, pad_amount), (0, 0))) flat_target_ids = jnp.pad(flat_target_ids, (0, pad_amount)) - if flat_adapter_indices is not None: - flat_adapter_indices = jnp.pad(flat_adapter_indices, (0, pad_amount)) # Reshape into chunks: [num_chunks, chunk_size, H] and [num_chunks, chunk_size] chunked_hidden = flat_hidden.reshape(num_chunks, chunk_size, H) chunked_targets = flat_target_ids.reshape(num_chunks, chunk_size) - if flat_adapter_indices is not None: - chunked_adapter_indices = flat_adapter_indices.reshape(num_chunks, chunk_size) - else: - chunked_adapter_indices = None + + # Precompute position offsets for adapter index lookup (reused buffer of chunk_size) + position_offsets = jnp.arange(chunk_size) + # Pad adapter_indices to avoid out-of-bounds when chunk spans past B + if adapter_indices is None: + adapter_indices = jnp.zeros(B, dtype=jnp.int32) + padded_adapter_indices = jnp.pad(adapter_indices, (0, 1)) # [B+1] for safe indexing def compute_chunk_logprobs(args): """Compute lm_head and log probabilities for a chunk of tokens.""" - chunk_hidden, chunk_targets, chunk_adapters = args + chunk_idx, chunk_hidden, chunk_targets = args + # Compute adapter indices on-the-fly from chunk position + flat_positions = chunk_idx * chunk_size + position_offsets + batch_indices = flat_positions // T + chunk_adapters = padded_adapter_indices[batch_indices] # [chunk_size] # Reshape to [chunk_size, 1, H] for lm_head (batch=chunk_size, seq=1) - # This allows LoRA to work with per-token adapter indices chunk_hidden_3d = chunk_hidden[:, None, :] # Compute logits: [chunk_size, 1, H] -> [chunk_size, 1, V] -> [chunk_size, V] chunk_logits = lm_head(chunk_hidden_3d, chunk_adapters)[:, 0, :] @@ -136,22 +134,6 @@ def compute_chunk_logprobs(args): if self.config.gradient_checkpointing: compute_chunk_logprobs = jax.checkpoint(compute_chunk_logprobs, policy=None) - # Process chunks sequentially using lax.map (not vmap) to reduce memory - if chunked_adapter_indices is not None: - all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets, chunked_adapter_indices)) - else: - # Create dummy array for lax.map (needs consistent structure) - dummy_adapters = jnp.zeros((num_chunks, chunk_size), dtype=jnp.int32) - def compute_chunk_logprobs_no_adapter(args): - chunk_hidden, chunk_targets, _ = args - chunk_hidden_3d = chunk_hidden[:, None, :] - chunk_logits = lm_head(chunk_hidden_3d, None)[:, 0, :] - log_sum_exp = jax.nn.logsumexp(chunk_logits, axis=-1, keepdims=True) - target_logits = jnp.take_along_axis(chunk_logits, chunk_targets[..., None], axis=-1) - return (target_logits - log_sum_exp).squeeze(-1) - if self.config.gradient_checkpointing: - compute_chunk_logprobs_no_adapter = jax.checkpoint(compute_chunk_logprobs_no_adapter, policy=None) - all_logprobs = jax.lax.map(compute_chunk_logprobs_no_adapter, (chunked_hidden, chunked_targets, dummy_adapters)) - - # Flatten and slice to original size, then reshape to [B, T] + chunk_indices = jnp.arange(num_chunks) + all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunk_indices, chunked_hidden, chunked_targets)) return all_logprobs.reshape(-1)[:total_tokens].reshape(B, T) From a82cd53c7783a1781302ceb65cc113e59411a58a Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 18:04:53 -0800 Subject: [PATCH 053/133] fix: load one model at a time in test_compute_logits to avoid OOM Load HF model, get logits, save weights, delete HF model, then load our model. This avoids having both models in memory simultaneously. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_models_common.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index dda0994dc..df5ed3667 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -2,6 +2,7 @@ from flax import nnx import jax +import jax.numpy as jnp import numpy as np import pytest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer @@ -23,29 +24,28 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): """Test that model.compute_logits matches HuggingFace logits.""" tokenizer = AutoTokenizer.from_pretrained(model_name) - # Load HF model in float32 for the comparison (our model will also use float32) - hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) inputs = ["The capital of France is", "Hello world"] batch = tokenizer(inputs, return_tensors="pt", padding=True) with tempfile.TemporaryDirectory() as tmp: + # Load HF model, get logits, save weights, then delete to free memory + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, attn_implementation="eager", use_safetensors=True + ) + hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask) + hf_logits = hf_outputs.logits.detach().numpy() hf_model.save_pretrained(tmp, safe_serialization=True) + del hf_model, hf_outputs + # Load our model from saved weights base_config = AutoConfig.from_pretrained(model_name) config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) mesh = jax.make_mesh((1, 1), mesh_axes) with jax.set_mesh(mesh): - import jax.numpy as jnp - - # Use float32 to match HF model for accurate comparison model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) load_safetensors(tmp, config, model) - # Get HF logits - hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask) - hf_logits = hf_outputs.logits.detach().numpy() - # Get our logits via compute_logits outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state)) From 345d5c15db8219b2c257d7b10bac971593853d86 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 18:12:44 -0800 Subject: [PATCH 054/133] lint --- skyrl-tx/tests/models/test_models_common.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index df5ed3667..ff78e6a39 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -30,9 +30,7 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): with tempfile.TemporaryDirectory() as tmp: # Load HF model, get logits, save weights, then delete to free memory - hf_model = AutoModelForCausalLM.from_pretrained( - model_name, attn_implementation="eager", use_safetensors=True - ) + hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask) hf_logits = hf_outputs.logits.detach().numpy() hf_model.save_pretrained(tmp, safe_serialization=True) From 2f78babdb4c7fecca3ee5d23c69ed73a8f7613cf Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 22 Jan 2026 18:18:22 -0800 Subject: [PATCH 055/133] fix: add missing config args and restore test_chunked_logprobs - Add loss_chunk_size and gradient_checkpointing to config in tests - Restore test_chunked_logprobs test that was lost during merge Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_models_common.py | 77 ++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index ff78e6a39..37059af31 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -12,6 +12,12 @@ from tx.models.qwen3 import Qwen3ForCausalLM from tx.utils.models import load_safetensors +MODEL_PARAMS = [ + ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp")), + ("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), +] +MODEL_IDS = ["llama3", "qwen3"] + @pytest.mark.parametrize( "model_name,config_cls,model_cls,mesh_axes", @@ -38,7 +44,14 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): # Load our model from saved weights base_config = AutoConfig.from_pretrained(model_name) - config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) + config = config_cls( + base_config, + max_lora_adapters=1, + max_lora_rank=1, + shard_attention_heads=True, + loss_chunk_size=0, + gradient_checkpointing=False, + ) mesh = jax.make_mesh((1, 1), mesh_axes) with jax.set_mesh(mesh): model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) @@ -49,3 +62,65 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state)) np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2) + + +def make_model(model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0): + """Create a model with the given config.""" + tokenizer = AutoTokenizer.from_pretrained(model_name) + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, attn_implementation="eager", use_safetensors=True + ) + + with tempfile.TemporaryDirectory() as tmp: + hf_model.save_pretrained(tmp, safe_serialization=True) + del hf_model + + base_config = AutoConfig.from_pretrained(model_name) + config = config_cls( + base_config, + max_lora_adapters=1, + max_lora_rank=1, + shard_attention_heads=True, + loss_chunk_size=loss_chunk_size, + gradient_checkpointing=False, + ) + mesh = jax.make_mesh((1, 1), mesh_axes) + with jax.set_mesh(mesh): + model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + load_safetensors(tmp, config, model) + + return model, tokenizer + + +@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) +@pytest.mark.parametrize("chunk_size", [8, 16, 32]) +def test_chunked_logprobs(model_name, config_cls, model_cls, mesh_axes, chunk_size): + """Test that chunked and non-chunked compute_logprobs produce identical results.""" + model_chunked, tokenizer = make_model( + model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size + ) + model_nonchunked, _ = make_model( + model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0 + ) + + inputs = ["The capital of France is", "Hello world"] + batch = tokenizer(inputs, return_tensors="pt", padding=True) + input_ids = jnp.array(batch.input_ids.numpy()) + attention_mask = jnp.array(batch.attention_mask.numpy()) + target_ids = jnp.roll(input_ids, -1, axis=1) + + # Get hidden states + outputs_chunked = model_chunked(input_ids, attention_mask=attention_mask) + outputs_nonchunked = model_nonchunked(input_ids, attention_mask=attention_mask) + + # Compute logprobs with both methods + logprobs_chunked = model_chunked.compute_logprobs(outputs_chunked.last_hidden_state, target_ids) + logprobs_nonchunked = model_nonchunked.compute_logprobs(outputs_nonchunked.last_hidden_state, target_ids) + + np.testing.assert_allclose( + np.asarray(logprobs_chunked), + np.asarray(logprobs_nonchunked), + rtol=1e-4, + atol=1e-4, + err_msg=f"Chunked vs non-chunked logprobs mismatch for chunk_size={chunk_size}", + ) From e0cb768bd42cdc45a7e28a9198e716ad38dac85e Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 23 Jan 2026 08:17:39 -0800 Subject: [PATCH 056/133] test: load one model at a time in test_chunked_logprobs Restructure test to avoid OOM by loading and deleting models sequentially instead of having two models in memory simultaneously. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_models_common.py | 75 ++++++++++----------- 1 file changed, 34 insertions(+), 41 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 4068a1e86..6d7d608a1 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -64,62 +64,55 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2) -def make_model(model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0): - """Create a model with the given config.""" - tokenizer = AutoTokenizer.from_pretrained(model_name) - - with tempfile.TemporaryDirectory() as tmp: - # Load HF model, save weights, then delete to free memory - hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) - hf_model.save_pretrained(tmp, safe_serialization=True) - del hf_model - - # Load our model from saved weights - base_config = AutoConfig.from_pretrained(model_name) - config = config_cls( - base_config, - max_lora_adapters=1, - max_lora_rank=1, - shard_attention_heads=True, - loss_chunk_size=loss_chunk_size, - gradient_checkpointing=False, - ) - mesh = jax.make_mesh((1, 1), mesh_axes) - with jax.set_mesh(mesh): - model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model) - - return model, tokenizer +def load_model(tmp_dir, model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0): + """Load model from pre-saved weights directory.""" + base_config = AutoConfig.from_pretrained(model_name) + config = config_cls( + base_config, + max_lora_adapters=1, + max_lora_rank=1, + shard_attention_heads=True, + loss_chunk_size=loss_chunk_size, + gradient_checkpointing=False, + ) + mesh = jax.make_mesh((1, 1), mesh_axes) + with jax.set_mesh(mesh): + model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + load_safetensors(tmp_dir, config, model) + return model @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) @pytest.mark.parametrize("chunk_size", [8, 16, 32]) def test_chunked_logprobs(model_name, config_cls, model_cls, mesh_axes, chunk_size): """Test that chunked and non-chunked compute_logprobs produce identical results.""" - model_chunked, tokenizer = make_model( - model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size - ) - model_nonchunked, _ = make_model( - model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0 - ) - + tokenizer = AutoTokenizer.from_pretrained(model_name) inputs = ["The capital of France is", "Hello world"] batch = tokenizer(inputs, return_tensors="pt", padding=True) input_ids = jnp.array(batch.input_ids.numpy()) attention_mask = jnp.array(batch.attention_mask.numpy()) target_ids = jnp.roll(input_ids, -1, axis=1) - # Get hidden states - outputs_chunked = model_chunked(input_ids, attention_mask=attention_mask) - outputs_nonchunked = model_nonchunked(input_ids, attention_mask=attention_mask) + with tempfile.TemporaryDirectory() as tmp: + # Save HF weights once + hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) + hf_model.save_pretrained(tmp, safe_serialization=True) + del hf_model + + # Load non-chunked model, compute logprobs, then delete + model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0) + outputs = model(input_ids, attention_mask=attention_mask) + logprobs_nonchunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids)) + del model, outputs - # Compute logprobs with both methods - logprobs_chunked = model_chunked.compute_logprobs(outputs_chunked.last_hidden_state, target_ids) - logprobs_nonchunked = model_nonchunked.compute_logprobs(outputs_nonchunked.last_hidden_state, target_ids) + # Load chunked model, compute logprobs + model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size) + outputs = model(input_ids, attention_mask=attention_mask) + logprobs_chunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids)) np.testing.assert_allclose( - np.asarray(logprobs_chunked), - np.asarray(logprobs_nonchunked), + logprobs_chunked, + logprobs_nonchunked, rtol=1e-4, atol=1e-4, err_msg=f"Chunked vs non-chunked logprobs mismatch for chunk_size={chunk_size}", From 9d9079540fbf603e4eeef6cc170a4b4d257a93e4 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 23 Jan 2026 08:19:56 -0800 Subject: [PATCH 057/133] test: load one backend at a time in test_mixed_train_unembed_adapters Restructure test to avoid OOM by creating and deleting backends sequentially instead of having two in memory simultaneously. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/tinker/test_jax_backend.py | 57 ++++++++++++----------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index 74787df9f..c5242737b 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -560,30 +560,29 @@ def test_adapter_reuse_initializes_lora_adapter(): def test_mixed_train_unembed_adapters(): """Test that chunked and non-chunked paths produce same results with train_unembed adapters.""" - config_chunked = JaxBackendConfig(max_lora_adapters=3, max_lora_rank=32, loss_chunk_size=1024) - config_nonchunked = JaxBackendConfig(max_lora_adapters=3, max_lora_rank=32, loss_chunk_size=0) - backend_chunked = JaxBackend(BASE_MODEL, config_chunked) - backend_nonchunked = JaxBackend(BASE_MODEL, config_nonchunked) - # Create same models on both backends - for backend in [backend_chunked, backend_nonchunked]: + def create_backend_and_models(loss_chunk_size): + config = JaxBackendConfig(max_lora_adapters=3, max_lora_rank=32, loss_chunk_size=loss_chunk_size) + backend = JaxBackend(BASE_MODEL, config) backend.create_model("model_normal", LoraConfig(rank=8, alpha=16, seed=0, train_unembed=False)) backend.create_model("model_unembed", LoraConfig(rank=8, alpha=16, seed=1, train_unembed=True)) + return backend - normal_idx = backend_chunked.models["model_normal"].adapter_index - unembed_idx = backend_chunked.models["model_unembed"].adapter_index + def run_forward(backend): + normal_idx = backend.models["model_normal"].adapter_index + unembed_idx = backend.models["model_unembed"].adapter_index - batch_size, seq_len = 2, 16 - vocab = backend_chunked.model.config.vocab_size - input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab - attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - target_ids = (input_ids + 1) % vocab - loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) - loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) - sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) - advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + batch_size, seq_len = 2, 16 + vocab = backend.model.config.vocab_size + input_ids = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape(batch_size, seq_len) % vocab + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + target_ids = (input_ids + 1) % vocab + loss_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32) + loss_fn_types = jnp.zeros((batch_size,), dtype=jnp.int32) + sampling_logprobs = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + advantages = jnp.zeros((batch_size, seq_len), dtype=jnp.float32) + adapter_indices = jnp.array([normal_idx, unembed_idx], dtype=jnp.int32) - def run_forward(backend, adapter_indices): _, losses, logprobs = backend._forward( backend.accumulated_grads, backend.lora_params, @@ -597,23 +596,27 @@ def run_forward(backend, adapter_indices): sampling_logprobs, advantages, ) - return losses, logprobs + return np.asarray(losses), np.asarray(logprobs) + + # Run non-chunked backend first, then delete + backend = create_backend_and_models(loss_chunk_size=0) + losses_nonchunked, logprobs_nonchunked = run_forward(backend) + del backend - # Test with mixed adapters: one normal, one unembed - adapter_indices = jnp.array([normal_idx, unembed_idx], dtype=jnp.int32) - losses_chunked, logprobs_chunked = run_forward(backend_chunked, adapter_indices) - losses_nonchunked, logprobs_nonchunked = run_forward(backend_nonchunked, adapter_indices) + # Run chunked backend + backend = create_backend_and_models(loss_chunk_size=1024) + losses_chunked, logprobs_chunked = run_forward(backend) np.testing.assert_allclose( - np.asarray(logprobs_chunked), - np.asarray(logprobs_nonchunked), + logprobs_chunked, + logprobs_nonchunked, rtol=1e-4, atol=1e-4, err_msg="Chunked vs non-chunked logprobs mismatch with mixed train_unembed adapters", ) np.testing.assert_allclose( - np.asarray(losses_chunked), - np.asarray(losses_nonchunked), + losses_chunked, + losses_nonchunked, rtol=1e-4, atol=1e-4, err_msg="Chunked vs non-chunked losses mismatch with mixed train_unembed adapters", From d5a213340698858243ef1b85fc6b47ca1aab29c1 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 23 Jan 2026 08:23:06 -0800 Subject: [PATCH 058/133] inherit --- skyrl-tx/tx/utils/logits_processor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/utils/logits_processor.py b/skyrl-tx/tx/utils/logits_processor.py index 71e2409f1..fb2ea95c8 100644 --- a/skyrl-tx/tx/utils/logits_processor.py +++ b/skyrl-tx/tx/utils/logits_processor.py @@ -5,13 +5,14 @@ import jax import jax.numpy as jnp +from tx.models.types import ModelForCausalLM # lm_head: (hidden_states, adapter_indices) -> logits LMHead = Callable[[jax.Array, jax.Array | None], jax.Array] -class LogitsProcessorMixin: +class LogitsProcessorMixin(ModelForCausalLM): """Mixin providing logits/logprobs computation for causal language models.""" @abstractmethod From 4e39b49365a0169a4cfdf8447b2a68503e25432c Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 23 Jan 2026 08:39:52 -0800 Subject: [PATCH 059/133] test: add unit tests for chunked logprobs edge cases Test coverage for: - Chunk boundary cases (padding, exact division, larger than total) - Adapter indices handling (None, per-batch, same for all) - Gradient checkpointing flag Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/utils/test_logits_processor.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 skyrl-tx/tests/utils/test_logits_processor.py diff --git a/skyrl-tx/tests/utils/test_logits_processor.py b/skyrl-tx/tests/utils/test_logits_processor.py new file mode 100644 index 000000000..a2f6c253b --- /dev/null +++ b/skyrl-tx/tests/utils/test_logits_processor.py @@ -0,0 +1,118 @@ +"""Unit tests for LogitsProcessorMixin chunked logprobs computation.""" + +from unittest.mock import MagicMock + +from flax import nnx +import jax.numpy as jnp +import numpy as np +import pytest + +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead + + +class DummyLogitsModel(LogitsProcessorMixin, nnx.Module): + """Minimal model for testing logits processor. + + Uses identity lm_head: logits = hidden_states (requires H == V). + When adapter_indices is provided, scales by (1 + adapter_index). + """ + + def __init__(self, vocab_size: int = 16, loss_chunk_size: int = 0): + self.config = MagicMock(loss_chunk_size=loss_chunk_size, gradient_checkpointing=False) + self.vocab_size = vocab_size + + def get_lm_head(self) -> LMHead: + def lm_head(hidden_states, adapter_indices=None): + if adapter_indices is not None: + scale = (1 + adapter_indices[:, None, None]).astype(jnp.float32) + return hidden_states * scale + return hidden_states + + return lm_head + + +def assert_chunked_matches_nonchunked( + hidden_states: jnp.ndarray, + target_ids: jnp.ndarray, + chunk_size: int, + adapter_indices: jnp.ndarray | None = None, + vocab_size: int = 16, +): + """Assert chunked and non-chunked paths produce identical results.""" + model_chunked = DummyLogitsModel(vocab_size=vocab_size, loss_chunk_size=chunk_size) + model_nonchunked = DummyLogitsModel(vocab_size=vocab_size, loss_chunk_size=0) + + logprobs_chunked = model_chunked.compute_logprobs(hidden_states, target_ids, adapter_indices) + logprobs_nonchunked = model_nonchunked.compute_logprobs(hidden_states, target_ids, adapter_indices) + + B, T = target_ids.shape + assert logprobs_chunked.shape == (B, T) + assert logprobs_nonchunked.shape == (B, T) + + np.testing.assert_allclose( + np.asarray(logprobs_chunked), + np.asarray(logprobs_nonchunked), + rtol=1e-5, + atol=1e-5, + ) + + +class TestChunkedLogprobs: + """Tests for chunked vs non-chunked logprobs computation.""" + + @pytest.mark.parametrize("B,T,chunk_size", [ + (2, 4, 3), # chunk doesn't divide evenly, needs padding + (2, 4, 8), # chunk equals B*T exactly + (2, 4, 16), # chunk larger than B*T + (1, 8, 3), # single batch element + (4, 1, 2), # single token per sequence + (1, 1, 1), # minimal case + ]) + def test_chunk_boundary_cases(self, B, T, chunk_size): + """Test various chunk size vs total token relationships.""" + V = 16 # vocab_size = hidden_size for identity lm_head + hidden_states = jnp.arange(B * T * V, dtype=jnp.float32).reshape(B, T, V) / (B * T * V) + target_ids = jnp.arange(B * T, dtype=jnp.int32).reshape(B, T) % V + + assert_chunked_matches_nonchunked(hidden_states, target_ids, chunk_size, vocab_size=V) + + @pytest.mark.parametrize("B,T,chunk_size,adapter_indices", [ + (2, 4, 3, None), # no adapters + (2, 4, 3, "arange"), # different adapter per batch, chunk spans boundary + (3, 4, 5, "arange"), # chunk spans multiple batches + (4, 2, 3, "zeros"), # all same adapter + ]) + def test_adapter_indices_handling(self, B, T, chunk_size, adapter_indices): + """Test adapter indices are correctly mapped across chunk boundaries.""" + V = 16 + hidden_states = jnp.arange(B * T * V, dtype=jnp.float32).reshape(B, T, V) / (B * T * V) + target_ids = jnp.arange(B * T, dtype=jnp.int32).reshape(B, T) % V + + if adapter_indices == "arange": + adapter_indices = jnp.arange(B, dtype=jnp.int32) + elif adapter_indices == "zeros": + adapter_indices = jnp.zeros(B, dtype=jnp.int32) + + assert_chunked_matches_nonchunked(hidden_states, target_ids, chunk_size, adapter_indices, vocab_size=V) + + def test_gradient_checkpointing_flag(self): + """Gradient checkpointing should not affect forward pass results.""" + B, T, V, chunk_size = 2, 4, 16, 3 + hidden_states = jnp.arange(B * T * V, dtype=jnp.float32).reshape(B, T, V) / (B * T * V) + target_ids = jnp.arange(B * T, dtype=jnp.int32).reshape(B, T) % V + + model_no_ckpt = DummyLogitsModel(vocab_size=V, loss_chunk_size=chunk_size) + model_no_ckpt.config.gradient_checkpointing = False + + model_ckpt = DummyLogitsModel(vocab_size=V, loss_chunk_size=chunk_size) + model_ckpt.config.gradient_checkpointing = True + + logprobs_no_ckpt = model_no_ckpt.compute_logprobs(hidden_states, target_ids) + logprobs_ckpt = model_ckpt.compute_logprobs(hidden_states, target_ids) + + np.testing.assert_allclose( + np.asarray(logprobs_no_ckpt), + np.asarray(logprobs_ckpt), + rtol=1e-5, + atol=1e-5, + ) From 0925010ed6ed0d503efcdd6929932f8123f15644 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 23 Jan 2026 08:51:41 -0800 Subject: [PATCH 060/133] lint --- skyrl-tx/tests/utils/test_logits_processor.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/skyrl-tx/tests/utils/test_logits_processor.py b/skyrl-tx/tests/utils/test_logits_processor.py index a2f6c253b..404206be9 100644 --- a/skyrl-tx/tests/utils/test_logits_processor.py +++ b/skyrl-tx/tests/utils/test_logits_processor.py @@ -60,14 +60,17 @@ def assert_chunked_matches_nonchunked( class TestChunkedLogprobs: """Tests for chunked vs non-chunked logprobs computation.""" - @pytest.mark.parametrize("B,T,chunk_size", [ - (2, 4, 3), # chunk doesn't divide evenly, needs padding - (2, 4, 8), # chunk equals B*T exactly - (2, 4, 16), # chunk larger than B*T - (1, 8, 3), # single batch element - (4, 1, 2), # single token per sequence - (1, 1, 1), # minimal case - ]) + @pytest.mark.parametrize( + "B,T,chunk_size", + [ + (2, 4, 3), # chunk doesn't divide evenly, needs padding + (2, 4, 8), # chunk equals B*T exactly + (2, 4, 16), # chunk larger than B*T + (1, 8, 3), # single batch element + (4, 1, 2), # single token per sequence + (1, 1, 1), # minimal case + ], + ) def test_chunk_boundary_cases(self, B, T, chunk_size): """Test various chunk size vs total token relationships.""" V = 16 # vocab_size = hidden_size for identity lm_head @@ -76,12 +79,15 @@ def test_chunk_boundary_cases(self, B, T, chunk_size): assert_chunked_matches_nonchunked(hidden_states, target_ids, chunk_size, vocab_size=V) - @pytest.mark.parametrize("B,T,chunk_size,adapter_indices", [ - (2, 4, 3, None), # no adapters - (2, 4, 3, "arange"), # different adapter per batch, chunk spans boundary - (3, 4, 5, "arange"), # chunk spans multiple batches - (4, 2, 3, "zeros"), # all same adapter - ]) + @pytest.mark.parametrize( + "B,T,chunk_size,adapter_indices", + [ + (2, 4, 3, None), # no adapters + (2, 4, 3, "arange"), # different adapter per batch, chunk spans boundary + (3, 4, 5, "arange"), # chunk spans multiple batches + (4, 2, 3, "zeros"), # all same adapter + ], + ) def test_adapter_indices_handling(self, B, T, chunk_size, adapter_indices): """Test adapter indices are correctly mapped across chunk boundaries.""" V = 16 From fa93a014be08faeaeaa00720cf412f0bd09fe10a Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 23 Jan 2026 09:16:17 -0800 Subject: [PATCH 061/133] default values --- skyrl-tx/tx/models/configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/models/configs.py b/skyrl-tx/tx/models/configs.py index f7b8cc78d..8a3ce3ae4 100644 --- a/skyrl-tx/tx/models/configs.py +++ b/skyrl-tx/tx/models/configs.py @@ -32,8 +32,8 @@ def __init__( max_lora_adapters: int, max_lora_rank: int, shard_attention_heads: bool, - loss_chunk_size: int, - gradient_checkpointing: bool, + loss_chunk_size: int = 0, + gradient_checkpointing: bool = False, ): # Copy all attributes from the base config super().__init__(**config.to_dict()) From 445a4c84d0043f846e1ee77cf102bc89d453335f Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 23 Jan 2026 11:42:07 -0800 Subject: [PATCH 062/133] empty From 1eca13760aae6a3c3241ca90e41ef32ceff3db01 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 23 Jan 2026 12:33:19 -0800 Subject: [PATCH 063/133] minor cleanup --- skyrl-tx/tests/models/test_models_common.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index b28755875..c23d4d590 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -19,17 +19,12 @@ LLAMA3_MODEL = "unsloth/Llama-3.2-1B" MODEL_PARAMS = [ - ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp")), - ("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), + (LLAMA3_MODEL, Llama3Config, Llama3ForCausalLM, ("dp", "tp")), + (QWEN3_MODEL, Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), ] MODEL_IDS = ["llama3", "qwen3"] -# ============================================================================= -# Gradient Checkpointing Tests -# ============================================================================= - - def create_qwen3_model(): """Create Qwen3 model for testing.""" base_config = PretrainedConfig.from_pretrained(QWEN3_MODEL) @@ -108,11 +103,6 @@ def test_is_training_false_uses_standard_path(self, create_model): assert len(out.kv_cache.keys) == config.num_hidden_layers -# ============================================================================= -# Chunked Logprobs Tests -# ============================================================================= - - @pytest.mark.parametrize( "model_name,config_cls,model_cls,mesh_axes", [ From 0ef5ea39ce855bc23aed55c67b198aa0d053c88a Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 23 Jan 2026 13:02:25 -0800 Subject: [PATCH 064/133] refactor: extract forward layer utilities to reduce duplication Move _forward_layers_checkpointed and _forward_layers from Llama3Model and Qwen3Model into shared utility functions in tx/models/utils.py. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_models_common.py | 58 +++++------- skyrl-tx/tx/models/llama3.py | 92 ++----------------- skyrl-tx/tx/models/qwen3.py | 96 ++------------------ skyrl-tx/tx/models/utils.py | 98 +++++++++++++++++++++ 4 files changed, 130 insertions(+), 214 deletions(-) create mode 100644 skyrl-tx/tx/models/utils.py diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index c23d4d590..d371dc8c1 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import numpy as np import pytest -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from tx.models.configs import Llama3Config, Qwen3Config from tx.models.llama3 import Llama3ForCausalLM @@ -15,42 +15,29 @@ from tx.utils.models import load_safetensors -QWEN3_MODEL = "Qwen/Qwen3-0.6B" -LLAMA3_MODEL = "unsloth/Llama-3.2-1B" - MODEL_PARAMS = [ - (LLAMA3_MODEL, Llama3Config, Llama3ForCausalLM, ("dp", "tp")), - (QWEN3_MODEL, Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), + ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp")), + ("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), ] MODEL_IDS = ["llama3", "qwen3"] -def create_qwen3_model(): - """Create Qwen3 model for testing.""" - base_config = PretrainedConfig.from_pretrained(QWEN3_MODEL) - config = Qwen3Config(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) - mesh = jax.make_mesh((1, 1), ("fsdp", "tp")) - with jax.set_mesh(mesh): - model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) - return model, config - - -def create_llama3_model(): - """Create Llama3 model for testing.""" - base_config = AutoConfig.from_pretrained(LLAMA3_MODEL) - config = Llama3Config(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) - mesh = jax.make_mesh((1, 1), ("dp", "tp")) +def create_model(model_name, config_cls, model_cls, mesh_axes): + """Create model with random weights for testing.""" + base_config = AutoConfig.from_pretrained(model_name) + config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) + mesh = jax.make_mesh((1, 1), mesh_axes) with jax.set_mesh(mesh): - model = Llama3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) return model, config -@pytest.mark.parametrize("create_model", [create_qwen3_model, create_llama3_model], ids=["qwen3", "llama3"]) +@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) class TestGradientCheckpointing: - def test_output_matches_non_checkpointed(self, create_model): + def test_output_matches_non_checkpointed(self, model_name, config_cls, model_cls, mesh_axes): """Forward pass should produce identical outputs with/without checkpointing.""" - model, config = create_model() + model, config = create_model(model_name, config_cls, model_cls, mesh_axes) batch_size, seq_len = 2, 8 input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) @@ -59,16 +46,18 @@ def test_output_matches_non_checkpointed(self, create_model): # Run without checkpointing config.gradient_checkpointing = False out_no_ckpt = model(input_ids, attention_mask=attention_mask, is_training=True) + logits_no_ckpt = model.compute_logits(out_no_ckpt.last_hidden_state) # Run with checkpointing config.gradient_checkpointing = True out_ckpt = model(input_ids, attention_mask=attention_mask, is_training=True) + logits_ckpt = model.compute_logits(out_ckpt.last_hidden_state) - np.testing.assert_allclose(out_no_ckpt.logits, out_ckpt.logits, rtol=1e-4, atol=1e-6) + np.testing.assert_allclose(logits_no_ckpt, logits_ckpt, rtol=1e-4, atol=1e-6) - def test_hidden_states_length_matches(self, create_model): + def test_hidden_states_length_matches(self, model_name, config_cls, model_cls, mesh_axes): """Both paths should return same number of hidden states.""" - model, config = create_model() + model, config = create_model(model_name, config_cls, model_cls, mesh_axes) batch_size, seq_len = 2, 8 input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) @@ -88,9 +77,9 @@ def test_hidden_states_length_matches(self, create_model): hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}" ) - def test_is_training_false_uses_standard_path(self, create_model): + def test_is_training_false_uses_standard_path(self, model_name, config_cls, model_cls, mesh_axes): """is_training=False should use standard path with KV cache support.""" - model, config = create_model() + model, config = create_model(model_name, config_cls, model_cls, mesh_axes) config.gradient_checkpointing = True batch_size, seq_len = 2, 8 @@ -103,14 +92,7 @@ def test_is_training_false_uses_standard_path(self, create_model): assert len(out.kv_cache.keys) == config.num_hidden_layers -@pytest.mark.parametrize( - "model_name,config_cls,model_cls,mesh_axes", - [ - ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp")), - ("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), - ], - ids=["llama3", "qwen3"], -) +@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): """Test that model.compute_logits matches HuggingFace logits.""" tokenizer = AutoTokenizer.from_pretrained(model_name) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 41abce6fa..269991529 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -7,6 +7,7 @@ from tx.layers.lora import LoRAEmbed, LoRALinear from tx.layers.rotary_embedding import apply_rope from tx.layers.layernorm import RMSNorm +from tx.models.utils import forward_layers, forward_layers_checkpointed from tx.utils.logits_processor import LogitsProcessorMixin, LMHead from tx.models.types import CausalLMOutput, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache, compute_positions @@ -232,14 +233,14 @@ def __call__( ) hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - all_hidden_states: list[jax.Array] = [] # Checkpointing: use scan so XLA compiles ONE loop body and reuses # buffers during recomputation. Without checkpointing, activations are # stored anyway, so scan's buffer reuse doesn't help and its weight # stacking overhead makes it worse. if is_training and self.config.gradient_checkpointing: - hidden_states, all_hidden_states = self._forward_layers_checkpointed( + hidden_states, all_hidden_states = forward_layers_checkpointed( + self.layers, hidden_states, attention_mask=attention_mask, positions=positions, @@ -249,14 +250,14 @@ def __call__( updated_keys, updated_values = [], [] new_cache_position = input_ids.shape[1] else: - hidden_states, updated_keys, updated_values = self._forward_layers( + hidden_states, all_hidden_states, updated_keys, updated_values = forward_layers( + self.layers, hidden_states, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=kv_cache, output_hidden_states=output_hidden_states, - all_hidden_states=all_hidden_states, ) new_cache_position = kv_cache.cache_position + 1 if kv_cache else input_ids.shape[1] @@ -270,89 +271,6 @@ def __call__( hidden_states=all_hidden_states if output_hidden_states else None, ) - def _forward_layers_checkpointed( - self, - hidden_states: jax.Array, - *, - attention_mask: jax.Array, - positions: jax.Array, - adapter_indices: jax.Array | None, - output_hidden_states: bool, - ) -> tuple[jax.Array, list[jax.Array]]: - """Forward pass with gradient checkpointing using scan. - - Uses scan so XLA compiles ONE loop body and reuses buffers during - backward recomputation. With a Python loop, XLA unrolls N separate - checkpoint regions and can't optimize buffer reuse across them. - - Tradeoff: requires stacking all layer weights once per forward pass. - This is acceptable because checkpointing already trades compute for memory. - - TODO(haochen): Load weights directly into stacked format to avoid 2x memory. - Currently we have both self.layers (original) and stacked copy during forward. - """ - num_layers = len(self.layers) - if num_layers == 0: - return hidden_states, [] - - # Stack layer weights for dynamic indexing in scan - layer_graphdef, _ = nnx.split(self.layers[0]) - stacked_weights = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *[nnx.state(layer) for layer in self.layers]) - - def body_fn(hs, i): - layer_weights = jax.tree.map(lambda x: x[i], stacked_weights) - layer = nnx.merge(layer_graphdef, layer_weights) - hs, _ = layer( - hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=None - ) - return hs, hs if output_hidden_states else None - - body_fn = jax.checkpoint(body_fn) - final_hs, all_hs = jax.lax.scan(body_fn, hidden_states, jnp.arange(num_layers)) - - if output_hidden_states: - # all_hs is [num_layers, batch, seq, hidden]. Exclude last layer output since - # it gets normed and appended in __call__ (matching non-checkpointed path). - all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers - 1)] - else: - all_hidden_states = [] - - return final_hs, all_hidden_states - - def _forward_layers( - self, - hidden_states: jax.Array, - *, - attention_mask: jax.Array, - positions: jax.Array, - adapter_indices: jax.Array | None, - kv_cache: KVCache | None, - output_hidden_states: bool, - all_hidden_states: list[jax.Array], - ) -> tuple[jax.Array, list[jax.Array], list[jax.Array]]: - """Standard forward pass through decoder layers. - - Used for inference (with KV cache) and training without checkpointing. - """ - updated_keys, updated_values = [], [] - - for layer_idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) - - layer_kv = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) - hidden_states, (k, v) = layer( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=layer_kv, - ) - updated_keys.append(k) - updated_values.append(v) - - return hidden_states, updated_keys, updated_values - class Llama3ForCausalLM(nnx.Module, GeneratorMixin, LogitsProcessorMixin): diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 0a32228e8..562eadcb7 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -6,11 +6,12 @@ from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear from tx.layers.util import prepare_routing, shard_map_ep from tx.layers.rotary_embedding import apply_rope -from tx.utils.logits_processor import LogitsProcessorMixin, LMHead -from tx.models.configs import Qwen3Config from tx.layers.layernorm import RMSNorm +from tx.models.configs import Qwen3Config from tx.models.types import CausalLMOutput, ModelOutput +from tx.models.utils import forward_layers, forward_layers_checkpointed from tx.utils.generator import GeneratorMixin, KVCache, compute_positions +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead class Qwen3Attention(nnx.Module): @@ -347,14 +348,14 @@ def __call__( ) hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - all_hidden_states: list[jax.Array] = [] # Checkpointing: use scan so XLA compiles ONE loop body and reuses # buffers during recomputation. Without checkpointing, activations are # stored anyway, so scan's buffer reuse doesn't help and its weight # stacking overhead makes it worse. if is_training and self.config.gradient_checkpointing: - hidden_states, all_hidden_states = self._forward_layers_checkpointed( + hidden_states, all_hidden_states = forward_layers_checkpointed( + self.layers, hidden_states, attention_mask=attention_mask, positions=positions, @@ -364,14 +365,14 @@ def __call__( updated_keys, updated_values = [], [] new_cache_position = input_ids.shape[1] else: - hidden_states, updated_keys, updated_values = self._forward_layers( + hidden_states, all_hidden_states, updated_keys, updated_values = forward_layers( + self.layers, hidden_states, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=kv_cache, output_hidden_states=output_hidden_states, - all_hidden_states=all_hidden_states, ) new_cache_position = kv_cache.cache_position + 1 if kv_cache else input_ids.shape[1] @@ -385,89 +386,6 @@ def __call__( hidden_states=all_hidden_states if output_hidden_states else None, ) - def _forward_layers_checkpointed( - self, - hidden_states: jax.Array, - *, - attention_mask: jax.Array, - positions: jax.Array, - adapter_indices: jax.Array | None, - output_hidden_states: bool, - ) -> tuple[jax.Array, list[jax.Array]]: - """Forward pass with gradient checkpointing using scan. - - Uses scan so XLA compiles ONE loop body and reuses buffers during - backward recomputation. With a Python loop, XLA unrolls N separate - checkpoint regions and can't optimize buffer reuse across them. - - Tradeoff: requires stacking all layer weights once per forward pass. - This is acceptable because checkpointing already trades compute for memory. - - TODO(haochen): Load weights directly into stacked format to avoid 2x memory. - Currently we have both self.layers (original) and stacked copy during forward. - """ - num_layers = len(self.layers) - if num_layers == 0: - return hidden_states, [] - - # Stack layer weights for dynamic indexing in scan - layer_graphdef, _ = nnx.split(self.layers[0]) - stacked_weights = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *[nnx.state(layer) for layer in self.layers]) - - def body_fn(hs, i): - layer_weights = jax.tree.map(lambda x: x[i], stacked_weights) - layer = nnx.merge(layer_graphdef, layer_weights) - hs, _ = layer( - hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=None - ) - return hs, hs if output_hidden_states else None - - body_fn = jax.checkpoint(body_fn) - final_hs, all_hs = jax.lax.scan(body_fn, hidden_states, jnp.arange(num_layers)) - - if output_hidden_states: - # all_hs is [num_layers, batch, seq, hidden]. Exclude last layer output since - # it gets normed and appended in __call__ (matching non-checkpointed path). - all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers - 1)] - else: - all_hidden_states = [] - - return final_hs, all_hidden_states - - def _forward_layers( - self, - hidden_states: jax.Array, - *, - attention_mask: jax.Array, - positions: jax.Array, - adapter_indices: jax.Array | None, - kv_cache: KVCache | None, - output_hidden_states: bool, - all_hidden_states: list[jax.Array], - ) -> tuple[jax.Array, list[jax.Array], list[jax.Array]]: - """Standard forward pass through decoder layers. - - Used for inference (with KV cache) and training without checkpointing. - """ - updated_keys, updated_values = [], [] - - for layer_idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) - - layer_kv = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) - hidden_states, (k, v) = layer( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=layer_kv, - ) - updated_keys.append(k) - updated_values.append(v) - - return hidden_states, updated_keys, updated_values - class Qwen3ForCausalLM(nnx.Module, GeneratorMixin, LogitsProcessorMixin): diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py new file mode 100644 index 000000000..e1b8df15c --- /dev/null +++ b/skyrl-tx/tx/models/utils.py @@ -0,0 +1,98 @@ +"""Utility functions for model forward passes.""" + +from flax import nnx +import jax +from jax import numpy as jnp + +from tx.utils.generator import KVCache + + +def forward_layers_checkpointed( + layers: nnx.List, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + output_hidden_states: bool, +) -> tuple[jax.Array, list[jax.Array]]: + """Forward pass with gradient checkpointing using scan. + + Uses scan so XLA compiles ONE loop body and reuses buffers during + backward recomputation. With a Python loop, XLA unrolls N separate + checkpoint regions and can't optimize buffer reuse across them. + + Tradeoff: requires stacking all layer weights once per forward pass. + This is acceptable because checkpointing already trades compute for memory. + + TODO(haochen): Load weights directly into stacked format to avoid 2x memory. + Currently we have both self.layers (original) and stacked copy during forward. + """ + num_layers = len(layers) + if num_layers == 0: + return hidden_states, [] + + # Stack layer weights for dynamic indexing in scan + layer_graphdef, _ = nnx.split(layers[0]) + stacked_weights = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *[nnx.state(layer) for layer in layers]) + + def body_fn(hs, i): + layer_weights = jax.tree.map(lambda x: x[i], stacked_weights) + layer = nnx.merge(layer_graphdef, layer_weights) + hs, _ = layer( + hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=None + ) + return hs, hs if output_hidden_states else None + + body_fn = jax.checkpoint(body_fn) + final_hs, all_hs = jax.lax.scan(body_fn, hidden_states, jnp.arange(num_layers)) + + if output_hidden_states: + # all_hs is [num_layers, batch, seq, hidden]. Exclude last layer output since + # it gets normed and appended in __call__ (matching non-checkpointed path). + all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers - 1)] + else: + all_hidden_states = [] + + return final_hs, all_hidden_states + + +def forward_layers( + layers: nnx.List, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + kv_cache: KVCache | None, + output_hidden_states: bool, +) -> tuple[jax.Array, list[jax.Array], list[jax.Array], list[jax.Array]]: + """Standard forward pass through decoder layers. + + Used for inference (with KV cache) and training without checkpointing. + + Returns: + hidden_states: Final hidden states after all layers + all_hidden_states: List of hidden states from each layer (if output_hidden_states) + updated_keys: List of updated key caches + updated_values: List of updated value caches + """ + all_hidden_states: list[jax.Array] = [] + updated_keys, updated_values = [], [] + + for layer_idx, layer in enumerate(layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + layer_kv = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position) + hidden_states, (k, v) = layer( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=layer_kv, + ) + updated_keys.append(k) + updated_values.append(v) + + return hidden_states, all_hidden_states, updated_keys, updated_values From 572a6974f2a48ddf70f0059df1a9a8584c9804f3 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 26 Jan 2026 10:31:50 -0800 Subject: [PATCH 065/133] fix: remove unused new_cache_position variable KVCache.update() handles cache position internally, so this variable is no longer needed after the KVCache API refactor. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/llama3.py | 2 -- skyrl-tx/tx/models/qwen3.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 6244cc504..93ed35021 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -240,7 +240,6 @@ def __call__( output_hidden_states=output_hidden_states, ) updated_keys, updated_values = [], [] - new_cache_position = input_ids.shape[1] else: hidden_states, all_hidden_states, updated_keys, updated_values = forward_layers( self.layers, @@ -251,7 +250,6 @@ def __call__( kv_cache=kv_cache, output_hidden_states=output_hidden_states, ) - new_cache_position = kv_cache.cache_position + 1 if kv_cache else input_ids.shape[1] hidden_states = self.norm(hidden_states) if output_hidden_states: diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 1edf2d111..9a0f66d6d 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -355,7 +355,6 @@ def __call__( output_hidden_states=output_hidden_states, ) updated_keys, updated_values = [], [] - new_cache_position = input_ids.shape[1] else: hidden_states, all_hidden_states, updated_keys, updated_values = forward_layers( self.layers, @@ -366,7 +365,6 @@ def __call__( kv_cache=kv_cache, output_hidden_states=output_hidden_states, ) - new_cache_position = kv_cache.cache_position + 1 if kv_cache else input_ids.shape[1] hidden_states = self.norm(hidden_states) if output_hidden_states: From 2c5b3a7fd087ad5428684cae7401a5b42af06024 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 26 Jan 2026 10:34:40 -0800 Subject: [PATCH 066/133] remove comments --- skyrl-tx/tx/models/llama3.py | 4 ---- skyrl-tx/tx/models/qwen3.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 93ed35021..62bd4814a 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -226,10 +226,6 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - # Checkpointing: use scan so XLA compiles ONE loop body and reuses - # buffers during recomputation. Without checkpointing, activations are - # stored anyway, so scan's buffer reuse doesn't help and its weight - # stacking overhead makes it worse. if is_training and self.config.gradient_checkpointing: hidden_states, all_hidden_states = forward_layers_checkpointed( self.layers, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 9a0f66d6d..9a4c505a7 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -341,10 +341,6 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - # Checkpointing: use scan so XLA compiles ONE loop body and reuses - # buffers during recomputation. Without checkpointing, activations are - # stored anyway, so scan's buffer reuse doesn't help and its weight - # stacking overhead makes it worse. if is_training and self.config.gradient_checkpointing: hidden_states, all_hidden_states = forward_layers_checkpointed( self.layers, From 246c2af2d6d1755e5e7c5362e0fce22ce746b41a Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 26 Jan 2026 10:39:30 -0800 Subject: [PATCH 067/133] fix --- skyrl-tx/tx/tinker/backends/jax.py | 4 ---- skyrl-tx/tx/utils/logits_processor.py | 3 +-- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 7a8d05803..e067028af 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -87,10 +87,6 @@ class JaxBackendConfig(BaseModel, extra="forbid"): default=1024, description="Chunk size for cross-entropy loss computation. Reduces memory by avoiding full [B*T, V] logits materialization. Set to 0 to disable chunking.", ) - loss_chunk_size: int = Field( - default=1024, - description="Chunk size for cross-entropy loss computation. Reduces memory by avoiding full [B*T, V] logits materialization. Set to 0 to disable chunking.", - ) # Multi-node configuration coordinator_address: str | None = Field( default=None, diff --git a/skyrl-tx/tx/utils/logits_processor.py b/skyrl-tx/tx/utils/logits_processor.py index 620c30f08..4cc9e1613 100644 --- a/skyrl-tx/tx/utils/logits_processor.py +++ b/skyrl-tx/tx/utils/logits_processor.py @@ -5,14 +5,13 @@ import jax import jax.numpy as jnp -from tx.models.types import ModelForCausalLM # lm_head: (hidden_states, adapter_indices) -> logits LMHead = Callable[[jax.Array, jax.Array | None], jax.Array] -class LogitsProcessorMixin(ModelForCausalLM): +class LogitsProcessorMixin: """Mixin providing logits/logprobs computation for causal language models.""" @abstractmethod From 159dc82ff4e08f83fa564860ca8b6eeb936847f2 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 26 Jan 2026 11:00:04 -0800 Subject: [PATCH 068/133] remove comment --- skyrl-tx/tests/models/test_models_common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 6e9233a53..e71b13179 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -1,5 +1,3 @@ -"""Common tests for models.""" - import tempfile from flax import nnx From 58527c72d8f83a936f212ed7b6a3d651164edb34 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 26 Jan 2026 12:07:34 -0800 Subject: [PATCH 069/133] unify forward_layers --- skyrl-tx/tx/models/llama3.py | 33 ++++++++------------ skyrl-tx/tx/models/qwen3.py | 33 ++++++++------------ skyrl-tx/tx/models/utils.py | 59 ++++++++++++++++++++++++++++-------- 3 files changed, 71 insertions(+), 54 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 62bd4814a..fb28c5c21 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -8,7 +8,7 @@ from tx.layers.lora import LoRAEmbed, LoRALinear from tx.layers.rotary_embedding import apply_rope from tx.layers.layernorm import RMSNorm -from tx.models.utils import forward_layers, forward_layers_checkpointed +from tx.models.utils import forward_layers from tx.utils.logits_processor import LogitsProcessorMixin, LMHead from tx.models.types import CausalLMOutput, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache @@ -226,26 +226,17 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - if is_training and self.config.gradient_checkpointing: - hidden_states, all_hidden_states = forward_layers_checkpointed( - self.layers, - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - output_hidden_states=output_hidden_states, - ) - updated_keys, updated_values = [], [] - else: - hidden_states, all_hidden_states, updated_keys, updated_values = forward_layers( - self.layers, - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=kv_cache, - output_hidden_states=output_hidden_states, - ) + hidden_states, all_hidden_states, updated_keys, updated_values = forward_layers( + self.layers, + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + output_hidden_states=output_hidden_states, + is_training=is_training, + gradient_checkpointing=self.config.gradient_checkpointing, + ) hidden_states = self.norm(hidden_states) if output_hidden_states: diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 9a4c505a7..4e9e4c2f6 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -10,7 +10,7 @@ from tx.layers.layernorm import RMSNorm from tx.models.configs import Qwen3Config from tx.models.types import CausalLMOutput, ModelOutput -from tx.models.utils import forward_layers, forward_layers_checkpointed +from tx.models.utils import forward_layers from tx.utils.generator import GeneratorMixin, KVCache from tx.utils.logits_processor import LogitsProcessorMixin, LMHead @@ -341,26 +341,17 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - if is_training and self.config.gradient_checkpointing: - hidden_states, all_hidden_states = forward_layers_checkpointed( - self.layers, - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - output_hidden_states=output_hidden_states, - ) - updated_keys, updated_values = [], [] - else: - hidden_states, all_hidden_states, updated_keys, updated_values = forward_layers( - self.layers, - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=kv_cache, - output_hidden_states=output_hidden_states, - ) + hidden_states, all_hidden_states, updated_keys, updated_values = forward_layers( + self.layers, + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + output_hidden_states=output_hidden_states, + is_training=is_training, + gradient_checkpointing=self.config.gradient_checkpointing, + ) hidden_states = self.norm(hidden_states) if output_hidden_states: diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 8a40670b9..0bdfb5e38 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -7,7 +7,7 @@ from tx.utils.generator import KVCache -def forward_layers_checkpointed( +def _forward_layers_checkpointed( layers: nnx.List, hidden_states: jax.Array, *, @@ -57,7 +57,7 @@ def body_fn(hs, i): return final_hs, all_hidden_states -def forward_layers( +def _forward_layers_standard( layers: nnx.List, hidden_states: jax.Array, *, @@ -67,16 +67,7 @@ def forward_layers( kv_cache: KVCache | None, output_hidden_states: bool, ) -> tuple[jax.Array, list[jax.Array], list[jax.Array], list[jax.Array]]: - """Standard forward pass through decoder layers. - - Used for inference (with KV cache) and training without checkpointing. - - Returns: - hidden_states: Final hidden states after all layers - all_hidden_states: List of hidden states from each layer (if output_hidden_states) - updated_keys: List of updated key caches - updated_values: List of updated value caches - """ + """Standard forward pass through decoder layers.""" all_hidden_states: list[jax.Array] = [] updated_keys, updated_values = [], [] @@ -96,3 +87,47 @@ def forward_layers( updated_values.append(v) return hidden_states, all_hidden_states, updated_keys, updated_values + + +def forward_layers( + layers: nnx.List, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + kv_cache: KVCache | None, + output_hidden_states: bool, + is_training: bool, + gradient_checkpointing: bool, +) -> tuple[jax.Array, list[jax.Array], list[jax.Array], list[jax.Array]]: + """Forward pass through decoder layers with optional gradient checkpointing. + + Chooses between checkpointed (scan-based) and standard (loop-based) paths. + + Returns: + hidden_states: Final hidden states after all layers + all_hidden_states: List of hidden states from each layer (if output_hidden_states) + updated_keys: List of updated key caches (empty if checkpointing) + updated_values: List of updated value caches (empty if checkpointing) + """ + if is_training and gradient_checkpointing: + hidden_states, all_hidden_states = _forward_layers_checkpointed( + layers, + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + output_hidden_states=output_hidden_states, + ) + return hidden_states, all_hidden_states, [], [] + else: + return _forward_layers_standard( + layers, + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + output_hidden_states=output_hidden_states, + ) From 53316f72d5d8a53eeb2490903af60be0c086ba56 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 26 Jan 2026 12:25:43 -0800 Subject: [PATCH 070/133] model.train() --- skyrl-tx/tests/models/test_models_common.py | 17 ++++++++++------- skyrl-tx/tx/models/llama3.py | 12 ++++++++---- skyrl-tx/tx/models/qwen3.py | 12 ++++++++---- skyrl-tx/tx/models/utils.py | 4 ++-- skyrl-tx/tx/tinker/backends/jax.py | 3 +-- 5 files changed, 29 insertions(+), 19 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index e71b13179..c6c8220c0 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -60,12 +60,13 @@ def test_output_matches_non_checkpointed(self, model_name, config_cls, model_cls # Run without checkpointing config.gradient_checkpointing = False - out_no_ckpt = model(input_ids, attention_mask=attention_mask, is_training=True) + model.train() + out_no_ckpt = model(input_ids, attention_mask=attention_mask) logits_no_ckpt = model.compute_logits(out_no_ckpt.last_hidden_state) # Run with checkpointing config.gradient_checkpointing = True - out_ckpt = model(input_ids, attention_mask=attention_mask, is_training=True) + out_ckpt = model(input_ids, attention_mask=attention_mask) logits_ckpt = model.compute_logits(out_ckpt.last_hidden_state) np.testing.assert_allclose(logits_no_ckpt, logits_ckpt, rtol=1e-4, atol=1e-6) @@ -79,10 +80,11 @@ def test_hidden_states_length_matches(self, model_name, config_cls, model_cls, m attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) config.gradient_checkpointing = False - out_no_ckpt = model(input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True) + model.train() + out_no_ckpt = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) config.gradient_checkpointing = True - out_ckpt = model(input_ids, attention_mask=attention_mask, output_hidden_states=True, is_training=True) + out_ckpt = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) assert len(out_no_ckpt.hidden_states) == len(out_ckpt.hidden_states) assert len(out_ckpt.hidden_states) == config.num_hidden_layers + 1 @@ -92,8 +94,8 @@ def test_hidden_states_length_matches(self, model_name, config_cls, model_cls, m hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}" ) - def test_is_training_false_uses_standard_path(self, model_name, config_cls, model_cls, mesh_axes): - """is_training=False should use standard path with KV cache support.""" + def test_eval_mode_uses_standard_path(self, model_name, config_cls, model_cls, mesh_axes): + """eval() mode should use standard path with KV cache support.""" model, config = create_model(model_name, config_cls, model_cls, mesh_axes) config.gradient_checkpointing = True @@ -101,7 +103,8 @@ def test_is_training_false_uses_standard_path(self, model_name, config_cls, mode input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - out = model(input_ids, attention_mask=attention_mask, is_training=False) + model.eval() + out = model(input_ids, attention_mask=attention_mask) # KV cache should be populated (checkpointed path returns empty) assert len(out.kv_cache.keys) == config.num_hidden_layers diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index fb28c5c21..654f31bf5 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -190,6 +190,7 @@ def __call__( class Llama3Model(nnx.Module): + training: bool = False def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config @@ -218,7 +219,6 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, - is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -234,7 +234,7 @@ def __call__( adapter_indices=adapter_indices, kv_cache=kv_cache, output_hidden_states=output_hidden_states, - is_training=is_training, + training=self.training, gradient_checkpointing=self.config.gradient_checkpointing, ) @@ -274,6 +274,12 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head + def train(self, **attributes): + return super().train(training=True, **attributes) + + def eval(self, **attributes): + return super().eval(training=False, **attributes) + @staticmethod def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" @@ -288,7 +294,6 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, - is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = jnp.arange(attention_mask.shape[1])[None, :] @@ -300,7 +305,6 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, - is_training=is_training, ) return CausalLMOutput( diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 4e9e4c2f6..217fb13b6 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -305,6 +305,7 @@ def __call__( class Qwen3Model(nnx.Module): + training: bool = False def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config @@ -333,7 +334,6 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, - is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -349,7 +349,7 @@ def __call__( adapter_indices=adapter_indices, kv_cache=kv_cache, output_hidden_states=output_hidden_states, - is_training=is_training, + training=self.training, gradient_checkpointing=self.config.gradient_checkpointing, ) @@ -389,6 +389,12 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head + def train(self, **attributes): + return super().train(training=True, **attributes) + + def eval(self, **attributes): + return super().eval(training=False, **attributes) + @staticmethod def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" @@ -403,7 +409,6 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, - is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = jnp.arange(attention_mask.shape[1])[None, :] @@ -415,7 +420,6 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, - is_training=is_training, ) return CausalLMOutput( diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 0bdfb5e38..8bc522025 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -98,7 +98,7 @@ def forward_layers( adapter_indices: jax.Array | None, kv_cache: KVCache | None, output_hidden_states: bool, - is_training: bool, + training: bool, gradient_checkpointing: bool, ) -> tuple[jax.Array, list[jax.Array], list[jax.Array], list[jax.Array]]: """Forward pass through decoder layers with optional gradient checkpointing. @@ -111,7 +111,7 @@ def forward_layers( updated_keys: List of updated key caches (empty if checkpointing) updated_values: List of updated value caches (empty if checkpointing) """ - if is_training and gradient_checkpointing: + if training and gradient_checkpointing: hidden_states, all_hidden_states = _forward_layers_checkpointed( layers, hidden_states, diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index e067028af..85a4428a4 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -245,12 +245,11 @@ def _model_forward( target_ids: jax.Array, ) -> jax.Array: """Forward pass and logprobs computation.""" - model = nnx.merge(graphdef, lora_params, non_lora_params) + model = nnx.merge(graphdef, lora_params, non_lora_params).train() output = model( input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices, - is_training=True, ) return model.compute_logprobs(output.last_hidden_state, target_ids, adapter_indices) From 113bd92a0935b96ca7778ef686531a6c0c798775 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 10:49:01 -0800 Subject: [PATCH 071/133] stack weights --- skyrl-tx/tx/models/llama3.py | 17 ++- skyrl-tx/tx/models/qwen3.py | 17 ++- skyrl-tx/tx/models/utils.py | 223 +++++++++++++++++-------------- skyrl-tx/tx/utils/generator.py | 110 ++++++++++------ skyrl-tx/tx/utils/models.py | 233 +++++++++++++++++++++++++++------ 5 files changed, 412 insertions(+), 188 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 613010275..01ed8ee69 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -8,7 +8,7 @@ from tx.layers.lora import LoRAEmbed, LoRALinear from tx.layers.rotary_embedding import apply_rope from tx.layers.layernorm import RMSNorm -from tx.models.utils import forward_layers +from tx.models.utils import create_stacked_layers, forward_layers from tx.utils.logits_processor import LogitsProcessorMixin, LMHead from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache @@ -194,6 +194,7 @@ class Llama3Model(nnx.Module): def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config + self.num_layers = config.num_hidden_layers self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, @@ -205,9 +206,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), rngs=rngs, ) - self.layers = nnx.List( - [Llama3DecoderLayer(config, dtype=dtype, rngs=rngs) for _ in range(config.num_hidden_layers)] - ) + + def create_layer(rngs: nnx.Rngs) -> Llama3DecoderLayer: + return Llama3DecoderLayer(config, dtype=dtype, rngs=rngs) + + self.layers = create_stacked_layers(create_layer, config.num_hidden_layers, rngs) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) def __call__( @@ -226,15 +229,15 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - hidden_states, all_hidden_states, updated_keys, updated_values = forward_layers( + hidden_states, all_hidden_states, new_kv_cache = forward_layers( self.layers, hidden_states, + self.num_layers, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=kv_cache, output_hidden_states=output_hidden_states, - training=self.training, gradient_checkpointing=self.config.gradient_checkpointing, ) @@ -244,7 +247,7 @@ def __call__( return ModelOutput( last_hidden_state=hidden_states, - kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask), + kv_cache=new_kv_cache, hidden_states=all_hidden_states if output_hidden_states else None, ) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 642a2a566..03914e668 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -10,7 +10,7 @@ from tx.layers.layernorm import RMSNorm from tx.models.configs import Qwen3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput -from tx.models.utils import forward_layers +from tx.models.utils import create_stacked_layers, forward_layers from tx.utils.generator import GeneratorMixin, KVCache from tx.utils.logits_processor import LogitsProcessorMixin, LMHead @@ -309,6 +309,7 @@ class Qwen3Model(nnx.Module): def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config + self.num_layers = config.num_hidden_layers self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, @@ -320,9 +321,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), rngs=rngs, ) - self.layers = nnx.List( - [Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs) for _ in range(config.num_hidden_layers)] - ) + + def create_layer(rngs: nnx.Rngs) -> Qwen3DecoderLayer: + return Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs) + + self.layers = create_stacked_layers(create_layer, config.num_hidden_layers, rngs) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) def __call__( @@ -341,15 +344,15 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - hidden_states, all_hidden_states, updated_keys, updated_values = forward_layers( + hidden_states, all_hidden_states, new_kv_cache = forward_layers( self.layers, hidden_states, + self.num_layers, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=kv_cache, output_hidden_states=output_hidden_states, - training=self.training, gradient_checkpointing=self.config.gradient_checkpointing, ) @@ -359,7 +362,7 @@ def __call__( return ModelOutput( last_hidden_state=hidden_states, - kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask), + kv_cache=new_kv_cache, hidden_states=all_hidden_states if output_hidden_states else None, ) diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 8bc522025..2a852e756 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -1,4 +1,15 @@ -"""Utility functions for model forward passes.""" +"""Utility functions for model forward passes with stacked decoder layers. + +This module provides a unified forward_layers function that works for both training +(with gradient checkpointing) and inference. The key insight is that jax.checkpoint +is a no-op when not computing gradients, so we can use the same scan-based code path. + +Prerequisites: +- Layers must be created with nnx.vmap (stacked weights) +- KVCache must use stacked format: (num_layers, batch, seq, heads, dim) +""" + +from typing import TypeVar from flax import nnx import jax @@ -6,128 +17,146 @@ from tx.utils.generator import KVCache +T = TypeVar("T", bound=nnx.Module) -def _forward_layers_checkpointed( - layers: nnx.List, - hidden_states: jax.Array, - *, - attention_mask: jax.Array, - positions: jax.Array, - adapter_indices: jax.Array | None, - output_hidden_states: bool, -) -> tuple[jax.Array, list[jax.Array]]: - """Forward pass with gradient checkpointing using scan. - - Uses scan so XLA compiles ONE loop body and reuses buffers during - backward recomputation. With a Python loop, XLA unrolls N separate - checkpoint regions and can't optimize buffer reuse across them. - Tradeoff: requires stacking all layer weights once per forward pass. - This is acceptable because checkpointing already trades compute for memory. +def create_stacked_layers( + create_layer_fn: callable, + num_layers: int, + rngs: nnx.Rngs, +) -> nnx.Module: + """Create stacked decoder layers using nnx.vmap. - TODO(haochen): Load weights directly into stacked format to avoid 2x memory. - Currently we have both self.layers (original) and stacked copy during forward. - """ - num_layers = len(layers) - if num_layers == 0: - return hidden_states, [] + This creates a single module object where all parameters have shape (num_layers, ...). + This enables efficient scanning over layers without runtime stacking. - # Stack layer weights for dynamic indexing in scan - layer_graphdef, _ = nnx.split(layers[0]) - stacked_weights = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *[nnx.state(layer) for layer in layers]) + Args: + create_layer_fn: Function that takes rngs and returns a single layer module. + num_layers: Number of layers to create. + rngs: Random number generators for initialization. - def body_fn(hs, i): - layer_weights = jax.tree.map(lambda x: x[i], stacked_weights) - layer = nnx.merge(layer_graphdef, layer_weights) - hs, _ = layer( - hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, kv_cache=None - ) - return hs, hs if output_hidden_states else None + Returns: + A single module with stacked parameters. - body_fn = jax.checkpoint(body_fn) - final_hs, all_hs = jax.lax.scan(body_fn, hidden_states, jnp.arange(num_layers)) + Example: + >>> def create_layer(rngs): + ... return Llama3DecoderLayer(config, dtype=dtype, rngs=rngs) + >>> layers = create_stacked_layers(create_layer, config.num_hidden_layers, rngs) + >>> # layers.self_attn.q_proj.kernel.shape == (num_layers, hidden, head_dim*num_heads) + """ - if output_hidden_states: - # all_hs is [num_layers, batch, seq, hidden]. Exclude last layer output since - # it gets normed and appended in __call__ (matching non-checkpointed path). - all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers - 1)] - else: - all_hidden_states = [] + @nnx.split_rngs(splits=num_layers) + @nnx.vmap(in_axes=(0,), out_axes=0) + def vmapped_create(rngs: nnx.Rngs): + return create_layer_fn(rngs) - return final_hs, all_hidden_states + return vmapped_create(rngs) -def _forward_layers_standard( - layers: nnx.List, +def forward_layers( + layers: nnx.Module, hidden_states: jax.Array, + num_layers: int, *, attention_mask: jax.Array, positions: jax.Array, adapter_indices: jax.Array | None, kv_cache: KVCache | None, output_hidden_states: bool, -) -> tuple[jax.Array, list[jax.Array], list[jax.Array], list[jax.Array]]: - """Standard forward pass through decoder layers.""" - all_hidden_states: list[jax.Array] = [] - updated_keys, updated_values = [], [] + gradient_checkpointing: bool, +) -> tuple[jax.Array, list[jax.Array], KVCache | None]: + """Unified forward pass through stacked decoder layers. + + Uses jax.lax.scan for both training and inference. When gradient_checkpointing=True, + wraps the body function with jax.checkpoint. This is a no-op during inference + (when not computing gradients), so we can use a single code path. + + Args: + layers: Stacked decoder layers (created with create_stacked_layers/nnx.vmap). + hidden_states: Input hidden states of shape (batch, seq, hidden). + num_layers: Number of decoder layers. + attention_mask: Attention mask of shape (batch, seq). + positions: Position indices of shape (batch, seq). + adapter_indices: Optional LoRA adapter indices of shape (batch,). + kv_cache: Optional KV cache with stacked keys/values. + output_hidden_states: Whether to return intermediate hidden states. + gradient_checkpointing: Whether to use gradient checkpointing. - for layer_idx, layer in enumerate(layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) + Returns: + Tuple of: + - Final hidden states of shape (batch, seq, hidden) + - List of intermediate hidden states (if output_hidden_states=True) + - Updated KV cache (if kv_cache was provided) + """ + if num_layers == 0: + return hidden_states, [], kv_cache - layer_kv = kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]) - hidden_states, (k, v) = layer( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=layer_kv, - ) - updated_keys.append(k) - updated_values.append(v) + # Split layers into graph definition and stacked state + layer_graphdef, layer_state = nnx.split(layers) - return hidden_states, all_hidden_states, updated_keys, updated_values + # Prepare stacked KV cache + stacked_kv: tuple[jax.Array, jax.Array] | None = None + if kv_cache is not None: + stacked_kv = (kv_cache.keys, kv_cache.values) + def body_fn(carry, layer_idx): + hs, kv = carry -def forward_layers( - layers: nnx.List, - hidden_states: jax.Array, - *, - attention_mask: jax.Array, - positions: jax.Array, - adapter_indices: jax.Array | None, - kv_cache: KVCache | None, - output_hidden_states: bool, - training: bool, - gradient_checkpointing: bool, -) -> tuple[jax.Array, list[jax.Array], list[jax.Array], list[jax.Array]]: - """Forward pass through decoder layers with optional gradient checkpointing. + # Extract this layer's weights by indexing into stacked state + layer_weights = jax.tree.map(lambda x: x[layer_idx], layer_state) + layer = nnx.merge(layer_graphdef, layer_weights) - Chooses between checkpointed (scan-based) and standard (loop-based) paths. + # Get this layer's KV cache slice + layer_kv = None + if kv is not None: + layer_kv = (kv[0][layer_idx], kv[1][layer_idx]) - Returns: - hidden_states: Final hidden states after all layers - all_hidden_states: List of hidden states from each layer (if output_hidden_states) - updated_keys: List of updated key caches (empty if checkpointing) - updated_values: List of updated value caches (empty if checkpointing) - """ - if training and gradient_checkpointing: - hidden_states, all_hidden_states = _forward_layers_checkpointed( - layers, - hidden_states, + # Forward through layer + new_hs, (k, v) = layer( + hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, - output_hidden_states=output_hidden_states, + kv_cache=layer_kv, ) - return hidden_states, all_hidden_states, [], [] - else: - return _forward_layers_standard( - layers, - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=kv_cache, - output_hidden_states=output_hidden_states, + + # Update stacked KV cache + new_kv = kv + if kv is not None: + new_kv = ( + kv[0].at[layer_idx].set(k), + kv[1].at[layer_idx].set(v), + ) + + # Return updated carry and output for this iteration + output = hs if output_hidden_states else None + return (new_hs, new_kv), output + + # Apply gradient checkpointing if requested + if gradient_checkpointing: + body_fn = jax.checkpoint(body_fn) + + # Scan over layer indices + (final_hs, final_kv), all_hs = jax.lax.scan( + body_fn, + (hidden_states, stacked_kv), + jnp.arange(num_layers), + ) + + # Collect hidden states if requested + all_hidden_states: list[jax.Array] = [] + if output_hidden_states: + # all_hs has shape (num_layers, batch, seq, hidden) + # We want [input, layer0_out, layer1_out, ...] excluding final (it gets normed) + all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers - 1)] + + # Reconstruct KVCache if it was provided + new_kv_cache = None + if kv_cache is not None and final_kv is not None: + new_kv_cache = KVCache( + keys=final_kv[0], + values=final_kv[1], + cache_position=kv_cache.cache_position, ) + + return final_hs, all_hidden_states, new_kv_cache diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index f461a5613..6afd261ab 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -1,4 +1,4 @@ -"""Generator mixin for autoregressive text generation with KV caching.""" +"""Generator mixin for autoregressive text generation with stacked KV caching.""" from __future__ import annotations from dataclasses import dataclass @@ -14,49 +14,61 @@ @jax.tree_util.register_dataclass @dataclass class KVCache: - """Key-value cache for all layers, each entry in the list corresponds to one layer.""" + """Key-value cache for all layers in stacked format. - keys: list[jax.Array] - values: list[jax.Array] - cache_position: jax.Array # Per-sequence positions of shape [B] for left-aligned decoding + Attributes: + keys: Stacked key cache of shape (num_layers, batch, seq, num_kv_heads, head_dim). + values: Stacked value cache of shape (num_layers, batch, seq, num_kv_heads, head_dim). + cache_position: Per-sequence positions of shape (batch,) for left-aligned decoding. + """ + + keys: jax.Array # (num_layers, batch, seq, num_kv_heads, head_dim) + values: jax.Array # (num_layers, batch, seq, num_kv_heads, head_dim) + cache_position: jax.Array # (batch,) @staticmethod - def update( - kv_cache: KVCache | None, - keys: list[jax.Array], - values: list[jax.Array], + def from_layer_outputs( + keys: jax.Array, + values: jax.Array, positions: jax.Array, attention_mask: jax.Array, ) -> KVCache: - """Create an updated KVCache with computed cache positions for left-aligned decoding. + """Create KVCache from stacked layer outputs after prefill. Args: - kv_cache: Existing KVCache (None during prefill). - keys: List of key arrays per layer. - values: List of value arrays per layer. - positions: Position indices with shape [B, seq_len]. - attention_mask: Attention mask with shape [B, seq_len]. + keys: Stacked keys of shape (num_layers, batch, seq, num_kv_heads, head_dim). + values: Stacked values of shape (num_layers, batch, seq, num_kv_heads, head_dim). + positions: Position indices of shape (batch, seq). + attention_mask: Attention mask of shape (batch, seq). Returns: New KVCache with computed cache_position. """ - if kv_cache is not None: - # Decode: next position is current position + 1 - cache_position = positions[:, 0] + 1 - else: - # Prefill: next position is the sequence length (number of real tokens) - cache_position = attention_mask.sum(axis=1) + # Prefill: next position is the sequence length (number of real tokens) + cache_position = attention_mask.sum(axis=1).astype(jnp.int32) return KVCache(keys=keys, values=values, cache_position=cache_position) @staticmethod - def update_layer(kv_cache, k, v, positions): - """Update a single layer's KV cache at the given positions (for left-aligned decoding). + def update_layer( + kv_cache: tuple[jax.Array, jax.Array], + k: jax.Array, + v: jax.Array, + positions: jax.Array, + ) -> tuple[jax.Array, jax.Array]: + """Update a single layer's KV cache at the given positions. + + This is called from within the scan body to update a single layer's cache. + The layer index is handled by the caller (indexing into stacked cache). Args: - kv_cache: Tuple of (k_cache, v_cache) arrays for this layer. - k: New key values with shape [B, seq_len, num_heads, head_dim]. - v: New value values with shape [B, seq_len, num_heads, head_dim]. - positions: Position indices with shape [B, seq_len]. + kv_cache: Tuple of (k_cache, v_cache) for this layer. + Each has shape (batch, seq, num_kv_heads, head_dim). + k: New key values of shape (batch, seq_len, num_kv_heads, head_dim). + v: New value values of shape (batch, seq_len, num_kv_heads, head_dim). + positions: Position indices of shape (batch, seq_len). + + Returns: + Updated (k_cache, v_cache) tuple with new values at positions. """ k_cache, v_cache = kv_cache @@ -68,23 +80,42 @@ def update_at_pos(cache_slice, new_val_slice, pos): return k, v def pad_to_length(self, max_length: int) -> KVCache: - """Pad KV cache to a specified maximum length. + """Pad KV cache to a specified maximum sequence length. Args: - max_length: Target length to pad the cache to. + max_length: Target sequence length to pad to. Returns: New KVCache with padded keys and values. """ - # k and v have shape [B, T, num_heads, head_dim] - cache_pad_length = max_length - self.keys[0].shape[1] - pad_spec = ((0, 0), (0, cache_pad_length), (0, 0), (0, 0)) + current_length = self.keys.shape[2] # (num_layers, batch, seq, heads, dim) + if current_length >= max_length: + return self + + pad_length = max_length - current_length + # Pad only the sequence dimension (axis 2) + pad_spec = ((0, 0), (0, 0), (0, pad_length), (0, 0), (0, 0)) return KVCache( - keys=[jnp.pad(k, pad_spec) for k in self.keys], - values=[jnp.pad(v, pad_spec) for v in self.values], + keys=jnp.pad(self.keys, pad_spec), + values=jnp.pad(self.values, pad_spec), cache_position=self.cache_position, ) + @property + def num_layers(self) -> int: + """Number of layers in the cache.""" + return self.keys.shape[0] + + @property + def batch_size(self) -> int: + """Batch size.""" + return self.keys.shape[1] + + @property + def seq_len(self) -> int: + """Current sequence length.""" + return self.keys.shape[2] + @jax.tree_util.register_dataclass @dataclass @@ -197,11 +228,16 @@ def _prefill_and_decode( last_logits = model.compute_logits(last_hidden, adapter_indices)[:, 0, :] prompt_logprobs_array = None - # Pad KV cache and attention mask + # Pad KV cache to max_length kv_cache = outputs.kv_cache.pad_to_length(max_length) - # Pad KV cache and attention mask to max_length - kv_cache = kv_cache.pad_to_length(max_length) + # Update cache_position after prefill + kv_cache = KVCache( + keys=kv_cache.keys, + values=kv_cache.values, + cache_position=attention_mask.sum(axis=1).astype(jnp.int32), + ) + decode_attention_mask = jnp.pad(attention_mask, ((0, 0), (0, max_length - attention_mask.shape[1]))) def decode_fn(s: DecodeState, step: jax.Array) -> tuple[DecodeState, tuple[jax.Array, jax.Array]]: diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index faf1a9634..84ba493e5 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -1,3 +1,5 @@ +"""Weight loading and saving utilities for stacked layer models.""" + from __future__ import annotations from enum import Enum @@ -72,29 +74,68 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: raise ValueError(f"None of the architectures {config.architectures} is currently supported.") -def get_param_key(path: tuple, prefix: str = "") -> str: - "Get the safetensors key for a given model path." - if path[-1] in {"embedding", "kernel"}: - path = (*path[:-1], "weight") - elif path[-1] in {"lora_A", "lora_B"}: - path = (*path, "weight") - return prefix + ".".join(map(str, path)) - - -def get_expert_key(path: tuple, expert_idx: int) -> str: - "Get the safetensors key for an expert weight model path." - path = tuple(s if s != "experts" else f"experts.{expert_idx}" for s in path) - return ".".join(map(str, path)) +def _is_layer_param(path: tuple) -> bool: + """Check if a parameter path corresponds to a stacked decoder layer weight.""" + path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] + # Layer params have 'layers' in their path but not as part of another word + return "layers" in path_strs + + +def _get_hf_key_for_layer(path: tuple, layer_idx: int) -> str: + """Convert a stacked layer param path to a per-layer HuggingFace key.""" + parts = [] + for p in path: + key = p.key if hasattr(p, "key") else str(p) + if key == "layers": + parts.append(f"layers.{layer_idx}") + elif key in ("kernel", "embedding"): + parts.append("weight") + elif key in ("lora_A", "lora_B"): + parts.append(key) + parts.append("weight") + else: + parts.append(key) + return ".".join(parts) + + +def _get_hf_key(path: tuple) -> str: + """Convert a non-layer param path to a HuggingFace key.""" + parts = [] + for p in path: + key = p.key if hasattr(p, "key") else str(p) + if key in ("kernel", "embedding"): + parts.append("weight") + elif key in ("lora_A", "lora_B"): + parts.append(key) + parts.append("weight") + else: + parts.append(key) + return ".".join(parts) def load_safetensors( checkpoint_dir: str | os.PathLike, config: PretrainedConfig, model: nnx.Module, + num_layers: int, skip_lora: bool = True, prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: + """Load safetensors weights into a model with stacked layers. + + For layer parameters, loads individual layer weights and stacks them. + For non-layer parameters, loads directly. + + Args: + checkpoint_dir: Directory containing safetensors files. + config: Model configuration. + model: Model with stacked layer weights (created with create_stacked_layers). + num_layers: Number of decoder layers. + skip_lora: Whether to skip LoRA parameters. + prefix: Prefix to remove from tensor keys. + filter_fn: Optional filter for which parameters to load. + """ tensors = {} for file in Path(checkpoint_dir).glob("*.safetensors"): tensors.update(safetensors.numpy.load_file(file)) @@ -102,22 +143,78 @@ def load_safetensors( model_params = nnx.to_flat_state(nnx.state(model)) updates = [] + for path, param in model_params: if filter_fn is not None and not filter_fn(path): continue - key = get_param_key(path) + + path_keys = [p.key if hasattr(p, "key") else str(p) for p in path] + # Skip LoRA parameters if requested - if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): + if skip_lora and any(k in path_keys for k in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")): continue - if "experts" in path: - tensors[key] = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.num_experts)], axis=0) + + if _is_layer_param(path): + # Stack layer weights from individual layer tensors + layer_tensors = [] + for layer_idx in range(num_layers): + key = _get_hf_key_for_layer(path, layer_idx) + + # Handle expert weights (MoE) - HF stores each expert separately + # Our model has shape (num_experts, in, out), HF has experts.{idx}.*.weight + if ".experts." in key and hasattr(config, "num_experts"): + num_experts = config.num_experts + expert_tensors = [] + for expert_idx in range(num_experts): + # Insert expert index: experts.gate_proj -> experts.0.gate_proj + expert_key = key.replace(".experts.", f".experts.{expert_idx}.") + if expert_key in tensors: + expert_tensors.append(tensors[expert_key].T) + if expert_tensors: + tensor = np.stack(expert_tensors, axis=0) + else: + raise KeyError(f"Expert weights not found for {key}") + else: + tensor = tensors[key] + # Transpose linear weights (HF uses [out, in], we use [in, out]) + if "embed_tokens" not in key: + tensor = tensor.T + + # Reshape attention projections if needed + if any(proj in key for proj in ("q_proj", "k_proj", "v_proj", "o_proj")): + # param.shape[1:] gives the target shape without the layer axis + target_shape = param.shape[1:] + tensor = tensor.reshape(target_shape) + + layer_tensors.append(tensor) + + stacked_tensor = np.stack(layer_tensors, axis=0) else: - tensors[key] = tensors[key] if "embed_tokens" in path else tensors[key].T - if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: - tensors[key] = tensors[key].reshape(param.shape) - assert param.shape == tensors[key].shape, f"shape mismatch for {key}" - sharded_tensor = jax.device_put(tensors[key].astype(param.dtype), param.sharding) + # Non-layer parameter - load directly + key = _get_hf_key(path) + + if ".experts." in key and hasattr(config, "num_experts"): + num_experts = config.num_experts + expert_tensors = [] + for expert_idx in range(num_experts): + expert_key = key.replace(".experts.", f".experts.{expert_idx}.") + if expert_key in tensors: + expert_tensors.append(tensors[expert_key].T) + if expert_tensors: + stacked_tensor = np.stack(expert_tensors, axis=0) + else: + raise KeyError(f"Expert weights not found for {key}") + else: + stacked_tensor = tensors[key] + if "embed_tokens" not in key: + stacked_tensor = stacked_tensor.T + + assert param.shape == stacked_tensor.shape, ( + f"Shape mismatch for {path}: expected {param.shape}, got {stacked_tensor.shape}" + ) + sharded_tensor = jax.device_put(stacked_tensor.astype(param.dtype), param.sharding) updates.append((path, sharded_tensor)) + nnx.update(model, nnx.from_flat_state(updates)) @@ -125,31 +222,69 @@ def save_safetensors( config: PretrainedConfig, model: nnx.Module, filename: Path, + num_layers: int, prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: + """Save model weights to safetensors, unstacking layer weights for HF compatibility. + + Args: + config: Model configuration. + model: Model with stacked layer weights. + filename: Output safetensors file path. + num_layers: Number of decoder layers. + prefix: Prefix to add to tensor keys. + filter_fn: Optional filter for which parameters to save. + """ model_params = nnx.to_flat_state(nnx.state(model)) tensors = {} + for path, param in model_params: - if "rngs" in path: + path_keys = [p.key if hasattr(p, "key") else str(p) for p in path] + if "rngs" in path_keys: continue if filter_fn is not None and not filter_fn(path): continue - key = get_param_key(path, prefix=prefix) - if "experts" in path: - for i in range(config.num_experts): - tensors[get_expert_key(path, i)] = param[i, :, :].T - continue - if "q_proj" in path or "k_proj" in path or "v_proj" in path: - param = param.reshape(param.shape[0], -1) - elif "o_proj" in path: - param = param.reshape(-1, param.shape[-1]) - tensors[key] = param if "embed_tokens" in path else param.T + + if _is_layer_param(path): + # Unstack and save as individual layer weights + for layer_idx in range(num_layers): + key = prefix + _get_hf_key_for_layer(path, layer_idx) + layer_param = param[layer_idx] + + # Handle expert weights (MoE) - save each expert separately for HF compatibility + if ".experts." in key and hasattr(config, "num_experts"): + for expert_idx in range(config.num_experts): + expert_key = key.replace(".experts.", f".experts.{expert_idx}.") + tensors[expert_key] = layer_param[expert_idx].T + else: + # Reshape attention projections back to 2D + if "q_proj" in key or "k_proj" in key or "v_proj" in key: + layer_param = layer_param.reshape(layer_param.shape[0], -1) + elif "o_proj" in key: + layer_param = layer_param.reshape(-1, layer_param.shape[-1]) + + # Transpose back to HF format + if "embed_tokens" not in key: + layer_param = layer_param.T + tensors[key] = layer_param + else: + # Non-layer parameter - save directly + key = prefix + _get_hf_key(path) + + if ".experts." in key and hasattr(config, "num_experts"): + for expert_idx in range(config.num_experts): + expert_key = key.replace(".experts.", f".experts.{expert_idx}.") + tensors[expert_key] = param[expert_idx].T + else: + tensor = param + if "embed_tokens" not in key: + tensor = tensor.T + tensors[key] = tensor # In multi-host mode, gather all shards and only save from rank 0 if jax.process_count() > 1: from jax.experimental import multihost_utils - tensors = {k: multihost_utils.process_allgather(v, tiled=True) for k, v in tensors.items()} if jax.process_index() == 0: @@ -186,6 +321,7 @@ def load_lora_checkpoint( temp_dir, model.config, adapter_lora_params, + model.model.num_layers, skip_lora=False, prefix="base_model.model.", filter_fn=lambda path: filter_lora(adapter_config, path), @@ -221,6 +357,7 @@ def save_lora_checkpoint( model.config, adapter_lora_params, temp_dir / "adapter_model.safetensors", + model.model.num_layers, prefix="base_model.model.", filter_fn=lambda path: filter_lora(adapter_config, path), ) @@ -248,11 +385,21 @@ def extract_adapter_state(adapter_index: int, lora_params: nnx.GraphState, rank: def extract_state(path: tuple, p: jnp.ndarray): if path[-2].key not in {"lora_A", "lora_B"}: return p - assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" + # For stacked layers, LoRA params have shape (num_layers, num_adapters, ...) + # We extract adapter_index from the adapter dimension + assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" if path[-2].key == "lora_A": - return p[adapter_index, ..., :, :rank] + # Shape: (L, A, in, R) or (A, in, R) -> extract [..., :, :rank] + if p.ndim == 4: # Stacked: (L, A, in, R) + return p[:, adapter_index, :, :rank] + else: # Non-stacked: (A, in, R) + return p[adapter_index, :, :rank] if path[-2].key == "lora_B": - return p[adapter_index, ..., :rank, :] + # Shape: (L, A, R, out) or (A, R, out) -> extract [..., :rank, :] + if p.ndim == 4: # Stacked: (L, A, R, out) + return p[:, adapter_index, :rank, :] + else: # Non-stacked: (A, R, out) + return p[adapter_index, :rank, :] return jax.tree.map_with_path(extract_state, lora_params) @@ -267,11 +414,17 @@ def insert_adapter_state( def insert_state(path: tuple, p: jax.Array, new: jax.Array): if path[-2].key not in {"lora_A", "lora_B"}: return new - assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" + assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" if path[-2].key == "lora_A": - return p.at[adapter_index, ..., :, :rank].set(new) + if p.ndim == 4: # Stacked: (L, A, in, R) + return p.at[:, adapter_index, :, :rank].set(new) + else: # Non-stacked: (A, in, R) + return p.at[adapter_index, :, :rank].set(new) elif path[-2].key == "lora_B": - return p.at[adapter_index, ..., :rank, :].set(new) + if p.ndim == 4: # Stacked: (L, A, R, out) + return p.at[:, adapter_index, :rank, :].set(new) + else: # Non-stacked: (A, R, out) + return p.at[adapter_index, :rank, :].set(new) updated = jax.tree.map_with_path(insert_state, lora_params, new_params) nnx.update(lora_params, updated) From 6ebf1b99550cdf9693e705c6672dac93889623b7 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 13:08:00 -0800 Subject: [PATCH 072/133] remove duplication --- skyrl-tx/tests/models/test_llama3.py | 2 +- .../tests/models/test_llama3_lora_training.py | 2 +- skyrl-tx/tests/models/test_models_common.py | 26 +++++------- skyrl-tx/tests/models/test_qwen3.py | 4 +- skyrl-tx/tests/models/test_qwen3_generate.py | 4 +- .../tests/models/test_qwen3_lora_training.py | 2 +- skyrl-tx/tx/layers/lora.py | 40 ++++++++++++++----- skyrl-tx/tx/models/utils.py | 30 +++++++++----- 8 files changed, 67 insertions(+), 43 deletions(-) diff --git a/skyrl-tx/tests/models/test_llama3.py b/skyrl-tx/tests/models/test_llama3.py index fa195567f..7913839c5 100644 --- a/skyrl-tx/tests/models/test_llama3.py +++ b/skyrl-tx/tests/models/test_llama3.py @@ -42,7 +42,7 @@ def test_llama3(tp: int): mesh = jax.make_mesh((1, tp), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Llama3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model) + load_safetensors(tmp, config, model, config.num_hidden_layers) outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy(), output_hidden_states=True) diff --git a/skyrl-tx/tests/models/test_llama3_lora_training.py b/skyrl-tx/tests/models/test_llama3_lora_training.py index af91d373e..ed9e9f266 100644 --- a/skyrl-tx/tests/models/test_llama3_lora_training.py +++ b/skyrl-tx/tests/models/test_llama3_lora_training.py @@ -21,7 +21,7 @@ def test_lora_training(): mesh = jax.make_mesh((1, 1), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Llama3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, config, model) + load_safetensors(checkpoint_path, config, model, config.num_hidden_layers) # Set different ranks for each adapter (0: rank 16, 1: rank 8) init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 289e97556..6ce875eeb 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -19,31 +19,23 @@ MODEL_IDS = ["llama3", "qwen3"] -def create_model(model_name, config_cls, model_cls, mesh_axes): +def create_model(model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0, gradient_checkpointing=None, seed=42): """Create model with random weights for testing.""" base_config = AutoConfig.from_pretrained(model_name) - config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) - mesh = jax.make_mesh((1, 1), mesh_axes) + config_kwargs = dict(max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True, loss_chunk_size=loss_chunk_size) + if gradient_checkpointing is not None: + config_kwargs["gradient_checkpointing"] = gradient_checkpointing + config = config_cls(base_config, **config_kwargs) + mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): - model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed)) return model, config def load_model(tmp_dir, model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0): """Load model from pre-saved weights directory.""" - base_config = AutoConfig.from_pretrained(model_name) - config = config_cls( - base_config, - max_lora_adapters=1, - max_lora_rank=1, - shard_attention_heads=True, - loss_chunk_size=loss_chunk_size, - gradient_checkpointing=False, - ) - mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * 2) - with jax.set_mesh(mesh): - model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(tmp_dir, config, model) + model, config = create_model(model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=loss_chunk_size, gradient_checkpointing=False, seed=0) + load_safetensors(tmp_dir, config, model, config.num_hidden_layers) return model diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index 55a779c9e..587e650a5 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -43,7 +43,7 @@ def test_qwen3(tp: int): mesh = jax.make_mesh((1, tp), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model) + load_safetensors(tmp, config, model, config.num_hidden_layers) outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy(), output_hidden_states=True) assert outputs.hidden_states is not None @@ -218,7 +218,7 @@ def test_qwen3_lora(): mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(base_tmp, config, model) + load_safetensors(base_tmp, config, model, config.num_hidden_layers) # Get outputs from all HF models hf_outputs_list = [] diff --git a/skyrl-tx/tests/models/test_qwen3_generate.py b/skyrl-tx/tests/models/test_qwen3_generate.py index 8b950d535..7579d823d 100644 --- a/skyrl-tx/tests/models/test_qwen3_generate.py +++ b/skyrl-tx/tests/models/test_qwen3_generate.py @@ -49,7 +49,7 @@ def test_qwen3_generate(): mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model) + load_safetensors(tmp, config, model, config.num_hidden_layers) sampling_params = [ types.SamplingParams(max_tokens=10, temperature=0.0, seed=42), @@ -149,7 +149,7 @@ def test_qwen3_generate_speed(): mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Qwen3ForCausalLM(config, dtype=jnp.bfloat16, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model) + load_safetensors(tmp, config, model, config.num_hidden_layers) sampling_params = [types.SamplingParams(max_tokens=50, temperature=0.0, seed=42) for i in range(len(inputs))] # Warmup diff --git a/skyrl-tx/tests/models/test_qwen3_lora_training.py b/skyrl-tx/tests/models/test_qwen3_lora_training.py index 85f5f3bda..f0dd0aa80 100644 --- a/skyrl-tx/tests/models/test_qwen3_lora_training.py +++ b/skyrl-tx/tests/models/test_qwen3_lora_training.py @@ -21,7 +21,7 @@ def test_lora_training(): mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Qwen3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, config, model) + load_safetensors(checkpoint_path, config, model, config.num_hidden_layers) # Set different ranks for each adapter (0: rank 16, 1: rank 8) init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 911fff721..3ad54505c 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -1,6 +1,7 @@ from flax import nnx import jax from jax import numpy as jnp +from jax.core import Tracer from tx.utils.models import filter_lora from tx.layers.util import Param, prepare_routing, ragged_dot @@ -8,6 +9,25 @@ from tx.tinker.types import LoraConfig +def _get_sharding_spec(arr: jax.Array): + """Get sharding spec from an array, handling both concrete and traced arrays. + + Inside nnx.vmap, arrays become tracers and .sharding is not directly accessible. + Use jax.typeof() to get sharding info from traced arrays. + """ + if isinstance(arr, Tracer): + # For traced arrays, use jax.typeof to get the abstract value with sharding + aval = jax.typeof(arr) + if hasattr(aval, "sharding") and aval.sharding is not None: + return aval.sharding.spec + return None + else: + # For concrete arrays, access sharding directly + if arr.sharding is not None: + return arr.sharding.spec + return None + + class LoRAMixin: """A mixin for flax NNX modules to add multi-adapter LoRA support. This mixin adds LoRA parameters (lora_A, lora_B) and methods to apply @@ -125,10 +145,10 @@ def __init__( embedding_init=embedding_init, rngs=rngs, ) - assert ( - self.embedding[...].sharding is not None - ), "LoRAEmbed layer needs sharding, you can specify it by using nnx.with_partitioning on the embedding_init" - sharding = self.embedding[...].sharding.spec + sharding = _get_sharding_spec(self.embedding[...]) + assert sharding is not None, ( + "LoRAEmbed layer needs sharding, you can specify it by using nnx.with_partitioning on the embedding_init" + ) self.init_lora( max_lora_adapters=max_lora_adapters, @@ -183,10 +203,10 @@ def __init__( bias_init=bias_init, rngs=rngs, ) - assert ( - self.kernel[...].sharding is not None - ), "LoRALinear layer needs sharding, you can specify it by using nnx.with_partitioning on the kernel_init" - sharding = self.kernel[...].sharding.spec + sharding = _get_sharding_spec(self.kernel[...]) + assert sharding is not None, ( + "LoRALinear layer needs sharding, you can specify it by using nnx.with_partitioning on the kernel_init" + ) self.init_lora( max_lora_adapters=max_lora_adapters, max_lora_rank=max_lora_rank, @@ -224,8 +244,8 @@ def __init__( self.weight = Param(num_experts, in_features, out_features, dtype=dtype, kernel_init=kernel_init, rngs=rngs) - assert self.weight[...].sharding is not None, "LoRAExpert layer needs sharding" - sharding = self.weight[...].sharding.spec + sharding = _get_sharding_spec(self.weight[...]) + assert sharding is not None, "LoRAExpert layer needs sharding" self.init_lora( max_lora_adapters=max_lora_adapters, max_lora_rank=max_lora_rank, diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 2a852e756..e6a29114c 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -120,7 +120,7 @@ def body_fn(carry, layer_idx): kv_cache=layer_kv, ) - # Update stacked KV cache + # Update stacked KV cache if provided new_kv = kv if kv is not None: new_kv = ( @@ -128,16 +128,18 @@ def body_fn(carry, layer_idx): kv[1].at[layer_idx].set(v), ) - # Return updated carry and output for this iteration - output = hs if output_hidden_states else None - return (new_hs, new_kv), output + # Return updated carry and outputs for this iteration + # Always output (k, v) so we can build cache during prefill + # Output the layer OUTPUT (new_hs), not input, for hidden_states collection + hs_output = new_hs if output_hidden_states else None + return (new_hs, new_kv), (hs_output, k, v) # Apply gradient checkpointing if requested if gradient_checkpointing: body_fn = jax.checkpoint(body_fn) # Scan over layer indices - (final_hs, final_kv), all_hs = jax.lax.scan( + (final_hs, final_kv), (all_hs, all_keys, all_values) = jax.lax.scan( body_fn, (hidden_states, stacked_kv), jnp.arange(num_layers), @@ -146,17 +148,27 @@ def body_fn(carry, layer_idx): # Collect hidden states if requested all_hidden_states: list[jax.Array] = [] if output_hidden_states: - # all_hs has shape (num_layers, batch, seq, hidden) - # We want [input, layer0_out, layer1_out, ...] excluding final (it gets normed) + # all_hs has shape (num_layers, batch, seq, hidden) containing output of each layer + # We want [embed, layer0_out, layer1_out, ..., layer(N-2)_out] + # The model will append the normed layer(N-1)_out after calling this function all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers - 1)] - # Reconstruct KVCache if it was provided - new_kv_cache = None + # Reconstruct KVCache if kv_cache is not None and final_kv is not None: + # Decode mode: use updated cache from carry new_kv_cache = KVCache( keys=final_kv[0], values=final_kv[1], cache_position=kv_cache.cache_position, ) + else: + # Prefill mode: build cache from collected K/V outputs + # all_keys/all_values have shape (num_layers, batch, seq, heads, dim) + new_kv_cache = KVCache.from_layer_outputs( + keys=all_keys, + values=all_values, + positions=positions, + attention_mask=attention_mask, + ) return final_hs, all_hidden_states, new_kv_cache From dbe5114a175e536535480b938095688066733a21 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 13:18:40 -0800 Subject: [PATCH 073/133] remove duplication --- skyrl-tx/tests/models/test_models_common.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 289e97556..7f953594e 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -19,30 +19,25 @@ MODEL_IDS = ["llama3", "qwen3"] -def create_model(model_name, config_cls, model_cls, mesh_axes): +def create_model(model_name, config_cls, model_cls, mesh_axes, *, mesh_axis_types=None, **config_kwargs): """Create model with random weights for testing.""" base_config = AutoConfig.from_pretrained(model_name) - config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) - mesh = jax.make_mesh((1, 1), mesh_axes) + config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True, **config_kwargs) + mesh_kwargs = {"axis_types": mesh_axis_types} if mesh_axis_types else {} + mesh = jax.make_mesh((1, 1), mesh_axes, **mesh_kwargs) with jax.set_mesh(mesh): - model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(42)) + model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) return model, config def load_model(tmp_dir, model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0): """Load model from pre-saved weights directory.""" - base_config = AutoConfig.from_pretrained(model_name) - config = config_cls( - base_config, - max_lora_adapters=1, - max_lora_rank=1, - shard_attention_heads=True, + model, config = create_model( + model_name, config_cls, model_cls, mesh_axes, + mesh_axis_types=(jax.sharding.AxisType.Auto,) * 2, loss_chunk_size=loss_chunk_size, gradient_checkpointing=False, ) - mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * 2) - with jax.set_mesh(mesh): - model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) load_safetensors(tmp_dir, config, model) return model From 15b4086e3c0b3432d1bb77cbabfb46adcef7bc84 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 13:28:14 -0800 Subject: [PATCH 074/133] load model twice --- skyrl-tx/tests/models/test_models_common.py | 59 +++++++++------------ 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 7f953594e..f96c8845e 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -45,49 +45,42 @@ def load_model(tmp_dir, model_name, config_cls, model_cls, mesh_axes, *, loss_ch @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) class TestGradientCheckpointing: - def test_output_matches_non_checkpointed(self, model_name, config_cls, model_cls, mesh_axes): - """Forward pass should produce identical outputs with/without checkpointing.""" - model, config = create_model(model_name, config_cls, model_cls, mesh_axes) - + def _forward(self, model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing, **forward_kwargs): + """Create model, run forward pass, and return (model, config, out).""" batch_size, seq_len = 2, 8 + model, config = create_model(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=gradient_checkpointing) input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - - # Run without checkpointing - config.gradient_checkpointing = False model.train() - out_no_ckpt = model(input_ids, attention_mask=attention_mask) - logits_no_ckpt = model.compute_logits(out_no_ckpt.last_hidden_state) + out = model(input_ids, attention_mask=attention_mask, **forward_kwargs) + return model, config, out - # Run with checkpointing - config.gradient_checkpointing = True - out_ckpt = model(input_ids, attention_mask=attention_mask) - logits_ckpt = model.compute_logits(out_ckpt.last_hidden_state) + def test_output_matches_non_checkpointed(self, model_name, config_cls, model_cls, mesh_axes): + """Forward pass should produce identical outputs with/without checkpointing.""" + model, _, out = self._forward(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=False) + logits_no_ckpt = model.compute_logits(out.last_hidden_state) + del model, out + + model, _, out = self._forward(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=True) + logits_ckpt = model.compute_logits(out.last_hidden_state) + del model, out np.testing.assert_allclose(logits_no_ckpt, logits_ckpt, rtol=1e-4, atol=1e-6) def test_hidden_states_length_matches(self, model_name, config_cls, model_cls, mesh_axes): """Both paths should return same number of hidden states.""" - model, config = create_model(model_name, config_cls, model_cls, mesh_axes) - - batch_size, seq_len = 2, 8 - input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) - attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - - config.gradient_checkpointing = False - model.train() - out_no_ckpt = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) - - config.gradient_checkpointing = True - out_ckpt = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) - - assert len(out_no_ckpt.hidden_states) == len(out_ckpt.hidden_states) - assert len(out_ckpt.hidden_states) == config.num_hidden_layers + 1 - - for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(out_no_ckpt.hidden_states, out_ckpt.hidden_states)): - np.testing.assert_allclose( - hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}" - ) + _, config, out = self._forward(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=False, output_hidden_states=True) + hidden_states_no_ckpt = out.hidden_states + num_hidden_layers = config.num_hidden_layers + del out + + _, _, out = self._forward(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=True, output_hidden_states=True) + hidden_states_ckpt = out.hidden_states + del out + + assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == num_hidden_layers + 1 + for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(hidden_states_no_ckpt, hidden_states_ckpt)): + np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}") def test_eval_mode_uses_standard_path(self, model_name, config_cls, model_cls, mesh_axes): """eval() mode should use standard path with KV cache support.""" From a3adadd742ef26cdfc6607d6735138b88a641776 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 13:32:55 -0800 Subject: [PATCH 075/133] type hints --- skyrl-tx/tests/models/test_models_common.py | 73 ++++++++++++++++++--- 1 file changed, 64 insertions(+), 9 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index f96c8845e..c8d46af3e 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -1,4 +1,5 @@ import tempfile +from typing import Any from flax import nnx import jax @@ -7,9 +8,10 @@ import pytest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -from tx.models.configs import Llama3Config, Qwen3Config +from tx.models.configs import Llama3Config, ModelConfig, Qwen3Config from tx.models.llama3 import Llama3ForCausalLM from tx.models.qwen3 import Qwen3ForCausalLM +from tx.models.types import CausalLMOutput, ModelForCausalLM from tx.utils.models import load_safetensors MODEL_PARAMS = [ @@ -19,7 +21,15 @@ MODEL_IDS = ["llama3", "qwen3"] -def create_model(model_name, config_cls, model_cls, mesh_axes, *, mesh_axis_types=None, **config_kwargs): +def create_model( + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + *, + mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None, + **config_kwargs: Any, +) -> tuple[ModelForCausalLM, ModelConfig]: """Create model with random weights for testing.""" base_config = AutoConfig.from_pretrained(model_name) config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True, **config_kwargs) @@ -30,7 +40,15 @@ def create_model(model_name, config_cls, model_cls, mesh_axes, *, mesh_axis_type return model, config -def load_model(tmp_dir, model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0): +def load_model( + tmp_dir: str, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + *, + loss_chunk_size: int = 0, +) -> ModelForCausalLM: """Load model from pre-saved weights directory.""" model, config = create_model( model_name, config_cls, model_cls, mesh_axes, @@ -45,7 +63,15 @@ def load_model(tmp_dir, model_name, config_cls, model_cls, mesh_axes, *, loss_ch @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) class TestGradientCheckpointing: - def _forward(self, model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing, **forward_kwargs): + def _forward( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + gradient_checkpointing: bool, + **forward_kwargs: Any, + ) -> tuple[ModelForCausalLM, ModelConfig, CausalLMOutput]: """Create model, run forward pass, and return (model, config, out).""" batch_size, seq_len = 2, 8 model, config = create_model(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=gradient_checkpointing) @@ -55,7 +81,13 @@ def _forward(self, model_name, config_cls, model_cls, mesh_axes, gradient_checkp out = model(input_ids, attention_mask=attention_mask, **forward_kwargs) return model, config, out - def test_output_matches_non_checkpointed(self, model_name, config_cls, model_cls, mesh_axes): + def test_output_matches_non_checkpointed( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + ) -> None: """Forward pass should produce identical outputs with/without checkpointing.""" model, _, out = self._forward(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=False) logits_no_ckpt = model.compute_logits(out.last_hidden_state) @@ -67,7 +99,13 @@ def test_output_matches_non_checkpointed(self, model_name, config_cls, model_cls np.testing.assert_allclose(logits_no_ckpt, logits_ckpt, rtol=1e-4, atol=1e-6) - def test_hidden_states_length_matches(self, model_name, config_cls, model_cls, mesh_axes): + def test_hidden_states_length_matches( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + ) -> None: """Both paths should return same number of hidden states.""" _, config, out = self._forward(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=False, output_hidden_states=True) hidden_states_no_ckpt = out.hidden_states @@ -82,7 +120,13 @@ def test_hidden_states_length_matches(self, model_name, config_cls, model_cls, m for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(hidden_states_no_ckpt, hidden_states_ckpt)): np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}") - def test_eval_mode_uses_standard_path(self, model_name, config_cls, model_cls, mesh_axes): + def test_eval_mode_uses_standard_path( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + ) -> None: """eval() mode should use standard path with KV cache support.""" model, config = create_model(model_name, config_cls, model_cls, mesh_axes) config.gradient_checkpointing = True @@ -99,7 +143,12 @@ def test_eval_mode_uses_standard_path(self, model_name, config_cls, model_cls, m @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) -def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): +def test_compute_logits( + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], +) -> None: """Test that model.compute_logits matches HuggingFace logits.""" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -126,7 +175,13 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) @pytest.mark.parametrize("chunk_size", [8, 16, 32]) -def test_chunked_logprobs(model_name, config_cls, model_cls, mesh_axes, chunk_size): +def test_chunked_logprobs( + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + chunk_size: int, +) -> None: """Test that chunked and non-chunked compute_logprobs produce identical results.""" tokenizer = AutoTokenizer.from_pretrained(model_name) inputs = ["The capital of France is", "Hello world"] From a552dfcfaab4e7e0294cacabe67f8b7a764b138a Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 14:22:43 -0800 Subject: [PATCH 076/133] fix --- .../tests/models/test_llama3_lora_training.py | 27 ++++++++++--- skyrl-tx/tests/models/test_qwen3.py | 38 +++++++++++++++---- .../tests/models/test_qwen3_lora_training.py | 27 ++++++++++--- skyrl-tx/tx/layers/lora.py | 34 ++++++++++++++--- skyrl-tx/tx/models/utils.py | 4 +- 5 files changed, 105 insertions(+), 25 deletions(-) diff --git a/skyrl-tx/tests/models/test_llama3_lora_training.py b/skyrl-tx/tests/models/test_llama3_lora_training.py index ed9e9f266..aba69a728 100644 --- a/skyrl-tx/tests/models/test_llama3_lora_training.py +++ b/skyrl-tx/tests/models/test_llama3_lora_training.py @@ -46,18 +46,33 @@ def loss_fn(model, input_ids, target_ids, attention_mask): graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) # Helper to extract adapter params at specific index + # Decoder layer LoRA params have shape (num_layers, num_adapters, ...) + # Embed tokens LoRA params have shape (num_adapters, ...) def get_adapter_params(params, adapter_idx): - return jax.tree.map(lambda p: p[adapter_idx].copy(), params) + def extract(path, p): + path_str = str(path) + if "layers" in path_str: + return p[:, adapter_idx].copy() # Keep layer dimension + else: + return p[adapter_idx].copy() + return jax.tree.map_with_path(extract, params) # Helper to extract out-of-rank params for an adapter def get_out_of_rank_params(params, adapter_idx, rank): def slice_param(path, p): - if "lora_A" in str(path): - return p[adapter_idx, :, rank:].copy() - elif "lora_B" in str(path): - return p[adapter_idx, rank:, :].copy() + path_str = str(path) + is_stacked = "layers" in path_str + if "lora_A" in path_str: + if is_stacked: + return p[:, adapter_idx, :, rank:].copy() + else: + return p[adapter_idx, :, rank:].copy() + elif "lora_B" in path_str: + if is_stacked: + return p[:, adapter_idx, rank:, :].copy() + else: + return p[adapter_idx, rank:, :].copy() return p - return jax.tree.map_with_path(slice_param, params) # Save initial states diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index 587e650a5..9e5fc9f95 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -92,7 +92,7 @@ def load_lora_weights( scaling: float, rank: int, ) -> None: - """Load LoRA weights from numpy arrays to JAX module.""" + """Load LoRA weights from numpy arrays to JAX module (non-stacked modules like embed_tokens).""" assert ( jax_module.lora_A is not None and jax_module.lora_B is not None @@ -105,6 +105,28 @@ def load_lora_weights( jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[adapter_idx].set(rank) +def load_stacked_lora_weights( + jax_module: LoRAMixin, + layer_idx: int, + adapter_idx: int, + lora_A_weights: np.ndarray, + lora_B_weights: np.ndarray, + scaling: float, + rank: int, +) -> None: + """Load LoRA weights for a specific layer in stacked format (decoder layers).""" + assert ( + jax_module.lora_A is not None + and jax_module.lora_B is not None + and jax_module.lora_scaling is not None + and jax_module.lora_ranks is not None + ) + jax_module.lora_A[...] = jax_module.lora_A[...].at[layer_idx, adapter_idx].set(jnp.array(lora_A_weights)) + jax_module.lora_B[...] = jax_module.lora_B[...].at[layer_idx, adapter_idx].set(jnp.array(lora_B_weights)) + jax_module.lora_scaling[...] = jax_module.lora_scaling[...].at[layer_idx, adapter_idx].set(scaling) + jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[layer_idx, adapter_idx].set(rank) + + @pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)]) def test_qwen3_moe_layer_lora(ep: int, tp: int): """Test MoE LoRA by merging adapter into base weights and comparing outputs.""" @@ -245,17 +267,19 @@ def test_qwen3_lora(): rank=lora_config.r, ) - # Load layer LoRA weights - for i, layer in enumerate(model.model.layers): + # Load layer LoRA weights (stacked format) + for i in range(config.num_hidden_layers): hf_layer = hf_model.base_model.model.model.layers[i] - for module, projections in [ + for module_name, projections in [ ("mlp", ["gate_proj", "up_proj", "down_proj"]), ("self_attn", ["q_proj", "k_proj", "v_proj", "o_proj"]), ]: for proj_name in projections: - hf_proj = getattr(getattr(hf_layer, module), proj_name) - load_lora_weights( - getattr(getattr(layer, module), proj_name), + hf_proj = getattr(getattr(hf_layer, module_name), proj_name) + jax_proj = getattr(getattr(model.model.layers, module_name), proj_name) + load_stacked_lora_weights( + jax_proj, + layer_idx=i, adapter_idx=adapter_idx, lora_A_weights=hf_proj.lora_A["default"].weight.detach().numpy().T, lora_B_weights=hf_proj.lora_B["default"].weight.detach().numpy().T, diff --git a/skyrl-tx/tests/models/test_qwen3_lora_training.py b/skyrl-tx/tests/models/test_qwen3_lora_training.py index f0dd0aa80..c757c18f6 100644 --- a/skyrl-tx/tests/models/test_qwen3_lora_training.py +++ b/skyrl-tx/tests/models/test_qwen3_lora_training.py @@ -46,18 +46,33 @@ def loss_fn(model, input_ids, target_ids, attention_mask): graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) # Helper to extract adapter params at specific index + # Decoder layer LoRA params have shape (num_layers, num_adapters, ...) + # Embed tokens LoRA params have shape (num_adapters, ...) def get_adapter_params(params, adapter_idx): - return jax.tree.map(lambda p: p[adapter_idx].copy(), params) + def extract(path, p): + path_str = str(path) + if "layers" in path_str: + return p[:, adapter_idx].copy() # Keep layer dimension + else: + return p[adapter_idx].copy() + return jax.tree.map_with_path(extract, params) # Helper to extract out-of-rank params for an adapter def get_out_of_rank_params(params, adapter_idx, rank): def slice_param(path, p): - if "lora_A" in str(path): - return p[adapter_idx, :, rank:].copy() - elif "lora_B" in str(path): - return p[adapter_idx, rank:, :].copy() + path_str = str(path) + is_stacked = "layers" in path_str + if "lora_A" in path_str: + if is_stacked: + return p[:, adapter_idx, :, rank:].copy() + else: + return p[adapter_idx, :, rank:].copy() + elif "lora_B" in path_str: + if is_stacked: + return p[:, adapter_idx, rank:, :].copy() + else: + return p[adapter_idx, rank:, :].copy() return p - return jax.tree.map_with_path(slice_param, params) # Save initial states diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 3ad54505c..0978ea296 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -343,20 +343,38 @@ def init_adapter(path, value): if not filter_lora(lora_config, normalized_path): effective_rank = 0 + # Check if this is a stacked layer parameter (shape has extra leading dimension) + # Stacked layers have shape (num_layers, num_adapters, ...) while + # non-stacked (embed_tokens) have shape (num_adapters, ...) + is_stacked = "layers" in normalized_path + key_name = path[-2].key if key_name == "lora_ranks": + if is_stacked: + return value.at[:, adapter_index].set(effective_rank) return value.at[adapter_index].set(effective_rank) if key_name == "lora_scaling": # Set scaling to 0.0 if rank is 0 - return value.at[adapter_index].set(lora_config.alpha / effective_rank if effective_rank > 0 else 0.0) + scaling_value = lora_config.alpha / effective_rank if effective_rank > 0 else 0.0 + if is_stacked: + return value.at[:, adapter_index].set(scaling_value) + return value.at[adapter_index].set(scaling_value) if key_name == "lora_A": # Reinitialize with he_uniform, then zero columns beyond rank - shape = value[adapter_index].shape - new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype) - new_A = new_A.at[..., effective_rank:].set(0.0) - return value.at[adapter_index].set(new_A) + if is_stacked: + shape = value[:, adapter_index].shape + new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype) + new_A = new_A.at[..., effective_rank:].set(0.0) + return value.at[:, adapter_index].set(new_A) + else: + shape = value[adapter_index].shape + new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype) + new_A = new_A.at[..., effective_rank:].set(0.0) + return value.at[adapter_index].set(new_A) if key_name == "lora_B": # Explicitly zero lora_B + if is_stacked: + return value.at[:, adapter_index].set(0.0) return value.at[adapter_index].set(0.0) return value @@ -373,10 +391,16 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int): state = nnx.state(model) def clear_adapter(path, value): + normalized_path = tuple(p.key if hasattr(p, "key") else p.name for p in path) + is_stacked = "layers" in normalized_path key = path[-2].key if key == "lora_ranks": + if is_stacked: + return value.at[:, adapter_index].set(0) return value.at[adapter_index].set(0) if key in ("lora_scaling", "lora_A", "lora_B"): + if is_stacked: + return value.at[:, adapter_index].set(0.0) return value.at[adapter_index].set(0.0) return value diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index e6a29114c..8d4938928 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -156,10 +156,12 @@ def body_fn(carry, layer_idx): # Reconstruct KVCache if kv_cache is not None and final_kv is not None: # Decode mode: use updated cache from carry + # Increment cache_position by the number of new tokens processed + new_cache_position = kv_cache.cache_position + positions.shape[1] new_kv_cache = KVCache( keys=final_kv[0], values=final_kv[1], - cache_position=kv_cache.cache_position, + cache_position=new_cache_position, ) else: # Prefill mode: build cache from collected K/V outputs From 687f2a5cacc4cd8cb15cd94f3f6aae5d1a478bdd Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 14:57:16 -0800 Subject: [PATCH 077/133] minor fixes --- skyrl-tx/tx/models/utils.py | 21 ++++++------- skyrl-tx/tx/utils/generator.py | 11 +------ skyrl-tx/tx/utils/models.py | 54 ++++++++++++++++++---------------- 3 files changed, 41 insertions(+), 45 deletions(-) diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 8d4938928..7df7e171e 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -9,7 +9,7 @@ - KVCache must use stacked format: (num_layers, batch, seq, heads, dim) """ -from typing import TypeVar +from typing import Callable from flax import nnx import jax @@ -17,11 +17,9 @@ from tx.utils.generator import KVCache -T = TypeVar("T", bound=nnx.Module) - def create_stacked_layers( - create_layer_fn: callable, + create_layer_fn: Callable[[nnx.Rngs], nnx.Module], num_layers: int, rngs: nnx.Rngs, ) -> nnx.Module: @@ -85,8 +83,10 @@ def forward_layers( Returns: Tuple of: - Final hidden states of shape (batch, seq, hidden) - - List of intermediate hidden states (if output_hidden_states=True) - - Updated KV cache (if kv_cache was provided) + - List of intermediate hidden states (if output_hidden_states=True, else empty list) + - KV cache: In decode mode (kv_cache provided), returns the updated cache. + In prefill mode (kv_cache=None), returns a newly constructed cache from + layer outputs. Only None if num_layers=0. """ if num_layers == 0: return hidden_states, [], kv_cache @@ -128,9 +128,11 @@ def body_fn(carry, layer_idx): kv[1].at[layer_idx].set(v), ) - # Return updated carry and outputs for this iteration - # Always output (k, v) so we can build cache during prefill - # Output the layer OUTPUT (new_hs), not input, for hidden_states collection + # Return updated carry and outputs for this iteration. + # Note: We always output (k, v) because JAX scan requires fixed output structure. + # During decode (kv_cache provided), these are unused but the memory overhead is + # minimal since decode processes seq_len=1. During prefill, we need them to build + # the initial KV cache. hs_output = new_hs if output_hidden_states else None return (new_hs, new_kv), (hs_output, k, v) @@ -169,7 +171,6 @@ def body_fn(carry, layer_idx): new_kv_cache = KVCache.from_layer_outputs( keys=all_keys, values=all_values, - positions=positions, attention_mask=attention_mask, ) diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index 6afd261ab..e7b176871 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -30,7 +30,6 @@ class KVCache: def from_layer_outputs( keys: jax.Array, values: jax.Array, - positions: jax.Array, attention_mask: jax.Array, ) -> KVCache: """Create KVCache from stacked layer outputs after prefill. @@ -38,7 +37,6 @@ def from_layer_outputs( Args: keys: Stacked keys of shape (num_layers, batch, seq, num_kv_heads, head_dim). values: Stacked values of shape (num_layers, batch, seq, num_kv_heads, head_dim). - positions: Position indices of shape (batch, seq). attention_mask: Attention mask of shape (batch, seq). Returns: @@ -228,16 +226,9 @@ def _prefill_and_decode( last_logits = model.compute_logits(last_hidden, adapter_indices)[:, 0, :] prompt_logprobs_array = None - # Pad KV cache to max_length + # Pad KV cache to max_length (cache_position is already set by from_layer_outputs) kv_cache = outputs.kv_cache.pad_to_length(max_length) - # Update cache_position after prefill - kv_cache = KVCache( - keys=kv_cache.keys, - values=kv_cache.values, - cache_position=attention_mask.sum(axis=1).astype(jnp.int32), - ) - decode_attention_mask = jnp.pad(attention_mask, ((0, 0), (0, max_length - attention_mask.shape[1]))) def decode_fn(s: DecodeState, step: jax.Array) -> tuple[DecodeState, tuple[jax.Array, jax.Array]]: diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 2c3491e5f..66cda49f9 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -159,8 +159,8 @@ def load_safetensors( continue if _is_layer_param(path): - # Stack layer weights from individual layer tensors - layer_tensors = [] + # Pre-allocate array for stacked layer weights to avoid 2x memory from list + stack + stacked_tensor = np.empty(param.shape, dtype=param.dtype) for layer_idx in range(num_layers): key = _get_hf_key_for_layer(path, layer_idx) @@ -190,9 +190,7 @@ def load_safetensors( target_shape = param.shape[1:] tensor = tensor.reshape(target_shape) - layer_tensors.append(tensor) - - stacked_tensor = np.stack(layer_tensors, axis=0) + stacked_tensor[layer_idx] = tensor else: # Non-layer parameter - load directly key = _get_hf_key(path) @@ -389,21 +387,24 @@ def extract_adapter_state(adapter_index: int, lora_params: nnx.GraphState, rank: def extract_state(path: tuple, p: jnp.ndarray): if path[-2].key not in {"lora_A", "lora_B"}: return p - # For stacked layers, LoRA params have shape (num_layers, num_adapters, ...) - # We extract adapter_index from the adapter dimension + # LoRA param shapes: + # - 3D: Non-stacked linear/embed (A, in, R) or (A, R, out) + # - 4D: Stacked linear/embed (L, A, in, R) or non-stacked expert (A, E, in, R) + # - 5D: Stacked expert (L, A, E, in, R) + # We extract adapter_index from the adapter dimension (axis 1 for stacked, axis 0 otherwise) assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" + is_stacked = "layers" in [pk.key if hasattr(pk, "key") else str(pk) for pk in path] if path[-2].key == "lora_A": - # Shape: (L, A, in, R) or (A, in, R) -> extract [..., :, :rank] - if p.ndim == 4: # Stacked: (L, A, in, R) - return p[:, adapter_index, :, :rank] - else: # Non-stacked: (A, in, R) - return p[adapter_index, :, :rank] + if is_stacked: # (L, A, ..., R) + return p[:, adapter_index, ..., :rank] + else: # (A, ..., R) + return p[adapter_index, ..., :rank] if path[-2].key == "lora_B": - # Shape: (L, A, R, out) or (A, R, out) -> extract [..., :rank, :] - if p.ndim == 4: # Stacked: (L, A, R, out) - return p[:, adapter_index, :rank, :] - else: # Non-stacked: (A, R, out) - return p[adapter_index, :rank, :] + if is_stacked: # (L, A, ..., out) + return p[:, adapter_index, ..., :rank, :] + else: # (A, ..., out) + return p[adapter_index, ..., :rank, :] + return p # Defensive fallback (should not be reached) return jax.tree.map_with_path(extract_state, lora_params) @@ -418,17 +419,20 @@ def insert_adapter_state( def insert_state(path: tuple, p: jax.Array, new: jax.Array): if path[-2].key not in {"lora_A", "lora_B"}: return new + # See extract_adapter_state for shape documentation assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" + is_stacked = "layers" in [pk.key if hasattr(pk, "key") else str(pk) for pk in path] if path[-2].key == "lora_A": - if p.ndim == 4: # Stacked: (L, A, in, R) - return p.at[:, adapter_index, :, :rank].set(new) - else: # Non-stacked: (A, in, R) - return p.at[adapter_index, :, :rank].set(new) + if is_stacked: # (L, A, ..., R) + return p.at[:, adapter_index, ..., :rank].set(new) + else: # (A, ..., R) + return p.at[adapter_index, ..., :rank].set(new) elif path[-2].key == "lora_B": - if p.ndim == 4: # Stacked: (L, A, R, out) - return p.at[:, adapter_index, :rank, :].set(new) - else: # Non-stacked: (A, R, out) - return p.at[adapter_index, :rank, :].set(new) + if is_stacked: # (L, A, ..., out) + return p.at[:, adapter_index, ..., :rank, :].set(new) + else: # (A, ..., out) + return p.at[adapter_index, ..., :rank, :].set(new) + return new # Defensive fallback (should not be reached) updated = jax.tree.map_with_path(insert_state, lora_params, new_params) nnx.update(lora_params, updated) From 55a42e6f688be9b589d237bc19dc443956156b82 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 15:55:12 -0800 Subject: [PATCH 078/133] simplify and optimize forward_layers --- skyrl-tx/tx/models/utils.py | 103 +++++++++++------------------------- 1 file changed, 32 insertions(+), 71 deletions(-) diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 7df7e171e..fca7c6645 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -62,7 +62,7 @@ def forward_layers( kv_cache: KVCache | None, output_hidden_states: bool, gradient_checkpointing: bool, -) -> tuple[jax.Array, list[jax.Array], KVCache | None]: +) -> tuple[jax.Array, list[jax.Array], KVCache]: """Unified forward pass through stacked decoder layers. Uses jax.lax.scan for both training and inference. When gradient_checkpointing=True, @@ -76,42 +76,30 @@ def forward_layers( attention_mask: Attention mask of shape (batch, seq). positions: Position indices of shape (batch, seq). adapter_indices: Optional LoRA adapter indices of shape (batch,). - kv_cache: Optional KV cache with stacked keys/values. + kv_cache: Optional KV cache for decode mode (None for prefill). output_hidden_states: Whether to return intermediate hidden states. gradient_checkpointing: Whether to use gradient checkpointing. Returns: - Tuple of: - - Final hidden states of shape (batch, seq, hidden) - - List of intermediate hidden states (if output_hidden_states=True, else empty list) - - KV cache: In decode mode (kv_cache provided), returns the updated cache. - In prefill mode (kv_cache=None), returns a newly constructed cache from - layer outputs. Only None if num_layers=0. + Tuple of (final_hidden_states, all_hidden_states, kv_cache). """ - if num_layers == 0: - return hidden_states, [], kv_cache + assert num_layers > 0, "num_layers must be positive" - # Split layers into graph definition and stacked state layer_graphdef, layer_state = nnx.split(layers) + is_decode = kv_cache is not None - # Prepare stacked KV cache - stacked_kv: tuple[jax.Array, jax.Array] | None = None - if kv_cache is not None: - stacked_kv = (kv_cache.keys, kv_cache.values) + def body_fn(hs, xs): + # Unpack xs based on mode (structure differs between prefill and decode) + if is_decode: + layer_idx, layer_k, layer_v = xs + layer_kv = (layer_k, layer_v) + else: + layer_idx = xs + layer_kv = None - def body_fn(carry, layer_idx): - hs, kv = carry + # Reconstruct layer module from stacked weights + layer = nnx.merge(layer_graphdef, jax.tree.map(lambda x: x[layer_idx], layer_state)) - # Extract this layer's weights by indexing into stacked state - layer_weights = jax.tree.map(lambda x: x[layer_idx], layer_state) - layer = nnx.merge(layer_graphdef, layer_weights) - - # Get this layer's KV cache slice - layer_kv = None - if kv is not None: - layer_kv = (kv[0][layer_idx], kv[1][layer_idx]) - - # Forward through layer new_hs, (k, v) = layer( hs, attention_mask=attention_mask, @@ -120,58 +108,31 @@ def body_fn(carry, layer_idx): kv_cache=layer_kv, ) - # Update stacked KV cache if provided - new_kv = kv - if kv is not None: - new_kv = ( - kv[0].at[layer_idx].set(k), - kv[1].at[layer_idx].set(v), - ) - - # Return updated carry and outputs for this iteration. - # Note: We always output (k, v) because JAX scan requires fixed output structure. - # During decode (kv_cache provided), these are unused but the memory overhead is - # minimal since decode processes seq_len=1. During prefill, we need them to build - # the initial KV cache. hs_output = new_hs if output_hidden_states else None - return (new_hs, new_kv), (hs_output, k, v) + return new_hs, (hs_output, k, v) - # Apply gradient checkpointing if requested if gradient_checkpointing: body_fn = jax.checkpoint(body_fn) - # Scan over layer indices - (final_hs, final_kv), (all_hs, all_keys, all_values) = jax.lax.scan( - body_fn, - (hidden_states, stacked_kv), - jnp.arange(num_layers), - ) - - # Collect hidden states if requested - all_hidden_states: list[jax.Array] = [] - if output_hidden_states: - # all_hs has shape (num_layers, batch, seq, hidden) containing output of each layer - # We want [embed, layer0_out, layer1_out, ..., layer(N-2)_out] - # The model will append the normed layer(N-1)_out after calling this function - all_hidden_states = [hidden_states] + [all_hs[i] for i in range(num_layers - 1)] - - # Reconstruct KVCache - if kv_cache is not None and final_kv is not None: - # Decode mode: use updated cache from carry - # Increment cache_position by the number of new tokens processed - new_cache_position = kv_cache.cache_position + positions.shape[1] + # Prepare scan inputs: in decode mode, pass per-layer caches via xs + # Scan automatically slices along axis 0, so each iteration gets one layer's cache + layer_indices = jnp.arange(num_layers) + xs = (layer_indices, kv_cache.keys, kv_cache.values) if is_decode else layer_indices + + final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, xs) + + # [embed, layer0_out, ..., layer(N-2)_out]; final layer output gets normed by caller + all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] + + if is_decode: + # Decode mode: scan stacked the per-layer updated caches into (num_layers, ...) new_kv_cache = KVCache( - keys=final_kv[0], - values=final_kv[1], - cache_position=new_cache_position, - ) - else: - # Prefill mode: build cache from collected K/V outputs - # all_keys/all_values have shape (num_layers, batch, seq, heads, dim) - new_kv_cache = KVCache.from_layer_outputs( keys=all_keys, values=all_values, - attention_mask=attention_mask, + cache_position=kv_cache.cache_position + positions.shape[1], ) + else: + # Prefill mode: build cache from collected k,v outputs + new_kv_cache = KVCache.from_layer_outputs(all_keys, all_values, attention_mask) return final_hs, all_hidden_states, new_kv_cache From 38509ce405d2893d2ec6e9f077c245371fffbbfb Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 15:55:21 -0800 Subject: [PATCH 079/133] skip skyrl-train --- skyrl-tx/pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index 33ea2c349..cc1aa8f52 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -93,6 +93,10 @@ tx = "tx.run.main:app" # The following is for supporting the skyrl-train dependency +[tool.uv] +# Exclude skyrl-train on macOS since it requires CUDA torch +exclude-dependencies = ["skyrl-train"] + [tool.uv.extra-build-dependencies] flash-attn = [{requirement = "torch", match-runtime = true}] transformer-engine = [{requirement = "torch", match-runtime = true}, "build_tools", "ninja"] @@ -104,4 +108,4 @@ flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"} [tool.uv.sources] # For now, just always use the current main branch, later it will be better to pin it to a released version. For development, you # can set it to your own development branch. -skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" } +# skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" } From 6d4d17db04ec0583c55b9d50d30835f57ef27feb Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 17:14:05 -0800 Subject: [PATCH 080/133] simplify models.py --- skyrl-tx/tx/utils/models.py | 190 +++++++++++------------------------- 1 file changed, 58 insertions(+), 132 deletions(-) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 66cda49f9..54d89f58d 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -81,40 +81,61 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: def _is_layer_param(path: tuple) -> bool: """Check if a parameter path corresponds to a stacked decoder layer weight.""" path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] - # Layer params have 'layers' in their path but not as part of another word return "layers" in path_strs -def _get_hf_key_for_layer(path: tuple, layer_idx: int) -> str: - """Convert a stacked layer param path to a per-layer HuggingFace key.""" +def _path_to_hf_key(path: tuple, layer_idx: int | None = None) -> str: + """Convert param path to HuggingFace key. If layer_idx provided, insert it after 'layers'.""" parts = [] for p in path: key = p.key if hasattr(p, "key") else str(p) - if key == "layers": + if key == "layers" and layer_idx is not None: parts.append(f"layers.{layer_idx}") elif key in ("kernel", "embedding"): parts.append("weight") elif key in ("lora_A", "lora_B"): - parts.append(key) - parts.append("weight") + parts.extend([key, "weight"]) else: parts.append(key) return ".".join(parts) -def _get_hf_key(path: tuple) -> str: - """Convert a non-layer param path to a HuggingFace key.""" - parts = [] - for p in path: - key = p.key if hasattr(p, "key") else str(p) - if key in ("kernel", "embedding"): - parts.append("weight") - elif key in ("lora_A", "lora_B"): - parts.append(key) - parts.append("weight") - else: - parts.append(key) - return ".".join(parts) +def _load_hf_tensor(tensors: dict, key: str, target_shape: tuple, num_experts: int | None) -> np.ndarray: + """Load tensor from HF format, handling experts, transpose, and reshape.""" + # Handle MoE expert weights (HF stores each expert separately) + if ".experts." in key and num_experts: + tensor = np.stack([ + tensors[key.replace(".experts.", f".experts.{i}.")].T + for i in range(num_experts) + ], axis=0) + else: + tensor = tensors[key] + if "embed_tokens" not in key: + tensor = tensor.T + + # Reshape attention projections to match model's grouped head format + if any(p in key for p in ("q_proj", "k_proj", "v_proj", "o_proj")): + tensor = tensor.reshape(target_shape) + + return tensor + + +def _save_hf_tensor(tensors: dict, key: str, param: np.ndarray, num_experts: int | None) -> None: + """Save tensor to HF format, handling experts, transpose, and reshape.""" + # Handle MoE expert weights + if ".experts." in key and num_experts: + for i in range(num_experts): + tensors[key.replace(".experts.", f".experts.{i}.")] = param[i].T + return + + # Reshape attention projections back to 2D + if any(p in key for p in ("q_proj", "k_proj", "v_proj")): + param = param.reshape(param.shape[0], -1) + elif "o_proj" in key: + param = param.reshape(-1, param.shape[-1]) + + # Transpose to HF format + tensors[key] = param if "embed_tokens" in key else param.T def load_safetensors( @@ -126,25 +147,13 @@ def load_safetensors( prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: - """Load safetensors weights into a model with stacked layers. - - For layer parameters, loads individual layer weights and stacks them. - For non-layer parameters, loads directly. - - Args: - checkpoint_dir: Directory containing safetensors files. - config: Model configuration. - model: Model with stacked layer weights (created with create_stacked_layers). - num_layers: Number of decoder layers. - skip_lora: Whether to skip LoRA parameters. - prefix: Prefix to remove from tensor keys. - filter_fn: Optional filter for which parameters to load. - """ + """Load safetensors weights into a model with stacked layers.""" tensors = {} for file in Path(checkpoint_dir).glob("*.safetensors"): tensors.update(safetensors.numpy.load_file(file)) tensors = {k.removeprefix(prefix): v for k, v in tensors.items()} + num_experts = getattr(config, "num_experts", None) model_params = nnx.to_flat_state(nnx.state(model)) updates = [] @@ -153,69 +162,21 @@ def load_safetensors( continue path_keys = [p.key if hasattr(p, "key") else str(p) for p in path] - - # Skip LoRA parameters if requested if skip_lora and any(k in path_keys for k in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")): continue if _is_layer_param(path): - # Pre-allocate array for stacked layer weights to avoid 2x memory from list + stack + # Stack per-layer weights from HF format stacked_tensor = np.empty(param.shape, dtype=param.dtype) - for layer_idx in range(num_layers): - key = _get_hf_key_for_layer(path, layer_idx) - - # Handle expert weights (MoE) - HF stores each expert separately - # Our model has shape (num_experts, in, out), HF has experts.{idx}.*.weight - if ".experts." in key and hasattr(config, "num_experts"): - num_experts = config.num_experts - expert_tensors = [] - for expert_idx in range(num_experts): - # Insert expert index: experts.gate_proj -> experts.0.gate_proj - expert_key = key.replace(".experts.", f".experts.{expert_idx}.") - if expert_key in tensors: - expert_tensors.append(tensors[expert_key].T) - if expert_tensors: - tensor = np.stack(expert_tensors, axis=0) - else: - raise KeyError(f"Expert weights not found for {key}") - else: - tensor = tensors[key] - # Transpose linear weights (HF uses [out, in], we use [in, out]) - if "embed_tokens" not in key: - tensor = tensor.T - - # Reshape attention projections if needed - if any(proj in key for proj in ("q_proj", "k_proj", "v_proj", "o_proj")): - # param.shape[1:] gives the target shape without the layer axis - target_shape = param.shape[1:] - tensor = tensor.reshape(target_shape) - - stacked_tensor[layer_idx] = tensor + for i in range(num_layers): + key = _path_to_hf_key(path, layer_idx=i) + stacked_tensor[i] = _load_hf_tensor(tensors, key, param.shape[1:], num_experts) else: - # Non-layer parameter - load directly - key = _get_hf_key(path) - - if ".experts." in key and hasattr(config, "num_experts"): - num_experts = config.num_experts - expert_tensors = [] - for expert_idx in range(num_experts): - expert_key = key.replace(".experts.", f".experts.{expert_idx}.") - if expert_key in tensors: - expert_tensors.append(tensors[expert_key].T) - if expert_tensors: - stacked_tensor = np.stack(expert_tensors, axis=0) - else: - raise KeyError(f"Expert weights not found for {key}") - else: - stacked_tensor = tensors[key] - if "embed_tokens" not in key: - stacked_tensor = stacked_tensor.T - - assert param.shape == stacked_tensor.shape, ( - f"Shape mismatch for {path}: expected {param.shape}, got {stacked_tensor.shape}" - ) - sharded_tensor = jax.device_put(stacked_tensor.astype(param.dtype), param.sharding) - updates.append((path, sharded_tensor)) + key = _path_to_hf_key(path) + stacked_tensor = _load_hf_tensor(tensors, key, param.shape, num_experts) + + assert param.shape == stacked_tensor.shape, f"Shape mismatch for {path}" + updates.append((path, jax.device_put(stacked_tensor.astype(param.dtype), param.sharding))) nnx.update(model, nnx.from_flat_state(updates)) @@ -228,16 +189,8 @@ def save_safetensors( prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: - """Save model weights to safetensors, unstacking layer weights for HF compatibility. - - Args: - config: Model configuration. - model: Model with stacked layer weights. - filename: Output safetensors file path. - num_layers: Number of decoder layers. - prefix: Prefix to add to tensor keys. - filter_fn: Optional filter for which parameters to save. - """ + """Save model weights to safetensors, unstacking layer weights for HF compatibility.""" + num_experts = getattr(config, "num_experts", None) model_params = nnx.to_flat_state(nnx.state(model)) tensors = {} @@ -250,39 +203,12 @@ def save_safetensors( if _is_layer_param(path): # Unstack and save as individual layer weights - for layer_idx in range(num_layers): - key = prefix + _get_hf_key_for_layer(path, layer_idx) - layer_param = param[layer_idx] - - # Handle expert weights (MoE) - save each expert separately for HF compatibility - if ".experts." in key and hasattr(config, "num_experts"): - for expert_idx in range(config.num_experts): - expert_key = key.replace(".experts.", f".experts.{expert_idx}.") - tensors[expert_key] = layer_param[expert_idx].T - else: - # Reshape attention projections back to 2D - if "q_proj" in key or "k_proj" in key or "v_proj" in key: - layer_param = layer_param.reshape(layer_param.shape[0], -1) - elif "o_proj" in key: - layer_param = layer_param.reshape(-1, layer_param.shape[-1]) - - # Transpose back to HF format - if "embed_tokens" not in key: - layer_param = layer_param.T - tensors[key] = layer_param + for i in range(num_layers): + key = prefix + _path_to_hf_key(path, layer_idx=i) + _save_hf_tensor(tensors, key, param[i], num_experts) else: - # Non-layer parameter - save directly - key = prefix + _get_hf_key(path) - - if ".experts." in key and hasattr(config, "num_experts"): - for expert_idx in range(config.num_experts): - expert_key = key.replace(".experts.", f".experts.{expert_idx}.") - tensors[expert_key] = param[expert_idx].T - else: - tensor = param - if "embed_tokens" not in key: - tensor = tensor.T - tensors[key] = tensor + key = prefix + _path_to_hf_key(path) + _save_hf_tensor(tensors, key, param, num_experts) # In multi-host mode, gather all shards and only save from rank 0 if jax.process_count() > 1: From 846aa967da62ca541c87592a27d708a42b73234d Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 17:27:22 -0800 Subject: [PATCH 081/133] clean up lora.py --- skyrl-tx/tx/layers/lora.py | 65 ++++++++++++++------------------------ 1 file changed, 23 insertions(+), 42 deletions(-) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index f7c89fdd5..cdc0e2cfa 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -16,16 +16,20 @@ def _get_sharding_spec(arr: jax.Array): Use jax.typeof() to get sharding info from traced arrays. """ if isinstance(arr, Tracer): - # For traced arrays, use jax.typeof to get the abstract value with sharding aval = jax.typeof(arr) if hasattr(aval, "sharding") and aval.sharding is not None: return aval.sharding.spec return None - else: - # For concrete arrays, access sharding directly - if arr.sharding is not None: - return arr.sharding.spec - return None + if arr.sharding is not None: + return arr.sharding.spec + return None + + +def _adapter_index(is_stacked: bool, adapter_index: int): + """Return index for accessing an adapter. Stacked params have layers as first dim.""" + # Stacked layers have shape (num_layers, num_adapters, ...), + # non-stacked (embed_tokens) have shape (num_adapters, ...) + return (slice(None), adapter_index) if is_stacked else (adapter_index,) class LoRAMixin: @@ -364,39 +368,22 @@ def init_adapter(path, value): if not filter_lora(lora_config, normalized_path): effective_rank = 0 - # Check if this is a stacked layer parameter (shape has extra leading dimension) - # Stacked layers have shape (num_layers, num_adapters, ...) while - # non-stacked (embed_tokens) have shape (num_adapters, ...) - is_stacked = "layers" in normalized_path + idx = _adapter_index("layers" in normalized_path, adapter_index) key_name = path[-2].key if key_name == "lora_ranks": - if is_stacked: - return value.at[:, adapter_index].set(effective_rank) - return value.at[adapter_index].set(effective_rank) + return value.at[idx].set(effective_rank) if key_name == "lora_scaling": - # Set scaling to 0.0 if rank is 0 - scaling_value = lora_config.alpha / effective_rank if effective_rank > 0 else 0.0 - if is_stacked: - return value.at[:, adapter_index].set(scaling_value) - return value.at[adapter_index].set(scaling_value) + scaling = lora_config.alpha / effective_rank if effective_rank > 0 else 0.0 + return value.at[idx].set(scaling) if key_name == "lora_A": # Reinitialize with he_uniform, then zero columns beyond rank - if is_stacked: - shape = value[:, adapter_index].shape - new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype) - new_A = new_A.at[..., effective_rank:].set(0.0) - return value.at[:, adapter_index].set(new_A) - else: - shape = value[adapter_index].shape - new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype) - new_A = new_A.at[..., effective_rank:].set(0.0) - return value.at[adapter_index].set(new_A) + new_A = nnx.initializers.he_uniform()(rngs.params(), value[idx].shape, value.dtype) + new_A = new_A.at[..., effective_rank:].set(0.0) + return value.at[idx].set(new_A) if key_name == "lora_B": # Explicitly zero lora_B - if is_stacked: - return value.at[:, adapter_index].set(0.0) - return value.at[adapter_index].set(0.0) + return value.at[idx].set(0.0) return value updated_state = jax.tree.map_with_path(init_adapter, state) @@ -412,18 +399,12 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int): state = nnx.state(model) def clear_adapter(path, value): - normalized_path = tuple(p.key if hasattr(p, "key") else p.name for p in path) - is_stacked = "layers" in normalized_path key = path[-2].key - if key == "lora_ranks": - if is_stacked: - return value.at[:, adapter_index].set(0) - return value.at[adapter_index].set(0) - if key in ("lora_scaling", "lora_A", "lora_B"): - if is_stacked: - return value.at[:, adapter_index].set(0.0) - return value.at[adapter_index].set(0.0) - return value + if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"): + return value + normalized_path = tuple(p.key if hasattr(p, "key") else p.name for p in path) + idx = _adapter_index("layers" in normalized_path, adapter_index) + return value.at[idx].set(0 if key == "lora_ranks" else 0.0) updated_state = jax.tree.map_with_path(clear_adapter, state) nnx.update(model, updated_state) From 521734341484e630c27941745f22953235e1afec Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 29 Jan 2026 17:41:09 -0800 Subject: [PATCH 082/133] fix tests/utils --- skyrl-tx/tests/utils/test_generator.py | 5 +++-- skyrl-tx/tests/utils/test_models.py | 10 +++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 89bc637be..f4cbe3421 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -52,8 +52,9 @@ def __call__( if kv_cache is None: # Prefill: deterministic hidden_states (which equal logits) hidden_states = jnp.tile(base[None, None, :], (batch_size, seq_len, 1)) - keys = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] - values = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] + # Stacked format: (num_layers, batch, seq, heads, dim) - use 1 layer for this dummy model + keys = jnp.zeros((1, batch_size, seq_len, 1, 1), dtype=jnp.float32) + values = jnp.zeros((1, batch_size, seq_len, 1, 1), dtype=jnp.float32) # Per-sequence cache_position (all same length in this test) cache_position = ( attention_mask.sum(axis=1) if attention_mask is not None else jnp.full((batch_size,), seq_len) diff --git a/skyrl-tx/tests/utils/test_models.py b/skyrl-tx/tests/utils/test_models.py index 70c177fe3..2c74950af 100644 --- a/skyrl-tx/tests/utils/test_models.py +++ b/skyrl-tx/tests/utils/test_models.py @@ -55,14 +55,18 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat adapter_config = LoraConfig(rank=rank, alpha=alpha, seed=0) # Set LoRA weights to random values for testing (to catch transpose bugs) - q_proj = model.model.layers[0].self_attn.q_proj + # layers is now stacked, so access directly (not subscriptable) + # LoRA weights have shape (num_layers, num_adapters, ...) for stacked layers + q_proj = model.model.layers.self_attn.q_proj rng1, rng2 = jax.random.split(jax.random.PRNGKey(42)) q_proj.lora_A[...] = jax.random.normal(rng1, q_proj.lora_A[...].shape) q_proj.lora_B[...] = jax.random.normal(rng2, q_proj.lora_B[...].shape) # Store expected values (trimmed to rank and transposed) - expected_lora_A = np.array(q_proj.lora_A[...][adapter_index, :, :rank].T) - expected_lora_B = np.array(q_proj.lora_B[...][adapter_index, :rank, :].T) + # For stacked layers: shape is (num_layers, num_adapters, in_dim, rank) for lora_A + # We have 1 layer, so index [0] for layer, then adapter_index + expected_lora_A = np.array(q_proj.lora_A[...][0, adapter_index, :, :rank].T) + expected_lora_B = np.array(q_proj.lora_B[...][0, adapter_index, :rank, :].T) # Save and verify checkpoint exists models.save_lora_checkpoint(model, base_model_name, adapter_config, adapter_index, output_path) From 6bf3cae1dcbe5ced87c72a3d8b77102dc90749a1 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 10:12:48 -0800 Subject: [PATCH 083/133] Update tests and load_safetensors for stacked layer format - Add _is_stacked_layer_param helper to distinguish stacked vs non-stacked paths - Update load_safetensors/save_safetensors to handle both formats - Add num_layers argument to load_safetensors calls - Use Auto axis types in test mesh to avoid sharding errors - Update KV cache assertions for stacked array format Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_deepseekv3.py | 2 +- .../models/test_deepseekv3_lora_training.py | 4 +-- skyrl-tx/tests/models/test_models_common.py | 9 ++++-- skyrl-tx/tx/tinker/backends/jax.py | 2 +- skyrl-tx/tx/utils/models.py | 31 ++++++++++++++----- 5 files changed, 34 insertions(+), 14 deletions(-) diff --git a/skyrl-tx/tests/models/test_deepseekv3.py b/skyrl-tx/tests/models/test_deepseekv3.py index 188917e12..1a33e0987 100644 --- a/skyrl-tx/tests/models/test_deepseekv3.py +++ b/skyrl-tx/tests/models/test_deepseekv3.py @@ -53,7 +53,7 @@ def test_deepseekv3(tp: int): ) with jax.set_mesh(mesh): model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model) + load_safetensors(tmp, config, model, config.num_hidden_layers) outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy(), output_hidden_states=True) diff --git a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py index ab1038d2b..bbb181c13 100644 --- a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py +++ b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py @@ -62,7 +62,7 @@ def test_lora_training_moe_rank_normalized(): ) with jax.set_mesh(mesh): model = DeepseekV3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, config, model) + load_safetensors(checkpoint_path, config, model, config.num_hidden_layers) # Set different ranks for each adapter (0: rank 16, 1: rank 8) # For routed experts with 256 experts: effective rank = max(1, rank // 256) = 1 @@ -152,7 +152,7 @@ def test_lora_training_high_rank(): ) with jax.set_mesh(mesh): model = DeepseekV3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, config, model) + load_safetensors(checkpoint_path, config, model, config.num_hidden_layers) init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) init_lora_adapter(model, adapter_index=1, lora_config=LoraConfig(rank=8, alpha=8, seed=1)) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 1ae7fad95..1edcf5e6c 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -34,8 +34,10 @@ def create_model( """Create model with random weights for testing.""" base_config = AutoConfig.from_pretrained(model_name) config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True, **config_kwargs) - mesh_kwargs = {"axis_types": mesh_axis_types} if mesh_axis_types else {} - mesh = jax.make_mesh((1, 1), mesh_axes, **mesh_kwargs) + # Default to Auto axis types to avoid sharding resolution errors + if mesh_axis_types is None: + mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes) + mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=mesh_axis_types) with jax.set_mesh(mesh): model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed)) return model, config @@ -140,7 +142,8 @@ def test_eval_mode_uses_standard_path( out = model(input_ids, attention_mask=attention_mask) # KV cache should be populated (checkpointed path returns empty) - assert len(out.kv_cache.keys) == config.num_hidden_layers + # keys is a stacked array with shape (num_layers, batch, seq, heads, dim) + assert out.kv_cache.keys.shape[0] == config.num_hidden_layers @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index bd1e16da7..3547b3ba2 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -185,7 +185,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig): ) with jax.set_mesh(self.mesh), nnx.use_eager_sharding(True): self.model = model_class(self.model_config, dtype=get_dtype(self.model_config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, self.model_config, self.model) + load_safetensors(checkpoint_path, self.model_config, self.model, self.model.model.num_layers) # Split model into LoRA and non-LoRA parameters self.graphdef, self.lora_params, self.non_lora_params = nnx.split(self.model, self.model.is_lora_param, ...) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 54d89f58d..630a8628d 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -79,11 +79,26 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: def _is_layer_param(path: tuple) -> bool: - """Check if a parameter path corresponds to a stacked decoder layer weight.""" + """Check if a parameter path corresponds to a decoder layer weight.""" path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] return "layers" in path_strs +def _is_stacked_layer_param(path: tuple) -> bool: + """Check if a parameter path corresponds to a STACKED decoder layer weight. + + Stacked layers (Qwen3/Llama3) have paths like: ('model', 'layers', 'self_attn', ...) + Non-stacked layers (DeepSeekV3) have paths like: ('model', 'layers', '0', 'self_attn', ...) + """ + path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] + if "layers" not in path_strs: + return False + layers_idx = path_strs.index("layers") + if layers_idx + 1 < len(path_strs) and path_strs[layers_idx + 1].isdigit(): + return False # Non-stacked: path already contains layer index + return True # Stacked: no layer index in path + + def _path_to_hf_key(path: tuple, layer_idx: int | None = None) -> str: """Convert param path to HuggingFace key. If layer_idx provided, insert it after 'layers'.""" parts = [] @@ -153,7 +168,7 @@ def load_safetensors( tensors.update(safetensors.numpy.load_file(file)) tensors = {k.removeprefix(prefix): v for k, v in tensors.items()} - num_experts = getattr(config, "num_experts", None) + num_experts = config.get_num_experts() model_params = nnx.to_flat_state(nnx.state(model)) updates = [] @@ -165,13 +180,14 @@ def load_safetensors( if skip_lora and any(k in path_keys for k in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")): continue - if _is_layer_param(path): - # Stack per-layer weights from HF format + if _is_stacked_layer_param(path): + # Stack per-layer weights from HF format (stacked layers like Qwen3/Llama3) stacked_tensor = np.empty(param.shape, dtype=param.dtype) for i in range(num_layers): key = _path_to_hf_key(path, layer_idx=i) stacked_tensor[i] = _load_hf_tensor(tensors, key, param.shape[1:], num_experts) else: + # Non-stacked layers (like DeepSeekV3) or non-layer params key = _path_to_hf_key(path) stacked_tensor = _load_hf_tensor(tensors, key, param.shape, num_experts) @@ -190,7 +206,7 @@ def save_safetensors( filter_fn: Callable[[tuple], bool] | None = None, ) -> None: """Save model weights to safetensors, unstacking layer weights for HF compatibility.""" - num_experts = getattr(config, "num_experts", None) + num_experts = config.get_num_experts() model_params = nnx.to_flat_state(nnx.state(model)) tensors = {} @@ -201,12 +217,13 @@ def save_safetensors( if filter_fn is not None and not filter_fn(path): continue - if _is_layer_param(path): - # Unstack and save as individual layer weights + if _is_stacked_layer_param(path): + # Unstack and save as individual layer weights (stacked layers like Qwen3/Llama3) for i in range(num_layers): key = prefix + _path_to_hf_key(path, layer_idx=i) _save_hf_tensor(tensors, key, param[i], num_experts) else: + # Non-stacked layers (like DeepSeekV3) or non-layer params key = prefix + _path_to_hf_key(path) _save_hf_tensor(tensors, key, param, num_experts) From 801458b477edd18d0f5d0b034d55959a2996bdca Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 10:24:04 -0800 Subject: [PATCH 084/133] Add workarounds for non-stacked DeepSeekV3 layers - Add KVCache.update() to stack list-based KV outputs from non-stacked models - Add _is_stacked_path() in lora.py to correctly index LoRA params These workarounds allow DeepSeekV3 to work with the new stacked layer format used by Qwen3/Llama3, without modifying the DeepSeekV3 model itself. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/layers/lora.py | 21 +++++++++++++++++++-- skyrl-tx/tx/utils/generator.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index cdc0e2cfa..2b6d76074 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -32,6 +32,23 @@ def _adapter_index(is_stacked: bool, adapter_index: int): return (slice(None), adapter_index) if is_stacked else (adapter_index,) +def _is_stacked_path(normalized_path: tuple[str | int, ...]) -> bool: + """Check if a parameter path corresponds to a STACKED decoder layer weight. + + Stacked layers (Qwen3/Llama3) have paths like: ('model', 'layers', 'self_attn', ...) + Non-stacked layers (DeepSeekV3) have paths like: ('model', 'layers', 0, 'self_attn', ...) + """ + if "layers" not in normalized_path: + return False + layers_idx = normalized_path.index("layers") + if layers_idx + 1 < len(normalized_path): + next_elem = normalized_path[layers_idx + 1] + # Check if next element is a layer index (int or numeric string) + if isinstance(next_elem, int) or (isinstance(next_elem, str) and next_elem.isdigit()): + return False # Non-stacked: path already contains layer index + return True # Stacked: no layer index in path + + class LoRAMixin: """A mixin for flax NNX modules to add multi-adapter LoRA support. This mixin adds LoRA parameters (lora_A, lora_B) and methods to apply @@ -368,7 +385,7 @@ def init_adapter(path, value): if not filter_lora(lora_config, normalized_path): effective_rank = 0 - idx = _adapter_index("layers" in normalized_path, adapter_index) + idx = _adapter_index(_is_stacked_path(normalized_path), adapter_index) key_name = path[-2].key if key_name == "lora_ranks": @@ -403,7 +420,7 @@ def clear_adapter(path, value): if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"): return value normalized_path = tuple(p.key if hasattr(p, "key") else p.name for p in path) - idx = _adapter_index("layers" in normalized_path, adapter_index) + idx = _adapter_index(_is_stacked_path(normalized_path), adapter_index) return value.at[idx].set(0 if key == "lora_ranks" else 0.0) updated_state = jax.tree.map_with_path(clear_adapter, state) diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index e7b176871..05a78a861 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -99,6 +99,34 @@ def pad_to_length(self, max_length: int) -> KVCache: cache_position=self.cache_position, ) + @staticmethod + def update( + kv_cache: KVCache | None, + keys: list[jax.Array], + values: list[jax.Array], + positions: jax.Array, + attention_mask: jax.Array, + ) -> KVCache: + """Create KVCache from list of per-layer outputs (for non-stacked models like DeepSeekV3). + + Args: + kv_cache: Existing KVCache (None during prefill). + keys: List of key arrays per layer. + values: List of value arrays per layer. + positions: Position indices with shape (batch, seq_len). + attention_mask: Attention mask with shape (batch, seq_len). + + Returns: + New KVCache with stacked keys/values and computed cache_position. + """ + stacked_keys = jnp.stack(keys, axis=0) + stacked_values = jnp.stack(values, axis=0) + if kv_cache is not None: + cache_position = kv_cache.cache_position + positions.shape[1] + else: + cache_position = attention_mask.sum(axis=1).astype(jnp.int32) + return KVCache(keys=stacked_keys, values=stacked_values, cache_position=cache_position) + @property def num_layers(self) -> int: """Number of layers in the cache.""" From e7bab9399bb4c8286bf02f38fd8584ca64e37d72 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 10:24:26 -0800 Subject: [PATCH 085/133] Revert "Add workarounds for non-stacked DeepSeekV3 layers" This reverts commit 801458b477edd18d0f5d0b034d55959a2996bdca. --- skyrl-tx/tx/layers/lora.py | 21 ++------------------- skyrl-tx/tx/utils/generator.py | 28 ---------------------------- 2 files changed, 2 insertions(+), 47 deletions(-) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 2b6d76074..cdc0e2cfa 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -32,23 +32,6 @@ def _adapter_index(is_stacked: bool, adapter_index: int): return (slice(None), adapter_index) if is_stacked else (adapter_index,) -def _is_stacked_path(normalized_path: tuple[str | int, ...]) -> bool: - """Check if a parameter path corresponds to a STACKED decoder layer weight. - - Stacked layers (Qwen3/Llama3) have paths like: ('model', 'layers', 'self_attn', ...) - Non-stacked layers (DeepSeekV3) have paths like: ('model', 'layers', 0, 'self_attn', ...) - """ - if "layers" not in normalized_path: - return False - layers_idx = normalized_path.index("layers") - if layers_idx + 1 < len(normalized_path): - next_elem = normalized_path[layers_idx + 1] - # Check if next element is a layer index (int or numeric string) - if isinstance(next_elem, int) or (isinstance(next_elem, str) and next_elem.isdigit()): - return False # Non-stacked: path already contains layer index - return True # Stacked: no layer index in path - - class LoRAMixin: """A mixin for flax NNX modules to add multi-adapter LoRA support. This mixin adds LoRA parameters (lora_A, lora_B) and methods to apply @@ -385,7 +368,7 @@ def init_adapter(path, value): if not filter_lora(lora_config, normalized_path): effective_rank = 0 - idx = _adapter_index(_is_stacked_path(normalized_path), adapter_index) + idx = _adapter_index("layers" in normalized_path, adapter_index) key_name = path[-2].key if key_name == "lora_ranks": @@ -420,7 +403,7 @@ def clear_adapter(path, value): if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"): return value normalized_path = tuple(p.key if hasattr(p, "key") else p.name for p in path) - idx = _adapter_index(_is_stacked_path(normalized_path), adapter_index) + idx = _adapter_index("layers" in normalized_path, adapter_index) return value.at[idx].set(0 if key == "lora_ranks" else 0.0) updated_state = jax.tree.map_with_path(clear_adapter, state) diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index 05a78a861..e7b176871 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -99,34 +99,6 @@ def pad_to_length(self, max_length: int) -> KVCache: cache_position=self.cache_position, ) - @staticmethod - def update( - kv_cache: KVCache | None, - keys: list[jax.Array], - values: list[jax.Array], - positions: jax.Array, - attention_mask: jax.Array, - ) -> KVCache: - """Create KVCache from list of per-layer outputs (for non-stacked models like DeepSeekV3). - - Args: - kv_cache: Existing KVCache (None during prefill). - keys: List of key arrays per layer. - values: List of value arrays per layer. - positions: Position indices with shape (batch, seq_len). - attention_mask: Attention mask with shape (batch, seq_len). - - Returns: - New KVCache with stacked keys/values and computed cache_position. - """ - stacked_keys = jnp.stack(keys, axis=0) - stacked_values = jnp.stack(values, axis=0) - if kv_cache is not None: - cache_position = kv_cache.cache_position + positions.shape[1] - else: - cache_position = attention_mask.sum(axis=1).astype(jnp.int32) - return KVCache(keys=stacked_keys, values=stacked_values, cache_position=cache_position) - @property def num_layers(self) -> int: """Number of layers in the cache.""" From c18747de49c575f9ed641257da460dd93cdc26f6 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 11:15:37 -0800 Subject: [PATCH 086/133] Implement split stacked layers for DeepSeekV3 - Split DeepseekV3DecoderLayer into DenseDecoderLayer and MoEDecoderLayer - Use create_stacked_layers/forward_layers for both layer groups - Add _get_layer_group_info for HF weight loading with layer offsets - Update LoRA adapter indexing to handle dense_layers/moe_layers paths - Fix dtype preservation in MoE routing weights - Update tests for stacked adapter extraction This enables gradient checkpointing and unified forward pass for DeepSeekV3, matching the architecture used by Qwen3/Llama3. Co-Authored-By: Claude Opus 4.5 --- .../models/test_deepseekv3_lora_training.py | 32 ++++- skyrl-tx/tx/layers/lora.py | 6 +- skyrl-tx/tx/models/deepseekv3.py | 127 +++++++++++++++--- skyrl-tx/tx/utils/models.py | 65 +++++++-- 4 files changed, 188 insertions(+), 42 deletions(-) diff --git a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py index bbb181c13..054abc56a 100644 --- a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py +++ b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py @@ -27,6 +27,15 @@ def _is_routed_expert_path(path) -> bool: return False +def _is_stacked_path(path) -> bool: + """Check if path is for stacked layers (dense_layers or moe_layers).""" + for p in path: + key = p.key if hasattr(p, "key") else str(p) + if key in ("dense_layers", "moe_layers"): + return True + return False + + def _get_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: int): """Extract out-of-rank params, using effective rank for routed expert layers.""" @@ -38,11 +47,18 @@ def slice_param(path, p): else: effective_rank = rank + # For stacked layers, adapter index is dim 1; for non-stacked, it's dim 0 + is_stacked = _is_stacked_path(path) + if "lora_A" in path_str: - # lora_A shape: [adapters, ..., max_rank] - slice last dim + # lora_A shape: [layers, adapters, ..., max_rank] (stacked) or [adapters, ..., max_rank] + if is_stacked: + return p[:, adapter_idx, ..., effective_rank:].copy() return p[adapter_idx, ..., effective_rank:].copy() elif "lora_B" in path_str: - # lora_B shape: [adapters, ..., max_rank, out] - slice second-to-last dim + # lora_B shape: [layers, adapters, ..., max_rank, out] (stacked) or [adapters, ..., max_rank, out] + if is_stacked: + return p[:, adapter_idx, ..., effective_rank:, :].copy() return p[adapter_idx, ..., effective_rank:, :].copy() return p @@ -86,7 +102,11 @@ def loss_fn(model, input_ids, target_ids, attention_mask): graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) def get_adapter_params(params, adapter_idx): - return jax.tree.map(lambda p: p[adapter_idx].copy(), params) + def extract(path, p): + if _is_stacked_path(path): + return p[:, adapter_idx].copy() + return p[adapter_idx].copy() + return jax.tree.map_with_path(extract, params) num_experts = config.n_routed_experts @@ -173,7 +193,11 @@ def loss_fn(model, input_ids, target_ids, attention_mask): graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) def get_adapter_params(params, adapter_idx): - return jax.tree.map(lambda p: p[adapter_idx].copy(), params) + def extract(path, p): + if _is_stacked_path(path): + return p[:, adapter_idx].copy() + return p[adapter_idx].copy() + return jax.tree.map_with_path(extract, params) num_experts = config.n_routed_experts diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index cdc0e2cfa..259372baf 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -368,7 +368,8 @@ def init_adapter(path, value): if not filter_lora(lora_config, normalized_path): effective_rank = 0 - idx = _adapter_index("layers" in normalized_path, adapter_index) + is_stacked = any(name in normalized_path for name in ("layers", "dense_layers", "moe_layers")) + idx = _adapter_index(is_stacked, adapter_index) key_name = path[-2].key if key_name == "lora_ranks": @@ -403,7 +404,8 @@ def clear_adapter(path, value): if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"): return value normalized_path = tuple(p.key if hasattr(p, "key") else p.name for p in path) - idx = _adapter_index("layers" in normalized_path, adapter_index) + is_stacked = any(name in normalized_path for name in ("layers", "dense_layers", "moe_layers")) + idx = _adapter_index(is_stacked, adapter_index) return value.at[idx].set(0 if key == "lora_ranks" else 0.0) updated_state = jax.tree.map_with_path(clear_adapter, state) diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index a2e48abdf..70d1f2649 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -9,6 +9,7 @@ from tx.layers.layernorm import RMSNorm from tx.models.configs import DeepseekV3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput +from tx.models.utils import create_stacked_layers, forward_layers from tx.utils.generator import GeneratorMixin, KVCache from tx.utils.logits_processor import LogitsProcessorMixin, LMHead @@ -387,6 +388,8 @@ def __call__( router_logits = self.gate(hidden_states_flat) top_k_weights, top_k_index = self._compute_routing(router_logits) + # Cast routing weights to hidden_states dtype to preserve dtype through the forward pass + top_k_weights = top_k_weights.astype(hidden_states.dtype) expert_output = self.experts(hidden_states_flat, top_k_index, top_k_weights, adapter_indices_flat) shared_output = self.shared_experts( @@ -398,18 +401,13 @@ def __call__( class DeepseekV3DecoderLayer(nnx.Module): + """Base decoder layer with shared attributes and forward pass.""" - def __init__(self, config: DeepseekV3Config, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) self.self_attn = DeepseekV3Attention(config, dtype=dtype, rngs=rngs) - # Use dense MLP for initial layers, MoE for the rest - if layer_idx >= config.first_k_dense_replace: - self.mlp = DeepseekV3MoE(config, dtype=dtype, rngs=rngs) - else: - self.mlp = DeepseekV3MLP(config, dtype=dtype, rngs=rngs) - def __call__( self, hidden_states: jax.Array, @@ -438,10 +436,30 @@ def __call__( return hidden_states, updated_cache +class DeepseekV3DenseDecoderLayer(DeepseekV3DecoderLayer): + """Dense decoder layer (uses MLP, no MoE).""" + + def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + super().__init__(config, dtype=dtype, rngs=rngs) + self.mlp = DeepseekV3MLP(config, dtype=dtype, rngs=rngs) + + +class DeepseekV3MoEDecoderLayer(DeepseekV3DecoderLayer): + """MoE decoder layer (uses sparse MoE instead of dense MLP).""" + + def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + super().__init__(config, dtype=dtype, rngs=rngs) + self.mlp = DeepseekV3MoE(config, dtype=dtype, rngs=rngs) + + class DeepseekV3Model(nnx.Module): + training: bool = False def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config + self.num_dense_layers = config.first_k_dense_replace + self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace + self.num_layers = config.num_hidden_layers self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, @@ -453,12 +471,23 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), rngs=rngs, ) - self.layers = nnx.List( - [ - DeepseekV3DecoderLayer(config, layer_idx=i, dtype=dtype, rngs=rngs) - for i in range(config.num_hidden_layers) - ] - ) + + # Create stacked dense layers (layers 0 to first_k_dense_replace - 1) + if self.num_dense_layers > 0: + def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DenseDecoderLayer: + return DeepseekV3DenseDecoderLayer(config, dtype=dtype, rngs=rngs) + self.dense_layers = create_stacked_layers(create_dense_layer, self.num_dense_layers, rngs) + else: + self.dense_layers = None + + # Create stacked MoE layers (layers first_k_dense_replace to num_hidden_layers - 1) + if self.num_moe_layers > 0: + def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3MoEDecoderLayer: + return DeepseekV3MoEDecoderLayer(config, dtype=dtype, rngs=rngs) + self.moe_layers = create_stacked_layers(create_moe_layer, self.num_moe_layers, rngs) + else: + self.moe_layers = None + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) def __call__( @@ -477,29 +506,77 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) all_hidden_states: list[jax.Array] = [] - updated_keys, updated_values = [], [] - for layer_idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) + # Split KV cache for dense and MoE layers + dense_kv_cache = None + moe_kv_cache = None + if kv_cache is not None: + if self.num_dense_layers > 0: + dense_kv_cache = KVCache( + keys=kv_cache.keys[:self.num_dense_layers], + values=kv_cache.values[:self.num_dense_layers], + cache_position=kv_cache.cache_position, + ) + if self.num_moe_layers > 0: + moe_kv_cache = KVCache( + keys=kv_cache.keys[self.num_dense_layers:], + values=kv_cache.values[self.num_dense_layers:], + cache_position=kv_cache.cache_position, + ) + + # Forward through dense layers + dense_kv_result = None + if self.dense_layers is not None: + hidden_states, dense_hidden_states, dense_kv_result = forward_layers( + self.dense_layers, + hidden_states, + self.num_dense_layers, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=dense_kv_cache, + output_hidden_states=output_hidden_states, + gradient_checkpointing=self.config.gradient_checkpointing, + ) + all_hidden_states.extend(dense_hidden_states) - hidden_states, (k, v) = layer( + # Forward through MoE layers + moe_kv_result = None + if self.moe_layers is not None: + hidden_states, moe_hidden_states, moe_kv_result = forward_layers( + self.moe_layers, hidden_states, + self.num_moe_layers, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, - kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]), + kv_cache=moe_kv_cache, + output_hidden_states=output_hidden_states, + gradient_checkpointing=self.config.gradient_checkpointing, ) - updated_keys.append(k) - updated_values.append(v) + all_hidden_states.extend(moe_hidden_states) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states.append(hidden_states) + # Merge KV caches from dense and MoE layers + if dense_kv_result is not None and moe_kv_result is not None: + new_kv_cache = KVCache( + keys=jnp.concatenate([dense_kv_result.keys, moe_kv_result.keys], axis=0), + values=jnp.concatenate([dense_kv_result.values, moe_kv_result.values], axis=0), + cache_position=moe_kv_result.cache_position, + ) + elif dense_kv_result is not None: + new_kv_cache = dense_kv_result + elif moe_kv_result is not None: + new_kv_cache = moe_kv_result + else: + new_kv_cache = None + return ModelOutput( last_hidden_state=hidden_states, - kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask), + kv_cache=new_kv_cache, hidden_states=all_hidden_states if output_hidden_states else None, ) @@ -527,6 +604,12 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head + def train(self, **attributes): + return super().train(training=True, **attributes) + + def eval(self, **attributes): + return super().eval(training=False, **attributes) + @staticmethod def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 630a8628d..e57173d51 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -87,10 +87,18 @@ def _is_layer_param(path: tuple) -> bool: def _is_stacked_layer_param(path: tuple) -> bool: """Check if a parameter path corresponds to a STACKED decoder layer weight. - Stacked layers (Qwen3/Llama3) have paths like: ('model', 'layers', 'self_attn', ...) - Non-stacked layers (DeepSeekV3) have paths like: ('model', 'layers', '0', 'self_attn', ...) + Stacked layers have paths like: + - Qwen3/Llama3: ('model', 'layers', 'self_attn', ...) + - DeepSeekV3 dense: ('model', 'dense_layers', 'self_attn', ...) + - DeepSeekV3 MoE: ('model', 'moe_layers', 'self_attn', ...) + + Non-stacked layers have paths like: ('model', 'layers', '0', 'self_attn', ...) """ path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] + # Check for split stacked layer names (DeepSeekV3) + if "dense_layers" in path_strs or "moe_layers" in path_strs: + return True + # Check for regular stacked layers (Qwen3/Llama3) if "layers" not in path_strs: return False layers_idx = path_strs.index("layers") @@ -99,12 +107,33 @@ def _is_stacked_layer_param(path: tuple) -> bool: return True # Stacked: no layer index in path +def _get_layer_group_info(path: tuple, config: ModelConfig) -> tuple[str, int]: + """Get layer group name and starting layer index for a stacked param path. + + Returns: + Tuple of (layer_name_for_hf_key, layer_offset) where: + - layer_name_for_hf_key is 'layers' (HF always uses 'layers') + - layer_offset is the starting layer index for this group + """ + path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] + if "dense_layers" in path_strs: + return "layers", 0 + elif "moe_layers" in path_strs: + return "layers", getattr(config, "first_k_dense_replace", 0) + else: + return "layers", 0 + + def _path_to_hf_key(path: tuple, layer_idx: int | None = None) -> str: - """Convert param path to HuggingFace key. If layer_idx provided, insert it after 'layers'.""" + """Convert param path to HuggingFace key. If layer_idx provided, insert it after 'layers'. + + Handles split stacked layer names (dense_layers, moe_layers) by converting them to 'layers'. + """ parts = [] for p in path: key = p.key if hasattr(p, "key") else str(p) - if key == "layers" and layer_idx is not None: + # Handle split stacked layer names - convert to 'layers' with index + if key in ("layers", "dense_layers", "moe_layers") and layer_idx is not None: parts.append(f"layers.{layer_idx}") elif key in ("kernel", "embedding"): parts.append("weight") @@ -181,13 +210,16 @@ def load_safetensors( continue if _is_stacked_layer_param(path): - # Stack per-layer weights from HF format (stacked layers like Qwen3/Llama3) + # Stack per-layer weights from HF format + # Infer layer count from param shape and get offset for split stacked layers + stacked_layer_count = param.shape[0] + _, layer_offset = _get_layer_group_info(path, config) stacked_tensor = np.empty(param.shape, dtype=param.dtype) - for i in range(num_layers): - key = _path_to_hf_key(path, layer_idx=i) + for i in range(stacked_layer_count): + key = _path_to_hf_key(path, layer_idx=layer_offset + i) stacked_tensor[i] = _load_hf_tensor(tensors, key, param.shape[1:], num_experts) else: - # Non-stacked layers (like DeepSeekV3) or non-layer params + # Non-stacked layers or non-layer params key = _path_to_hf_key(path) stacked_tensor = _load_hf_tensor(tensors, key, param.shape, num_experts) @@ -218,12 +250,15 @@ def save_safetensors( continue if _is_stacked_layer_param(path): - # Unstack and save as individual layer weights (stacked layers like Qwen3/Llama3) - for i in range(num_layers): - key = prefix + _path_to_hf_key(path, layer_idx=i) + # Unstack and save as individual layer weights + # Infer layer count from param shape and get offset for split stacked layers + stacked_layer_count = param.shape[0] + _, layer_offset = _get_layer_group_info(path, config) + for i in range(stacked_layer_count): + key = prefix + _path_to_hf_key(path, layer_idx=layer_offset + i) _save_hf_tensor(tensors, key, param[i], num_experts) else: - # Non-stacked layers (like DeepSeekV3) or non-layer params + # Non-stacked layers or non-layer params key = prefix + _path_to_hf_key(path) _save_hf_tensor(tensors, key, param, num_experts) @@ -336,7 +371,8 @@ def extract_state(path: tuple, p: jnp.ndarray): # - 5D: Stacked expert (L, A, E, in, R) # We extract adapter_index from the adapter dimension (axis 1 for stacked, axis 0 otherwise) assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" - is_stacked = "layers" in [pk.key if hasattr(pk, "key") else str(pk) for pk in path] + path_strs = [pk.key if hasattr(pk, "key") else str(pk) for pk in path] + is_stacked = any(name in path_strs for name in ("layers", "dense_layers", "moe_layers")) if path[-2].key == "lora_A": if is_stacked: # (L, A, ..., R) return p[:, adapter_index, ..., :rank] @@ -364,7 +400,8 @@ def insert_state(path: tuple, p: jax.Array, new: jax.Array): return new # See extract_adapter_state for shape documentation assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" - is_stacked = "layers" in [pk.key if hasattr(pk, "key") else str(pk) for pk in path] + path_strs = [pk.key if hasattr(pk, "key") else str(pk) for pk in path] + is_stacked = any(name in path_strs for name in ("layers", "dense_layers", "moe_layers")) if path[-2].key == "lora_A": if is_stacked: # (L, A, ..., R) return p.at[:, adapter_index, ..., :rank].set(new) From 650c926eb1f8adf5a5f28fd676c81c2b9df5cc48 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 11:18:38 -0800 Subject: [PATCH 087/133] Remove unused train/eval methods from all models These methods were added to distinguish training/inference paths but are no longer needed with the unified forward_layers approach. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/deepseekv3.py | 7 ------- skyrl-tx/tx/models/llama3.py | 7 ------- skyrl-tx/tx/models/qwen3.py | 7 ------- 3 files changed, 21 deletions(-) diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 70d1f2649..524db56bd 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -453,7 +453,6 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs class DeepseekV3Model(nnx.Module): - training: bool = False def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config @@ -604,12 +603,6 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head - def train(self, **attributes): - return super().train(training=True, **attributes) - - def eval(self, **attributes): - return super().eval(training=False, **attributes) - @staticmethod def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 01ed8ee69..4c9d8c9d2 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -190,7 +190,6 @@ def __call__( class Llama3Model(nnx.Module): - training: bool = False def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config @@ -277,12 +276,6 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head - def train(self, **attributes): - return super().train(training=True, **attributes) - - def eval(self, **attributes): - return super().eval(training=False, **attributes) - @staticmethod def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 03914e668..5be6fb0f1 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -305,7 +305,6 @@ def __call__( class Qwen3Model(nnx.Module): - training: bool = False def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config @@ -392,12 +391,6 @@ def get_lm_head(self) -> LMHead: """Return the lm_head callable for logits computation.""" return self.lm_head - def train(self, **attributes): - return super().train(training=True, **attributes) - - def eval(self, **attributes): - return super().eval(training=False, **attributes) - @staticmethod def is_lora_param(path: tuple, _value) -> bool: """Return True if a parameter path corresponds to LoRA weights.""" From c669b34dbfc51ffef75fb1d0cda3df6a840e81cd Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 11:20:50 -0800 Subject: [PATCH 088/133] Remove .train()/.eval() calls no longer needed Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_models_common.py | 2 -- skyrl-tx/tx/tinker/backends/jax.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 1edcf5e6c..12d9854b6 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -80,7 +80,6 @@ def _forward( model, config = create_model(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=gradient_checkpointing) input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - model.train() out = model(input_ids, attention_mask=attention_mask, **forward_kwargs) return model, config, out @@ -138,7 +137,6 @@ def test_eval_mode_uses_standard_path( input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) - model.eval() out = model(input_ids, attention_mask=attention_mask) # KV cache should be populated (checkpointed path returns empty) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 3547b3ba2..d07397f59 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -246,7 +246,7 @@ def _model_forward( target_ids: jax.Array, ) -> jax.Array: """Forward pass and logprobs computation.""" - model = nnx.merge(graphdef, lora_params, non_lora_params).train() + model = nnx.merge(graphdef, lora_params, non_lora_params) output = model( input_ids, attention_mask=attention_mask, From 68f82dfe71ff01e305ac31bf751bcf323040ba6c Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 11:37:44 -0800 Subject: [PATCH 089/133] Fix outdated test name and improve dtype cast comment - Rename test_eval_mode_uses_standard_path to test_kv_cache_with_checkpointing - Clarify dtype cast comment in DeepSeekV3 MoE routing Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/test_models_common.py | 6 ++---- skyrl-tx/tx/models/deepseekv3.py | 3 ++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 12d9854b6..590d0ecbb 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -55,7 +55,6 @@ def load_model( """Load model from pre-saved weights directory.""" model, config = create_model( model_name, config_cls, model_cls, mesh_axes, - mesh_axis_types=(jax.sharding.AxisType.Auto,) * 2, loss_chunk_size=loss_chunk_size, gradient_checkpointing=False, ) @@ -122,14 +121,14 @@ def test_hidden_states_length_matches( for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(hidden_states_no_ckpt, hidden_states_ckpt)): np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}") - def test_eval_mode_uses_standard_path( + def test_kv_cache_with_checkpointing( self, model_name: str, config_cls: type[ModelConfig], model_cls: type[ModelForCausalLM], mesh_axes: tuple[str, str], ) -> None: - """eval() mode should use standard path with KV cache support.""" + """KV cache should be populated even with gradient checkpointing enabled.""" model, config = create_model(model_name, config_cls, model_cls, mesh_axes) config.gradient_checkpointing = True @@ -139,7 +138,6 @@ def test_eval_mode_uses_standard_path( out = model(input_ids, attention_mask=attention_mask) - # KV cache should be populated (checkpointed path returns empty) # keys is a stacked array with shape (num_layers, batch, seq, heads, dim) assert out.kv_cache.keys.shape[0] == config.num_hidden_layers diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 524db56bd..bad6b0d47 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -388,7 +388,8 @@ def __call__( router_logits = self.gate(hidden_states_flat) top_k_weights, top_k_index = self._compute_routing(router_logits) - # Cast routing weights to hidden_states dtype to preserve dtype through the forward pass + # _compute_routing uses float32 for softmax stability; cast back to model dtype + # to maintain consistent dtypes through jax.lax.scan in forward_layers top_k_weights = top_k_weights.astype(hidden_states.dtype) expert_output = self.experts(hidden_states_flat, top_k_index, top_k_weights, adapter_indices_flat) From 301f7dcbd7ac6e478a5e5fb09176f36a7171cbf8 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 11:46:03 -0800 Subject: [PATCH 090/133] Refactor: remove unused code and consolidate stacked path utilities 1. Remove unused _is_layer_param function from tx/utils/models.py 2. Remove unused num_layers parameter from load_safetensors/save_safetensors 3. Add is_stacked_lora_path() shared utility for LoRA adapter indexing 4. Create tests/models/lora_test_utils.py with shared test helpers: - get_adapter_params, get_out_of_rank_params, verify_params_unchanged - get_moe_out_of_rank_params for MoE-specific rank handling 5. Update all test files to use shared utilities Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/models/lora_test_utils.py | 83 ++++++++++++++ skyrl-tx/tests/models/test_deepseekv3.py | 2 +- .../models/test_deepseekv3_lora_training.py | 103 +++--------------- skyrl-tx/tests/models/test_llama3.py | 2 +- .../tests/models/test_llama3_lora_training.py | 40 +------ skyrl-tx/tests/models/test_models_common.py | 2 +- skyrl-tx/tests/models/test_qwen3.py | 4 +- skyrl-tx/tests/models/test_qwen3_generate.py | 4 +- .../tests/models/test_qwen3_lora_training.py | 40 +------ skyrl-tx/tx/layers/lora.py | 9 +- skyrl-tx/tx/tinker/backends/jax.py | 2 +- skyrl-tx/tx/utils/models.py | 26 +++-- 12 files changed, 130 insertions(+), 187 deletions(-) create mode 100644 skyrl-tx/tests/models/lora_test_utils.py diff --git a/skyrl-tx/tests/models/lora_test_utils.py b/skyrl-tx/tests/models/lora_test_utils.py new file mode 100644 index 000000000..24b506d0d --- /dev/null +++ b/skyrl-tx/tests/models/lora_test_utils.py @@ -0,0 +1,83 @@ +"""Shared test utilities for LoRA training tests.""" + +import jax +import jax.numpy as jnp + +from tx.utils.models import is_stacked_lora_path + + +def get_adapter_params(params, adapter_idx: int): + """Extract adapter params at a specific index. + + Decoder layer LoRA params have shape (num_layers, num_adapters, ...). + Embed tokens LoRA params have shape (num_adapters, ...). + """ + def extract(path, p): + if is_stacked_lora_path(path): + return p[:, adapter_idx].copy() + return p[adapter_idx].copy() + return jax.tree.map_with_path(extract, params) + + +def get_out_of_rank_params(params, adapter_idx: int, rank: int): + """Extract out-of-rank params for an adapter. + + Returns the portion of LoRA weights beyond the effective rank, + which should remain unchanged during training. + """ + def slice_param(path, p): + path_str = str(path) + is_stacked = is_stacked_lora_path(path) + if "lora_A" in path_str: + if is_stacked: + return p[:, adapter_idx, ..., rank:].copy() + return p[adapter_idx, ..., rank:].copy() + elif "lora_B" in path_str: + if is_stacked: + return p[:, adapter_idx, ..., rank:, :].copy() + return p[adapter_idx, ..., rank:, :].copy() + return p + return jax.tree.map_with_path(slice_param, params) + + +def verify_params_unchanged(initial_params, final_params, error_msg_prefix: str): + """Verify that params haven't changed between initial and final state.""" + for (path, initial), (_, final) in zip( + jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) + ): + assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" + + +def is_routed_expert_path(path) -> bool: + """Check if path is for routed experts (not shared_experts).""" + keys = [] + for p in path: + if hasattr(p, "key"): + keys.append(str(p.key)) + elif hasattr(p, "name"): + keys.append(str(p.name)) + for i, key in enumerate(keys): + if key == "experts" and i > 0 and keys[i - 1] == "mlp": + return True + return False + + +def get_moe_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: int): + """Extract out-of-rank params for MoE models. + + For routed experts, uses effective rank = max(1, rank // num_experts). + """ + def slice_param(path, p): + path_str = str(path) + effective_rank = max(1, rank // num_experts) if is_routed_expert_path(path) else rank + is_stacked = is_stacked_lora_path(path) + if "lora_A" in path_str: + if is_stacked: + return p[:, adapter_idx, ..., effective_rank:].copy() + return p[adapter_idx, ..., effective_rank:].copy() + elif "lora_B" in path_str: + if is_stacked: + return p[:, adapter_idx, ..., effective_rank:, :].copy() + return p[adapter_idx, ..., effective_rank:, :].copy() + return p + return jax.tree.map_with_path(slice_param, params) diff --git a/skyrl-tx/tests/models/test_deepseekv3.py b/skyrl-tx/tests/models/test_deepseekv3.py index 1a33e0987..188917e12 100644 --- a/skyrl-tx/tests/models/test_deepseekv3.py +++ b/skyrl-tx/tests/models/test_deepseekv3.py @@ -53,7 +53,7 @@ def test_deepseekv3(tp: int): ) with jax.set_mesh(mesh): model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model, config.num_hidden_layers) + load_safetensors(tmp, config, model) outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy(), output_hidden_states=True) diff --git a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py index 054abc56a..3ff2b7510 100644 --- a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py +++ b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py @@ -11,58 +11,11 @@ from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig - -def _is_routed_expert_path(path) -> bool: - """Disambiguate shared_experts and experts""" - keys = [] - for p in path: - if hasattr(p, "key"): - keys.append(str(p.key)) - elif hasattr(p, "name"): - keys.append(str(p.name)) - - for i, key in enumerate(keys): - if key == "experts" and i > 0 and keys[i - 1] == "mlp": - return True - return False - - -def _is_stacked_path(path) -> bool: - """Check if path is for stacked layers (dense_layers or moe_layers).""" - for p in path: - key = p.key if hasattr(p, "key") else str(p) - if key in ("dense_layers", "moe_layers"): - return True - return False - - -def _get_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: int): - """Extract out-of-rank params, using effective rank for routed expert layers.""" - - def slice_param(path, p): - path_str = str(path) - - if _is_routed_expert_path(path): - effective_rank = max(1, rank // num_experts) - else: - effective_rank = rank - - # For stacked layers, adapter index is dim 1; for non-stacked, it's dim 0 - is_stacked = _is_stacked_path(path) - - if "lora_A" in path_str: - # lora_A shape: [layers, adapters, ..., max_rank] (stacked) or [adapters, ..., max_rank] - if is_stacked: - return p[:, adapter_idx, ..., effective_rank:].copy() - return p[adapter_idx, ..., effective_rank:].copy() - elif "lora_B" in path_str: - # lora_B shape: [layers, adapters, ..., max_rank, out] (stacked) or [adapters, ..., max_rank, out] - if is_stacked: - return p[:, adapter_idx, ..., effective_rank:, :].copy() - return p[adapter_idx, ..., effective_rank:, :].copy() - return p - - return jax.tree.map_with_path(slice_param, params) +from tests.models.lora_test_utils import ( + get_adapter_params, + get_moe_out_of_rank_params, + verify_params_unchanged, +) def test_lora_training_moe_rank_normalized(): @@ -78,7 +31,7 @@ def test_lora_training_moe_rank_normalized(): ) with jax.set_mesh(mesh): model = DeepseekV3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, config, model, config.num_hidden_layers) + load_safetensors(checkpoint_path, config, model) # Set different ranks for each adapter (0: rank 16, 1: rank 8) # For routed experts with 256 experts: effective rank = max(1, rank // 256) = 1 @@ -101,19 +54,12 @@ def loss_fn(model, input_ids, target_ids, attention_mask): graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) - def get_adapter_params(params, adapter_idx): - def extract(path, p): - if _is_stacked_path(path): - return p[:, adapter_idx].copy() - return p[adapter_idx].copy() - return jax.tree.map_with_path(extract, params) - num_experts = config.n_routed_experts # Save initial states initial_adapter_2_params = get_adapter_params(lora_params, 2) - initial_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) - initial_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + initial_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts) + initial_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts) initial_loss = None @@ -136,12 +82,6 @@ def loss_for_lora(lora_params): final_loss = float(loss) - def verify_params_unchanged(initial_params, final_params, error_msg_prefix): - for (path, initial), (_, final) in zip( - jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) - ): - assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" - assert final_loss < initial_loss, f"Loss did not decrease: {initial_loss} -> {final_loss}" # Verify unused adapter was not modified @@ -149,11 +89,11 @@ def verify_params_unchanged(initial_params, final_params, error_msg_prefix): verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") # Verify out-of-rank params were not modified - final_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + final_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts) verify_params_unchanged( initial_adapter_0_out_of_rank, final_adapter_0_out_of_rank, "Adapter 0 out-of-rank params modified" ) - final_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + final_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts) verify_params_unchanged( initial_adapter_1_out_of_rank, final_adapter_1_out_of_rank, "Adapter 1 out-of-rank params modified" ) @@ -172,7 +112,7 @@ def test_lora_training_high_rank(): ) with jax.set_mesh(mesh): model = DeepseekV3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, config, model, config.num_hidden_layers) + load_safetensors(checkpoint_path, config, model) init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) init_lora_adapter(model, adapter_index=1, lora_config=LoraConfig(rank=8, alpha=8, seed=1)) @@ -192,13 +132,6 @@ def loss_fn(model, input_ids, target_ids, attention_mask): graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) - def get_adapter_params(params, adapter_idx): - def extract(path, p): - if _is_stacked_path(path): - return p[:, adapter_idx].copy() - return p[adapter_idx].copy() - return jax.tree.map_with_path(extract, params) - num_experts = config.n_routed_experts # Save initial states for all unused adapters @@ -207,8 +140,8 @@ def extract(path, p): initial_adapter_4_params = get_adapter_params(lora_params, 4) # Save out-of-rank params for adapters 0 and 1 - initial_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) - initial_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + initial_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts) + initial_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts) # Training loop for step in range(10): @@ -224,12 +157,6 @@ def loss_for_lora(lora_params): print(f"Step {step}: loss = {float(loss):.4f}") - def verify_params_unchanged(initial_params, final_params, error_msg_prefix): - for (path, initial), (_, final) in zip( - jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) - ): - assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" - # Verify unused adapters (2, 3, 4) were not modified final_adapter_2_params = get_adapter_params(lora_params, 2) verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") @@ -241,11 +168,11 @@ def verify_params_unchanged(initial_params, final_params, error_msg_prefix): verify_params_unchanged(initial_adapter_4_params, final_adapter_4_params, "Adapter 4 was modified") # Verify out-of-rank params were not modified - final_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + final_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts) verify_params_unchanged( initial_adapter_0_out_of_rank, final_adapter_0_out_of_rank, "Adapter 0 out-of-rank params modified" ) - final_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + final_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts) verify_params_unchanged( initial_adapter_1_out_of_rank, final_adapter_1_out_of_rank, "Adapter 1 out-of-rank params modified" ) diff --git a/skyrl-tx/tests/models/test_llama3.py b/skyrl-tx/tests/models/test_llama3.py index 7913839c5..fa195567f 100644 --- a/skyrl-tx/tests/models/test_llama3.py +++ b/skyrl-tx/tests/models/test_llama3.py @@ -42,7 +42,7 @@ def test_llama3(tp: int): mesh = jax.make_mesh((1, tp), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Llama3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model, config.num_hidden_layers) + load_safetensors(tmp, config, model) outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy(), output_hidden_states=True) diff --git a/skyrl-tx/tests/models/test_llama3_lora_training.py b/skyrl-tx/tests/models/test_llama3_lora_training.py index aba69a728..a04fa5f60 100644 --- a/skyrl-tx/tests/models/test_llama3_lora_training.py +++ b/skyrl-tx/tests/models/test_llama3_lora_training.py @@ -11,6 +11,8 @@ from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig +from tests.models.lora_test_utils import get_adapter_params, get_out_of_rank_params, verify_params_unchanged + def test_lora_training(): base_model = "unsloth/Llama-3.2-1B" @@ -21,7 +23,7 @@ def test_lora_training(): mesh = jax.make_mesh((1, 1), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Llama3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, config, model, config.num_hidden_layers) + load_safetensors(checkpoint_path, config, model) # Set different ranks for each adapter (0: rank 16, 1: rank 8) init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) @@ -45,36 +47,6 @@ def loss_fn(model, input_ids, target_ids, attention_mask): # that we want to compute gradients for graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) - # Helper to extract adapter params at specific index - # Decoder layer LoRA params have shape (num_layers, num_adapters, ...) - # Embed tokens LoRA params have shape (num_adapters, ...) - def get_adapter_params(params, adapter_idx): - def extract(path, p): - path_str = str(path) - if "layers" in path_str: - return p[:, adapter_idx].copy() # Keep layer dimension - else: - return p[adapter_idx].copy() - return jax.tree.map_with_path(extract, params) - - # Helper to extract out-of-rank params for an adapter - def get_out_of_rank_params(params, adapter_idx, rank): - def slice_param(path, p): - path_str = str(path) - is_stacked = "layers" in path_str - if "lora_A" in path_str: - if is_stacked: - return p[:, adapter_idx, :, rank:].copy() - else: - return p[adapter_idx, :, rank:].copy() - elif "lora_B" in path_str: - if is_stacked: - return p[:, adapter_idx, rank:, :].copy() - else: - return p[adapter_idx, rank:, :].copy() - return p - return jax.tree.map_with_path(slice_param, params) - # Save initial states initial_adapter_2_params = get_adapter_params(lora_params, 2) initial_adapter_0_out_of_rank = get_out_of_rank_params(lora_params, 0, 16) @@ -94,12 +66,6 @@ def loss_for_lora(lora_params): print(f"Step {step}: loss = {float(loss):.4f}") - def verify_params_unchanged(initial_params, final_params, error_msg_prefix): - for (path, initial), (_, final) in zip( - jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) - ): - assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" - # Verify adapter 2 (unused) was not modified final_adapter_2_params = get_adapter_params(lora_params, 2) verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 590d0ecbb..612df15c2 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -58,7 +58,7 @@ def load_model( loss_chunk_size=loss_chunk_size, gradient_checkpointing=False, ) - load_safetensors(tmp_dir, config, model, config.num_hidden_layers) + load_safetensors(tmp_dir, config, model) return model diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index 9e5fc9f95..dcf2680b9 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -43,7 +43,7 @@ def test_qwen3(tp: int): mesh = jax.make_mesh((1, tp), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model, config.num_hidden_layers) + load_safetensors(tmp, config, model) outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy(), output_hidden_states=True) assert outputs.hidden_states is not None @@ -240,7 +240,7 @@ def test_qwen3_lora(): mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(base_tmp, config, model, config.num_hidden_layers) + load_safetensors(base_tmp, config, model) # Get outputs from all HF models hf_outputs_list = [] diff --git a/skyrl-tx/tests/models/test_qwen3_generate.py b/skyrl-tx/tests/models/test_qwen3_generate.py index 7579d823d..8b950d535 100644 --- a/skyrl-tx/tests/models/test_qwen3_generate.py +++ b/skyrl-tx/tests/models/test_qwen3_generate.py @@ -49,7 +49,7 @@ def test_qwen3_generate(): mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model, config.num_hidden_layers) + load_safetensors(tmp, config, model) sampling_params = [ types.SamplingParams(max_tokens=10, temperature=0.0, seed=42), @@ -149,7 +149,7 @@ def test_qwen3_generate_speed(): mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Qwen3ForCausalLM(config, dtype=jnp.bfloat16, rngs=nnx.Rngs(0)) - load_safetensors(tmp, config, model, config.num_hidden_layers) + load_safetensors(tmp, config, model) sampling_params = [types.SamplingParams(max_tokens=50, temperature=0.0, seed=42) for i in range(len(inputs))] # Warmup diff --git a/skyrl-tx/tests/models/test_qwen3_lora_training.py b/skyrl-tx/tests/models/test_qwen3_lora_training.py index c757c18f6..a5873f506 100644 --- a/skyrl-tx/tests/models/test_qwen3_lora_training.py +++ b/skyrl-tx/tests/models/test_qwen3_lora_training.py @@ -11,6 +11,8 @@ from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig +from tests.models.lora_test_utils import get_adapter_params, get_out_of_rank_params, verify_params_unchanged + def test_lora_training(): base_model = "Qwen/Qwen3-0.6B" @@ -21,7 +23,7 @@ def test_lora_training(): mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Qwen3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, config, model, config.num_hidden_layers) + load_safetensors(checkpoint_path, config, model) # Set different ranks for each adapter (0: rank 16, 1: rank 8) init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) @@ -45,36 +47,6 @@ def loss_fn(model, input_ids, target_ids, attention_mask): # that we want to compute gradients for graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) - # Helper to extract adapter params at specific index - # Decoder layer LoRA params have shape (num_layers, num_adapters, ...) - # Embed tokens LoRA params have shape (num_adapters, ...) - def get_adapter_params(params, adapter_idx): - def extract(path, p): - path_str = str(path) - if "layers" in path_str: - return p[:, adapter_idx].copy() # Keep layer dimension - else: - return p[adapter_idx].copy() - return jax.tree.map_with_path(extract, params) - - # Helper to extract out-of-rank params for an adapter - def get_out_of_rank_params(params, adapter_idx, rank): - def slice_param(path, p): - path_str = str(path) - is_stacked = "layers" in path_str - if "lora_A" in path_str: - if is_stacked: - return p[:, adapter_idx, :, rank:].copy() - else: - return p[adapter_idx, :, rank:].copy() - elif "lora_B" in path_str: - if is_stacked: - return p[:, adapter_idx, rank:, :].copy() - else: - return p[adapter_idx, rank:, :].copy() - return p - return jax.tree.map_with_path(slice_param, params) - # Save initial states initial_adapter_2_params = get_adapter_params(lora_params, 2) initial_adapter_0_out_of_rank = get_out_of_rank_params(lora_params, 0, 16) @@ -94,12 +66,6 @@ def loss_for_lora(lora_params): print(f"Step {step}: loss = {float(loss):.4f}") - def verify_params_unchanged(initial_params, final_params, error_msg_prefix): - for (path, initial), (_, final) in zip( - jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) - ): - assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" - # Verify adapter 2 (unused) was not modified final_adapter_2_params = get_adapter_params(lora_params, 2) verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 259372baf..574ddb99f 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -3,7 +3,7 @@ from jax import numpy as jnp from jax.core import Tracer -from tx.utils.models import filter_lora +from tx.utils.models import filter_lora, is_stacked_lora_path from tx.layers.util import Param, prepare_routing, ragged_dot from tx.models.types import ModelForCausalLM from tx.tinker.types import LoraConfig @@ -368,8 +368,7 @@ def init_adapter(path, value): if not filter_lora(lora_config, normalized_path): effective_rank = 0 - is_stacked = any(name in normalized_path for name in ("layers", "dense_layers", "moe_layers")) - idx = _adapter_index(is_stacked, adapter_index) + idx = _adapter_index(is_stacked_lora_path(path), adapter_index) key_name = path[-2].key if key_name == "lora_ranks": @@ -403,9 +402,7 @@ def clear_adapter(path, value): key = path[-2].key if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"): return value - normalized_path = tuple(p.key if hasattr(p, "key") else p.name for p in path) - is_stacked = any(name in normalized_path for name in ("layers", "dense_layers", "moe_layers")) - idx = _adapter_index(is_stacked, adapter_index) + idx = _adapter_index(is_stacked_lora_path(path), adapter_index) return value.at[idx].set(0 if key == "lora_ranks" else 0.0) updated_state = jax.tree.map_with_path(clear_adapter, state) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index d07397f59..7287f7a1d 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -185,7 +185,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig): ) with jax.set_mesh(self.mesh), nnx.use_eager_sharding(True): self.model = model_class(self.model_config, dtype=get_dtype(self.model_config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, self.model_config, self.model, self.model.model.num_layers) + load_safetensors(checkpoint_path, self.model_config, self.model) # Split model into LoRA and non-LoRA parameters self.graphdef, self.lora_params, self.non_lora_params = nnx.split(self.model, self.model.is_lora_param, ...) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index e57173d51..2df833988 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -78,10 +78,20 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: raise ValueError(f"None of the architectures {config.architectures} is currently supported.") -def _is_layer_param(path: tuple) -> bool: - """Check if a parameter path corresponds to a decoder layer weight.""" +def is_stacked_lora_path(path: tuple) -> bool: + """Check if a parameter path is for stacked layer weights (for LoRA indexing). + + Stacked layer params have the adapter dimension at axis 1: (num_layers, num_adapters, ...). + Non-stacked params (e.g., embed_tokens) have adapter dimension at axis 0: (num_adapters, ...). + + Args: + path: Parameter path tuple (can be nnx path objects or strings). + + Returns: + True if the path contains 'layers', 'dense_layers', or 'moe_layers'. + """ path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] - return "layers" in path_strs + return any(name in path_strs for name in ("layers", "dense_layers", "moe_layers")) def _is_stacked_layer_param(path: tuple) -> bool: @@ -186,7 +196,6 @@ def load_safetensors( checkpoint_dir: str | os.PathLike, config: ModelConfig, model: nnx.Module, - num_layers: int, skip_lora: bool = True, prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, @@ -233,7 +242,6 @@ def save_safetensors( config: ModelConfig, model: nnx.Module, filename: Path, - num_layers: int, prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: @@ -301,7 +309,6 @@ def load_lora_checkpoint( temp_dir, model.config, adapter_lora_params, - model.model.num_layers, skip_lora=False, prefix="base_model.model.", filter_fn=lambda path: filter_lora(adapter_config, path), @@ -337,7 +344,6 @@ def save_lora_checkpoint( model.config, adapter_lora_params, temp_dir / "adapter_model.safetensors", - model.model.num_layers, prefix="base_model.model.", filter_fn=lambda path: filter_lora(adapter_config, path), ) @@ -371,8 +377,7 @@ def extract_state(path: tuple, p: jnp.ndarray): # - 5D: Stacked expert (L, A, E, in, R) # We extract adapter_index from the adapter dimension (axis 1 for stacked, axis 0 otherwise) assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" - path_strs = [pk.key if hasattr(pk, "key") else str(pk) for pk in path] - is_stacked = any(name in path_strs for name in ("layers", "dense_layers", "moe_layers")) + is_stacked = is_stacked_lora_path(path) if path[-2].key == "lora_A": if is_stacked: # (L, A, ..., R) return p[:, adapter_index, ..., :rank] @@ -400,8 +405,7 @@ def insert_state(path: tuple, p: jax.Array, new: jax.Array): return new # See extract_adapter_state for shape documentation assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" - path_strs = [pk.key if hasattr(pk, "key") else str(pk) for pk in path] - is_stacked = any(name in path_strs for name in ("layers", "dense_layers", "moe_layers")) + is_stacked = is_stacked_lora_path(path) if path[-2].key == "lora_A": if is_stacked: # (L, A, ..., R) return p.at[:, adapter_index, ..., :rank].set(new) From 6abe6e7c7d77051b72610a4f7d5dbf09822b2e7a Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 12:09:45 -0800 Subject: [PATCH 091/133] Fix tinker tests for stacked layer access Update test_jax_backend.py to use stacked layer indexing: - layers.self_attn.q_proj instead of layers[0].self_attn.q_proj - Access adapter params with [layer_idx, adapter_idx] instead of [adapter_idx] Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/tinker/test_jax_backend.py | 28 ++++++++++++----------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index 3543c7378..5ffa5d60c 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -108,17 +108,18 @@ def test_clear_lora_adapter(): # Verify adapter has non-zero rank after creation model = backend.model - lora_layer: LoRALinear = model.model.layers[0].self_attn.q_proj - assert lora_layer.lora_ranks[adapter_idx] > 0 + # With stacked layers, lora_ranks has shape (num_layers, num_adapters) + lora_layer: LoRALinear = model.model.layers.self_attn.q_proj + assert lora_layer.lora_ranks[0, adapter_idx] > 0 # Delete the model (calls clear_lora_adapter internally) backend.delete_model(model_id) - # Verify adapter state is zeroed - assert lora_layer.lora_ranks[adapter_idx] == 0 - assert lora_layer.lora_scaling[adapter_idx] == 0.0 - assert (lora_layer.lora_A[adapter_idx] == 0.0).all() - assert (lora_layer.lora_B[adapter_idx] == 0.0).all() + # Verify adapter state is zeroed (check layer 0) + assert lora_layer.lora_ranks[0, adapter_idx] == 0 + assert lora_layer.lora_scaling[0, adapter_idx] == 0.0 + assert (lora_layer.lora_A[0, adapter_idx] == 0.0).all() + assert (lora_layer.lora_B[0, adapter_idx] == 0.0).all() def make_fwd_bwd_input(token_lists: list[list[int]]) -> types.ForwardBackwardInput: @@ -534,20 +535,21 @@ def test_adapter_reuse_initializes_lora_adapter(): # (slot 0 is reserved for base model) backend = create_backend(max_lora_adapters=2) model = backend.model - lora_layer: LoRALinear = model.model.layers[0].self_attn.q_proj + # With stacked layers, lora_A has shape (num_layers, num_adapters, in_features, max_rank) + lora_layer: LoRALinear = model.model.layers.self_attn.q_proj # Create first model model_id_1 = "model_1" adapter_idx = create_model(backend, model_id_1) - # Verify lora_A is non-zero after creation + # Verify lora_A is non-zero after creation (check layer 0) assert not ( - lora_layer.lora_A[adapter_idx, ..., :LORA_RANK] == 0.0 + lora_layer.lora_A[0, adapter_idx, ..., :LORA_RANK] == 0.0 ).all(), "lora_A should be initialized with he_uniform (non-zero)" # Delete the model (clears both lora_A and lora_B to zeros) backend.delete_model(model_id_1) - assert (lora_layer.lora_A[adapter_idx] == 0.0).all(), "lora_A should be zeroed after clear_lora_adapter" + assert (lora_layer.lora_A[0, adapter_idx] == 0.0).all(), "lora_A should be zeroed after clear_lora_adapter" # Create a new model that reuses the same adapter slot model_id_2 = "model_2" @@ -556,11 +558,11 @@ def test_adapter_reuse_initializes_lora_adapter(): # Verify lora_A is initialized (non-zero) assert not ( - lora_layer.lora_A[adapter_idx, ..., :LORA_RANK] == 0.0 + lora_layer.lora_A[0, adapter_idx, ..., :LORA_RANK] == 0.0 ).all(), "lora_A should be initialized with he_uniform after adapter reuse" # Verify lora_B is zeros - assert (lora_layer.lora_B[adapter_idx] == 0.0).all(), "lora_B should be zeros" + assert (lora_layer.lora_B[0, adapter_idx] == 0.0).all(), "lora_B should be zeros" def test_mixed_train_unembed_adapters(): From 4cdd7dc754bd49afaf32e222fdb3f52ba05652d3 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 12:12:49 -0800 Subject: [PATCH 092/133] lint --- skyrl-tx/tests/models/lora_test_utils.py | 6 ++++++ skyrl-tx/tests/models/test_models_common.py | 21 ++++++++++++++++----- skyrl-tx/tx/layers/lora.py | 12 ++++++------ skyrl-tx/tx/models/deepseekv3.py | 12 ++++++++---- skyrl-tx/tx/utils/models.py | 6 ++---- 5 files changed, 38 insertions(+), 19 deletions(-) diff --git a/skyrl-tx/tests/models/lora_test_utils.py b/skyrl-tx/tests/models/lora_test_utils.py index 24b506d0d..507b5d9c6 100644 --- a/skyrl-tx/tests/models/lora_test_utils.py +++ b/skyrl-tx/tests/models/lora_test_utils.py @@ -12,10 +12,12 @@ def get_adapter_params(params, adapter_idx: int): Decoder layer LoRA params have shape (num_layers, num_adapters, ...). Embed tokens LoRA params have shape (num_adapters, ...). """ + def extract(path, p): if is_stacked_lora_path(path): return p[:, adapter_idx].copy() return p[adapter_idx].copy() + return jax.tree.map_with_path(extract, params) @@ -25,6 +27,7 @@ def get_out_of_rank_params(params, adapter_idx: int, rank: int): Returns the portion of LoRA weights beyond the effective rank, which should remain unchanged during training. """ + def slice_param(path, p): path_str = str(path) is_stacked = is_stacked_lora_path(path) @@ -37,6 +40,7 @@ def slice_param(path, p): return p[:, adapter_idx, ..., rank:, :].copy() return p[adapter_idx, ..., rank:, :].copy() return p + return jax.tree.map_with_path(slice_param, params) @@ -67,6 +71,7 @@ def get_moe_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: For routed experts, uses effective rank = max(1, rank // num_experts). """ + def slice_param(path, p): path_str = str(path) effective_rank = max(1, rank // num_experts) if is_routed_expert_path(path) else rank @@ -80,4 +85,5 @@ def slice_param(path, p): return p[:, adapter_idx, ..., effective_rank:, :].copy() return p[adapter_idx, ..., effective_rank:, :].copy() return p + return jax.tree.map_with_path(slice_param, params) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 612df15c2..53d2db389 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -54,7 +54,10 @@ def load_model( ) -> ModelForCausalLM: """Load model from pre-saved weights directory.""" model, config = create_model( - model_name, config_cls, model_cls, mesh_axes, + model_name, + config_cls, + model_cls, + mesh_axes, loss_chunk_size=loss_chunk_size, gradient_checkpointing=False, ) @@ -76,7 +79,9 @@ def _forward( ) -> tuple[ModelForCausalLM, ModelConfig, CausalLMOutput]: """Create model, run forward pass, and return (model, config, out).""" batch_size, seq_len = 2, 8 - model, config = create_model(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=gradient_checkpointing) + model, config = create_model( + model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=gradient_checkpointing + ) input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) out = model(input_ids, attention_mask=attention_mask, **forward_kwargs) @@ -108,18 +113,24 @@ def test_hidden_states_length_matches( mesh_axes: tuple[str, str], ) -> None: """Both paths should return same number of hidden states.""" - _, config, out = self._forward(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=False, output_hidden_states=True) + _, config, out = self._forward( + model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=False, output_hidden_states=True + ) hidden_states_no_ckpt = out.hidden_states num_hidden_layers = config.num_hidden_layers del out - _, _, out = self._forward(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=True, output_hidden_states=True) + _, _, out = self._forward( + model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=True, output_hidden_states=True + ) hidden_states_ckpt = out.hidden_states del out assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == num_hidden_layers + 1 for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(hidden_states_no_ckpt, hidden_states_ckpt)): - np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}") + np.testing.assert_allclose( + hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}" + ) def test_kv_cache_with_checkpointing( self, diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 574ddb99f..c0a3f6a10 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -159,9 +159,9 @@ def __init__( rngs=rngs, ) sharding = _get_sharding_spec(self.embedding[...]) - assert sharding is not None, ( - "LoRAEmbed layer needs sharding, you can specify it by using nnx.with_partitioning on the embedding_init" - ) + assert ( + sharding is not None + ), "LoRAEmbed layer needs sharding, you can specify it by using nnx.with_partitioning on the embedding_init" self.init_lora( max_lora_adapters=max_lora_adapters, @@ -229,9 +229,9 @@ def __init__( rngs=rngs, ) sharding = _get_sharding_spec(self.kernel[...]) - assert sharding is not None, ( - "LoRALinear layer needs sharding, you can specify it by using nnx.with_partitioning on the kernel_init" - ) + assert ( + sharding is not None + ), "LoRALinear layer needs sharding, you can specify it by using nnx.with_partitioning on the kernel_init" self.init_lora( max_lora_adapters=max_lora_adapters, max_lora_rank=max_lora_rank, diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index bad6b0d47..b5991c352 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -474,16 +474,20 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs # Create stacked dense layers (layers 0 to first_k_dense_replace - 1) if self.num_dense_layers > 0: + def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DenseDecoderLayer: return DeepseekV3DenseDecoderLayer(config, dtype=dtype, rngs=rngs) + self.dense_layers = create_stacked_layers(create_dense_layer, self.num_dense_layers, rngs) else: self.dense_layers = None # Create stacked MoE layers (layers first_k_dense_replace to num_hidden_layers - 1) if self.num_moe_layers > 0: + def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3MoEDecoderLayer: return DeepseekV3MoEDecoderLayer(config, dtype=dtype, rngs=rngs) + self.moe_layers = create_stacked_layers(create_moe_layer, self.num_moe_layers, rngs) else: self.moe_layers = None @@ -513,14 +517,14 @@ def __call__( if kv_cache is not None: if self.num_dense_layers > 0: dense_kv_cache = KVCache( - keys=kv_cache.keys[:self.num_dense_layers], - values=kv_cache.values[:self.num_dense_layers], + keys=kv_cache.keys[: self.num_dense_layers], + values=kv_cache.values[: self.num_dense_layers], cache_position=kv_cache.cache_position, ) if self.num_moe_layers > 0: moe_kv_cache = KVCache( - keys=kv_cache.keys[self.num_dense_layers:], - values=kv_cache.values[self.num_dense_layers:], + keys=kv_cache.keys[self.num_dense_layers :], + values=kv_cache.values[self.num_dense_layers :], cache_position=kv_cache.cache_position, ) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 2df833988..e720f3b56 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -158,10 +158,7 @@ def _load_hf_tensor(tensors: dict, key: str, target_shape: tuple, num_experts: i """Load tensor from HF format, handling experts, transpose, and reshape.""" # Handle MoE expert weights (HF stores each expert separately) if ".experts." in key and num_experts: - tensor = np.stack([ - tensors[key.replace(".experts.", f".experts.{i}.")].T - for i in range(num_experts) - ], axis=0) + tensor = np.stack([tensors[key.replace(".experts.", f".experts.{i}.")].T for i in range(num_experts)], axis=0) else: tensor = tensors[key] if "embed_tokens" not in key: @@ -273,6 +270,7 @@ def save_safetensors( # In multi-host mode, gather all shards and only save from rank 0 if jax.process_count() > 1: from jax.experimental import multihost_utils + tensors = {k: multihost_utils.process_allgather(v, tiled=True) for k, v in tensors.items()} if jax.process_index() == 0: From 3fd1420c00ffad7f3706bfe926834c11dc9d23ab Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 12:32:15 -0800 Subject: [PATCH 093/133] Fix AccumulatedGradients indexing for stacked layer params The get_mean and reset_adapter methods assumed gradients had shape (num_adapters, ...), but stacked layers have shape (num_layers, num_adapters, ...). Use is_stacked_lora_path to detect and index correctly for each case. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/tinker/backends/jax.py | 36 ++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 7287f7a1d..80cb6dfff 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -52,6 +52,7 @@ insert_adapter_state, round_up_seq_len, resolve_model_path, + is_stacked_lora_path, ) from tx.utils.log import logger @@ -124,17 +125,38 @@ def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "Accumulated ) def get_mean(self, adapter_index: jax.Array) -> nnx.State: - """Compute mean gradients for a specific adapter, with zeros for all other adapters.""" + """Compute mean gradients for a specific adapter, with zeros for all other adapters. + + Handles both stacked (num_layers, num_adapters, ...) and non-stacked (num_adapters, ...) params. + """ count = self.counts[adapter_index] - return jax.tree.map( - lambda g: jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)), - self.grad_sum, - ) + + def compute_mean(path, g): + if is_stacked_lora_path(path): + # Stacked: (num_layers, num_adapters, ...) -> index as [:, adapter_index] + return jnp.zeros_like(g).at[:, adapter_index].set(g[:, adapter_index] / count.astype(g.dtype)) + else: + # Non-stacked: (num_adapters, ...) -> index as [adapter_index] + return jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)) + + return jax.tree.map_with_path(compute_mean, self.grad_sum) def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients": - """Reset gradients and count for a specific adapter.""" + """Reset gradients and count for a specific adapter. + + Handles both stacked (num_layers, num_adapters, ...) and non-stacked (num_adapters, ...) params. + """ + + def reset_grad(path, g): + if is_stacked_lora_path(path): + # Stacked: (num_layers, num_adapters, ...) -> index as [:, adapter_index] + return g.at[:, adapter_index].set(0.0) + else: + # Non-stacked: (num_adapters, ...) -> index as [adapter_index] + return g.at[adapter_index].set(0.0) + return AccumulatedGradients( - grad_sum=jax.tree.map(lambda g: g.at[adapter_index].set(0.0), self.grad_sum), + grad_sum=jax.tree.map_with_path(reset_grad, self.grad_sum), counts=self.counts.at[adapter_index].set(0), ) From acb98fd6d4ef08a0e2147a2cfdd36381db857446 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 12:36:30 -0800 Subject: [PATCH 094/133] revert pyproject --- skyrl-tx/pyproject.toml | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index cc1aa8f52..587d9c19e 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -27,11 +27,11 @@ dependencies = [ [project.optional-dependencies] gpu = [ - "jax[cuda12]>=0.7.2", + "jax[cuda12]>=0.7.2; sys_platform == 'linux'", ] tpu = [ - "jax[tpu]>=0.7.2", + "jax[tpu]>=0.7.2; sys_platform == 'linux'", ] tinker = [ @@ -61,14 +61,15 @@ azure = [ # respectively. jax = [ - "jax[cuda12]>=0.7.2", + "jax[cuda12]>=0.7.2; sys_platform == 'linux'", ] skyrl_train = [ # We currently need the extra pin on the python version # here since skyrl-train pins on python version 3.12, # hopefully in the future we can remove that. - "skyrl-train[vllm]; python_version == '3.12'", + # skyrl-train[vllm] requires CUDA packages which are Linux-only. + "skyrl-train[vllm]; python_version == '3.12' and sys_platform == 'linux'", ] dev = [ @@ -94,8 +95,11 @@ tx = "tx.run.main:app" # The following is for supporting the skyrl-train dependency [tool.uv] -# Exclude skyrl-train on macOS since it requires CUDA torch -exclude-dependencies = ["skyrl-train"] +# Resolve for both Linux (production) and macOS (dev) +required-environments = [ + "sys_platform == 'linux'", + "sys_platform == 'darwin' and platform_machine == 'arm64'", +] [tool.uv.extra-build-dependencies] flash-attn = [{requirement = "torch", match-runtime = true}] @@ -105,7 +109,26 @@ transformer-engine-torch = [{requirement = "torch", match-runtime = true}, "buil [tool.uv.extra-build-variables] flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"} +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true + +[[tool.uv.index]] +name = "pytorch-cpu" +url = "https://download.pytorch.org/whl/cpu" +explicit = true + [tool.uv.sources] # For now, just always use the current main branch, later it will be better to pin it to a released version. For development, you # can set it to your own development branch. -# skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" } +skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" } +# Use CUDA torch on Linux, CPU torch on macOS (must match skyrl-train config) +torch = [ + { index = "pytorch-cu128", marker = "sys_platform == 'linux'" }, + { index = "pytorch-cpu", marker = "sys_platform == 'darwin'" }, +] +torchvision = [ + { index = "pytorch-cu128", marker = "sys_platform == 'linux'" }, + { index = "pytorch-cpu", marker = "sys_platform == 'darwin'" }, +] From 8cfe622115e1c0a1a612c6b3825fa9340294bad1 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 13:19:27 -0800 Subject: [PATCH 095/133] Refactor: extract _lora_slice helper to reduce duplication Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/utils/models.py | 55 ++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index e720f3b56..fb82329a6 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -362,31 +362,32 @@ def get_optimizer(optimizer_name: OptimizerName, optimizer_args: dict) -> optax. raise ValueError("The 'learning_rate' key must be provided in optimizer_args.") +def _lora_slice(is_stacked: bool, adapter_index: int, rank: int, is_lora_A: bool) -> tuple: + """Return slice tuple for extracting/inserting LoRA params. + + LoRA param shapes: + - 3D: Non-stacked linear/embed (A, in, R) or (A, R, out) + - 4D: Stacked linear/embed (L, A, in, R) or non-stacked expert (A, E, in, R) + - 5D: Stacked expert (L, A, E, in, R) + """ + # Adapter index: axis 1 for stacked (L, A, ...), axis 0 for non-stacked (A, ...) + adapter_idx = (slice(None), adapter_index) if is_stacked else (adapter_index,) + # Rank slice: lora_A has rank at last dim, lora_B has rank at second-to-last + rank_slice = (Ellipsis, slice(None, rank)) if is_lora_A else (Ellipsis, slice(None, rank), slice(None)) + return adapter_idx + rank_slice + + @nnx.jit(static_argnames=("adapter_index", "rank")) def extract_adapter_state(adapter_index: int, lora_params: nnx.GraphState, rank: int) -> nnx.GraphState: "Helper function to extract the adapter parameters for a specific adapter index." def extract_state(path: tuple, p: jnp.ndarray): - if path[-2].key not in {"lora_A", "lora_B"}: + key = path[-2].key + if key not in {"lora_A", "lora_B"}: return p - # LoRA param shapes: - # - 3D: Non-stacked linear/embed (A, in, R) or (A, R, out) - # - 4D: Stacked linear/embed (L, A, in, R) or non-stacked expert (A, E, in, R) - # - 5D: Stacked expert (L, A, E, in, R) - # We extract adapter_index from the adapter dimension (axis 1 for stacked, axis 0 otherwise) assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" - is_stacked = is_stacked_lora_path(path) - if path[-2].key == "lora_A": - if is_stacked: # (L, A, ..., R) - return p[:, adapter_index, ..., :rank] - else: # (A, ..., R) - return p[adapter_index, ..., :rank] - if path[-2].key == "lora_B": - if is_stacked: # (L, A, ..., out) - return p[:, adapter_index, ..., :rank, :] - else: # (A, ..., out) - return p[adapter_index, ..., :rank, :] - return p # Defensive fallback (should not be reached) + idx = _lora_slice(is_stacked_lora_path(path), adapter_index, rank, is_lora_A=(key == "lora_A")) + return p[idx] return jax.tree.map_with_path(extract_state, lora_params) @@ -399,22 +400,12 @@ def insert_adapter_state( "Helper function to insert the adapter parameters for a specific adapter index (inverse of extract_adapter_state)." def insert_state(path: tuple, p: jax.Array, new: jax.Array): - if path[-2].key not in {"lora_A", "lora_B"}: + key = path[-2].key + if key not in {"lora_A", "lora_B"}: return new - # See extract_adapter_state for shape documentation assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" - is_stacked = is_stacked_lora_path(path) - if path[-2].key == "lora_A": - if is_stacked: # (L, A, ..., R) - return p.at[:, adapter_index, ..., :rank].set(new) - else: # (A, ..., R) - return p.at[adapter_index, ..., :rank].set(new) - elif path[-2].key == "lora_B": - if is_stacked: # (L, A, ..., out) - return p.at[:, adapter_index, ..., :rank, :].set(new) - else: # (A, ..., out) - return p.at[adapter_index, ..., :rank, :].set(new) - return new # Defensive fallback (should not be reached) + idx = _lora_slice(is_stacked_lora_path(path), adapter_index, rank, is_lora_A=(key == "lora_A")) + return p.at[idx].set(new) updated = jax.tree.map_with_path(insert_state, lora_params, new_params) nnx.update(lora_params, updated) From a8a3e52568b222ba982a84e2de70ad62f7b06ff5 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 14:05:43 -0800 Subject: [PATCH 096/133] Add tests for stacked layer utilities - Add parametrized test for is_stacked_lora_path covering stacked (layers, dense_layers, moe_layers) and non-stacked paths - Add roundtrip test for extract/insert_adapter_state with stacked layers - Add DeepSeekV3 gradient checkpointing test for split stacking Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tests/models/test_deepseekv3.py | 48 ++++++++++++++ skyrl-tx/tests/utils/test_models.py | 83 ++++++++++++++++++++++++ 2 files changed, 131 insertions(+) diff --git a/skyrl-tx/tests/models/test_deepseekv3.py b/skyrl-tx/tests/models/test_deepseekv3.py index 23a15a639..2b18a4b83 100644 --- a/skyrl-tx/tests/models/test_deepseekv3.py +++ b/skyrl-tx/tests/models/test_deepseekv3.py @@ -186,3 +186,51 @@ def test_deepseekv3_moe_layer_lora(ep: int, tp: int): output_merged = moe_layer_merged(x_sample) assert np.allclose(output_with_lora[sample_idx : sample_idx + 1], output_merged, rtol=1e-3, atol=1e-3) + + +def test_deepseekv3_gradient_checkpointing(): + """Test that gradient checkpointing produces identical outputs for DeepSeekV3. + + DeepSeekV3 has split stacking (dense_layers + moe_layers), so this tests + that gradient checkpointing works correctly with heterogeneous layer types. + """ + model_name = "yujiepan/deepseek-v3-tiny-random" + base_config = PretrainedConfig.from_pretrained(model_name, trust_remote_code=True) + + batch_size, seq_len = 2, 8 + mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + + results = {} + for use_checkpointing in [False, True]: + config = DeepseekV3Config( + base_config, + max_lora_adapters=1, + max_lora_rank=1, + shard_attention_heads=True, + gradient_checkpointing=use_checkpointing, + ) + with jax.set_mesh(mesh): + model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + + input_ids = jax.random.randint(jax.random.key(42), (batch_size, seq_len), 0, config.vocab_size) + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + + out = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) + logits = model.compute_logits(out.last_hidden_state) + + results[use_checkpointing] = { + "logits": np.array(logits), + "hidden_states": [np.array(hs) for hs in out.hidden_states], + "kv_cache_shape": out.kv_cache.keys.shape, + } + + # Verify outputs match + np.testing.assert_allclose(results[False]["logits"], results[True]["logits"], rtol=1e-4, atol=1e-6) + + # Verify hidden states match + assert len(results[False]["hidden_states"]) == len(results[True]["hidden_states"]) + for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(results[False]["hidden_states"], results[True]["hidden_states"])): + np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}") + + # Verify KV cache shape is correct (num_layers, batch, seq, heads, dim) + assert results[True]["kv_cache_shape"][0] == config.num_hidden_layers diff --git a/skyrl-tx/tests/utils/test_models.py b/skyrl-tx/tests/utils/test_models.py index 2c74950af..747ab0e66 100644 --- a/skyrl-tx/tests/utils/test_models.py +++ b/skyrl-tx/tests/utils/test_models.py @@ -11,11 +11,14 @@ from peft import PeftModel from transformers import AutoConfig, AutoModelForCausalLM +from jax.tree_util import DictKey + from tx.layers.lora import init_lora_adapter from tx.models.configs import Qwen3Config from tx.models.qwen3 import Qwen3ForCausalLM from tx.tinker.types import LoraConfig from tx.utils import models +from tx.utils.models import extract_adapter_state, insert_adapter_state, is_stacked_lora_path from tx.utils.storage import download_and_unpack @@ -86,3 +89,83 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat assert torch.allclose(lora_A, torch.from_numpy(expected_lora_A), atol=1e-6) assert torch.allclose(lora_B, torch.from_numpy(expected_lora_B), atol=1e-6) + + +@pytest.mark.parametrize( + "path,expected", + [ + # Stacked paths (DictKey) + ((DictKey(key="model"), DictKey(key="layers"), DictKey(key="self_attn"), DictKey(key="lora_A")), True), + ((DictKey(key="model"), DictKey(key="dense_layers"), DictKey(key="self_attn"), DictKey(key="lora_A")), True), + ((DictKey(key="model"), DictKey(key="moe_layers"), DictKey(key="mlp"), DictKey(key="lora_A")), True), + # Non-stacked paths (DictKey) + ((DictKey(key="model"), DictKey(key="embed_tokens"), DictKey(key="lora_A")), False), + ((DictKey(key="lm_head"), DictKey(key="lora_A")), False), + # String paths + (("model", "layers", "self_attn", "lora_A"), True), + (("model", "embed_tokens", "lora_A"), False), + ], + ids=["layers", "dense_layers", "moe_layers", "embed_tokens", "lm_head", "str_layers", "str_embed"], +) +def test_is_stacked_lora_path(path, expected): + """Test is_stacked_lora_path correctly identifies stacked vs non-stacked paths.""" + assert is_stacked_lora_path(path) is expected + + +def test_extract_insert_adapter_state_roundtrip(): + """Test that extract_adapter_state and insert_adapter_state are inverses.""" + base_model_name = "Qwen/Qwen3-0.6B" + rank, alpha, adapter_index = 8, 16, 2 + _, _, model = create_test_model(base_model_name, rank, alpha, adapter_index) + + # Set LoRA weights to random values + q_proj = model.model.layers.self_attn.q_proj + rng1, rng2 = jax.random.split(jax.random.PRNGKey(123)) + q_proj.lora_A[...] = jax.random.normal(rng1, q_proj.lora_A[...].shape) + q_proj.lora_B[...] = jax.random.normal(rng2, q_proj.lora_B[...].shape) + + # Split model to get lora_params + _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) + + # Store original values for comparison + original_lora_A = np.array(q_proj.lora_A[...][0, adapter_index, :, :rank]) + original_lora_B = np.array(q_proj.lora_B[...][0, adapter_index, :rank, :]) + + # Extract adapter state + extracted = extract_adapter_state(adapter_index, lora_params, rank) + + # Verify extracted shape is correct (no adapter dimension) + for path, leaf in jax.tree.leaves_with_path(extracted): + key = path[-2].key if hasattr(path[-2], "key") else str(path[-2]) + if key in {"lora_A", "lora_B"}: + # Stacked: should have (num_layers, ...) not (num_layers, num_adapters, ...) + if is_stacked_lora_path(path): + assert leaf.shape[0] == 1 # num_layers + assert leaf.ndim == 3 # (layers, in_dim, rank) or (layers, rank, out_dim) + + # Zero out the adapter's weights + q_proj.lora_A[...] = q_proj.lora_A[...].at[0, adapter_index].set(0) + q_proj.lora_B[...] = q_proj.lora_B[...].at[0, adapter_index].set(0) + + # Verify weights are zeroed + assert np.allclose(q_proj.lora_A[...][0, adapter_index], 0) + assert np.allclose(q_proj.lora_B[...][0, adapter_index], 0) + + # Re-split to get updated lora_params + _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) + + # Insert extracted state back (modifies lora_params in-place via nnx.update) + insert_adapter_state(adapter_index, lora_params, extracted, rank) + + # Verify weights are restored by checking lora_params directly + for path, leaf in jax.tree.leaves_with_path(lora_params): + key = path[-2].key if hasattr(path[-2], "key") else str(path[-2]) + # leaf is a state wrapper with .value, or can be an array directly + arr = leaf.value if hasattr(leaf, "value") else leaf + if "q_proj" in str(path) and key == "lora_A": + restored_lora_A = np.array(arr[0, adapter_index, :, :rank]) + elif "q_proj" in str(path) and key == "lora_B": + restored_lora_B = np.array(arr[0, adapter_index, :rank, :]) + + assert np.allclose(original_lora_A, restored_lora_A), "lora_A not restored correctly" + assert np.allclose(original_lora_B, restored_lora_B), "lora_B not restored correctly" From e3ed933b1725d592b94f87ddda3a3aa5e48b0fe4 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 14:08:42 -0800 Subject: [PATCH 097/133] Add mlp type annotation to DeepseekV3DecoderLayer base class Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/models/deepseekv3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 9cb8d07b9..40382468d 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -410,6 +410,8 @@ def __call__( class DeepseekV3DecoderLayer(nnx.Module): """Base decoder layer with shared attributes and forward pass.""" + mlp: DeepseekV3MLP | DeepseekV3MoE # Set by subclasses + def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) From 7d5bf5b085e48e704171497a3cd29acf179c2527 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 15:29:08 -0800 Subject: [PATCH 098/133] Fix Qwen3 MoE softmax ordering to match HuggingFace Apply softmax to all router logits before top-k selection, not after. This matches HF's implementation and fixes ~1.3x output scaling error. Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/models/qwen3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 5be6fb0f1..912be0bfc 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -192,8 +192,8 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> def __call__( self, hidden_states: jax.Array, router_logits: jax.Array, adapter_indices: jax.Array | None = None ) -> jax.Array: - routing_weights, selected_experts = jax.lax.top_k(router_logits, k=self.config.num_experts_per_tok) - routing_weights = nnx.softmax(routing_weights, axis=-1) + routing_weights = nnx.softmax(router_logits, axis=-1) + routing_weights, selected_experts = jax.lax.top_k(routing_weights, k=self.config.num_experts_per_tok) num_experts = self.config.num_experts num_experts_per_tok = self.config.num_experts_per_tok From 3651dec63a67c6a77c315b3c7aaa50b38d444c06 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 15:40:12 -0800 Subject: [PATCH 099/133] Address PR review feedback - Refactor lora_test_utils.py to reduce duplication with _slice_out_of_rank helper - Simplify DeepseekV3 decoder layers by passing mlp_cls instead of subclassing - Add KVCache.split() and concatenate() methods for layer-wise cache operations Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tests/models/lora_test_utils.py | 49 +++++++++++------------- skyrl-tx/tx/models/deepseekv3.py | 38 +++++++----------- skyrl-tx/tx/utils/generator.py | 40 +++++++++++++++++++ 3 files changed, 76 insertions(+), 51 deletions(-) diff --git a/skyrl-tx/tests/models/lora_test_utils.py b/skyrl-tx/tests/models/lora_test_utils.py index 507b5d9c6..00c9077d8 100644 --- a/skyrl-tx/tests/models/lora_test_utils.py +++ b/skyrl-tx/tests/models/lora_test_utils.py @@ -21,29 +21,35 @@ def extract(path, p): return jax.tree.map_with_path(extract, params) -def get_out_of_rank_params(params, adapter_idx: int, rank: int): - """Extract out-of-rank params for an adapter. +def _slice_out_of_rank(params, adapter_idx: int, get_rank): + """Extract out-of-rank params using a rank function. - Returns the portion of LoRA weights beyond the effective rank, - which should remain unchanged during training. + Args: + params: LoRA parameters tree. + adapter_idx: Adapter index to extract. + get_rank: Function (path) -> int returning effective rank for that path. """ def slice_param(path, p): path_str = str(path) + if "lora_A" not in path_str and "lora_B" not in path_str: + return p + rank = get_rank(path) is_stacked = is_stacked_lora_path(path) if "lora_A" in path_str: - if is_stacked: - return p[:, adapter_idx, ..., rank:].copy() - return p[adapter_idx, ..., rank:].copy() - elif "lora_B" in path_str: - if is_stacked: - return p[:, adapter_idx, ..., rank:, :].copy() - return p[adapter_idx, ..., rank:, :].copy() - return p + idx = (slice(None), adapter_idx, ..., slice(rank, None)) if is_stacked else (adapter_idx, ..., slice(rank, None)) + else: # lora_B + idx = (slice(None), adapter_idx, ..., slice(rank, None), slice(None)) if is_stacked else (adapter_idx, ..., slice(rank, None), slice(None)) + return p[idx].copy() return jax.tree.map_with_path(slice_param, params) +def get_out_of_rank_params(params, adapter_idx: int, rank: int): + """Extract out-of-rank params for an adapter.""" + return _slice_out_of_rank(params, adapter_idx, lambda _: rank) + + def verify_params_unchanged(initial_params, final_params, error_msg_prefix: str): """Verify that params haven't changed between initial and final state.""" for (path, initial), (_, final) in zip( @@ -52,7 +58,7 @@ def verify_params_unchanged(initial_params, final_params, error_msg_prefix: str) assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" -def is_routed_expert_path(path) -> bool: +def _is_routed_expert_path(path) -> bool: """Check if path is for routed experts (not shared_experts).""" keys = [] for p in path: @@ -72,18 +78,7 @@ def get_moe_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: For routed experts, uses effective rank = max(1, rank // num_experts). """ - def slice_param(path, p): - path_str = str(path) - effective_rank = max(1, rank // num_experts) if is_routed_expert_path(path) else rank - is_stacked = is_stacked_lora_path(path) - if "lora_A" in path_str: - if is_stacked: - return p[:, adapter_idx, ..., effective_rank:].copy() - return p[adapter_idx, ..., effective_rank:].copy() - elif "lora_B" in path_str: - if is_stacked: - return p[:, adapter_idx, ..., effective_rank:, :].copy() - return p[adapter_idx, ..., effective_rank:, :].copy() - return p + def get_rank(path): + return max(1, rank // num_experts) if _is_routed_expert_path(path) else rank - return jax.tree.map_with_path(slice_param, params) + return _slice_out_of_rank(params, adapter_idx, get_rank) diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 40382468d..0f19ded95 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -408,14 +408,20 @@ def __call__( class DeepseekV3DecoderLayer(nnx.Module): - """Base decoder layer with shared attributes and forward pass.""" + """Decoder layer supporting both dense MLP and sparse MoE.""" - mlp: DeepseekV3MLP | DeepseekV3MoE # Set by subclasses - - def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + def __init__( + self, + config: DeepseekV3Config, + *, + mlp_cls: type[DeepseekV3MLP] | type[DeepseekV3MoE], + dtype: jnp.dtype, + rngs: nnx.Rngs, + ) -> None: self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) self.self_attn = DeepseekV3Attention(config, dtype=dtype, rngs=rngs) + self.mlp = mlp_cls(config, dtype=dtype, rngs=rngs) def __call__( self, @@ -445,22 +451,6 @@ def __call__( return hidden_states, updated_cache -class DeepseekV3DenseDecoderLayer(DeepseekV3DecoderLayer): - """Dense decoder layer (uses MLP, no MoE).""" - - def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: - super().__init__(config, dtype=dtype, rngs=rngs) - self.mlp = DeepseekV3MLP(config, dtype=dtype, rngs=rngs) - - -class DeepseekV3MoEDecoderLayer(DeepseekV3DecoderLayer): - """MoE decoder layer (uses sparse MoE instead of dense MLP).""" - - def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: - super().__init__(config, dtype=dtype, rngs=rngs) - self.mlp = DeepseekV3MoE(config, dtype=dtype, rngs=rngs) - - class DeepseekV3Model(nnx.Module): def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: @@ -483,8 +473,8 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs # Create stacked dense layers (layers 0 to first_k_dense_replace - 1) if self.num_dense_layers > 0: - def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DenseDecoderLayer: - return DeepseekV3DenseDecoderLayer(config, dtype=dtype, rngs=rngs) + def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: + return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MLP, dtype=dtype, rngs=rngs) self.dense_layers = create_stacked_layers(create_dense_layer, self.num_dense_layers, rngs) else: @@ -493,8 +483,8 @@ def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DenseDecoderLayer: # Create stacked MoE layers (layers first_k_dense_replace to num_hidden_layers - 1) if self.num_moe_layers > 0: - def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3MoEDecoderLayer: - return DeepseekV3MoEDecoderLayer(config, dtype=dtype, rngs=rngs) + def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: + return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MoE, dtype=dtype, rngs=rngs) self.moe_layers = create_stacked_layers(create_moe_layer, self.num_moe_layers, rngs) else: diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index e7b176871..c32f4a661 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -114,6 +114,46 @@ def seq_len(self) -> int: """Current sequence length.""" return self.keys.shape[2] + def split(self, layer_idx: int) -> tuple[KVCache, KVCache]: + """Split the cache at a layer index. + + Args: + layer_idx: Layer index to split at. + + Returns: + Tuple of (first_cache, second_cache) where first_cache contains + layers [0, layer_idx) and second_cache contains layers [layer_idx, num_layers). + """ + return ( + KVCache( + keys=self.keys[:layer_idx], + values=self.values[:layer_idx], + cache_position=self.cache_position, + ), + KVCache( + keys=self.keys[layer_idx:], + values=self.values[layer_idx:], + cache_position=self.cache_position, + ), + ) + + @staticmethod + def concatenate(first: KVCache, second: KVCache) -> KVCache: + """Concatenate two caches along the layer dimension. + + Args: + first: First cache (earlier layers). + second: Second cache (later layers). + + Returns: + Combined KVCache with all layers. + """ + return KVCache( + keys=jnp.concatenate([first.keys, second.keys], axis=0), + values=jnp.concatenate([first.values, second.values], axis=0), + cache_position=second.cache_position, + ) + @jax.tree_util.register_dataclass @dataclass From b6a6f9588f72a45393f3d34a79551be11fd622fc Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 15:51:45 -0800 Subject: [PATCH 100/133] Add get_adapter_idx to consolidate stacked/non-stacked indexing Introduces get_adapter_idx(path, adapter_index) that encapsulates the stacked vs non-stacked adapter indexing logic. Removes duplicate if/else patterns across the codebase. Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tests/models/lora_test_utils.py | 15 ++++------ skyrl-tx/tx/layers/lora.py | 13 ++------ skyrl-tx/tx/tinker/backends/jax.py | 28 +++++------------ skyrl-tx/tx/utils/models.py | 38 ++++++++++++------------ 4 files changed, 35 insertions(+), 59 deletions(-) diff --git a/skyrl-tx/tests/models/lora_test_utils.py b/skyrl-tx/tests/models/lora_test_utils.py index 00c9077d8..b83d583d6 100644 --- a/skyrl-tx/tests/models/lora_test_utils.py +++ b/skyrl-tx/tests/models/lora_test_utils.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from tx.utils.models import is_stacked_lora_path +from tx.utils.models import get_adapter_idx def get_adapter_params(params, adapter_idx: int): @@ -14,9 +14,8 @@ def get_adapter_params(params, adapter_idx: int): """ def extract(path, p): - if is_stacked_lora_path(path): - return p[:, adapter_idx].copy() - return p[adapter_idx].copy() + idx = get_adapter_idx(path, adapter_idx) + return p[idx].copy() return jax.tree.map_with_path(extract, params) @@ -35,12 +34,10 @@ def slice_param(path, p): if "lora_A" not in path_str and "lora_B" not in path_str: return p rank = get_rank(path) - is_stacked = is_stacked_lora_path(path) + idx = get_adapter_idx(path, adapter_idx) if "lora_A" in path_str: - idx = (slice(None), adapter_idx, ..., slice(rank, None)) if is_stacked else (adapter_idx, ..., slice(rank, None)) - else: # lora_B - idx = (slice(None), adapter_idx, ..., slice(rank, None), slice(None)) if is_stacked else (adapter_idx, ..., slice(rank, None), slice(None)) - return p[idx].copy() + return p[idx + (..., slice(rank, None))].copy() + return p[idx + (..., slice(rank, None), slice(None))].copy() return jax.tree.map_with_path(slice_param, params) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index c0a3f6a10..7a4a4c6aa 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -3,7 +3,7 @@ from jax import numpy as jnp from jax.core import Tracer -from tx.utils.models import filter_lora, is_stacked_lora_path +from tx.utils.models import filter_lora, get_adapter_idx from tx.layers.util import Param, prepare_routing, ragged_dot from tx.models.types import ModelForCausalLM from tx.tinker.types import LoraConfig @@ -25,13 +25,6 @@ def _get_sharding_spec(arr: jax.Array): return None -def _adapter_index(is_stacked: bool, adapter_index: int): - """Return index for accessing an adapter. Stacked params have layers as first dim.""" - # Stacked layers have shape (num_layers, num_adapters, ...), - # non-stacked (embed_tokens) have shape (num_adapters, ...) - return (slice(None), adapter_index) if is_stacked else (adapter_index,) - - class LoRAMixin: """A mixin for flax NNX modules to add multi-adapter LoRA support. This mixin adds LoRA parameters (lora_A, lora_B) and methods to apply @@ -368,7 +361,7 @@ def init_adapter(path, value): if not filter_lora(lora_config, normalized_path): effective_rank = 0 - idx = _adapter_index(is_stacked_lora_path(path), adapter_index) + idx = get_adapter_idx(path, adapter_index) key_name = path[-2].key if key_name == "lora_ranks": @@ -402,7 +395,7 @@ def clear_adapter(path, value): key = path[-2].key if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"): return value - idx = _adapter_index(is_stacked_lora_path(path), adapter_index) + idx = get_adapter_idx(path, adapter_index) return value.at[idx].set(0 if key == "lora_ranks" else 0.0) updated_state = jax.tree.map_with_path(clear_adapter, state) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 80cb6dfff..6eb8ab52a 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -52,7 +52,7 @@ insert_adapter_state, round_up_seq_len, resolve_model_path, - is_stacked_lora_path, + get_adapter_idx, ) from tx.utils.log import logger @@ -125,35 +125,21 @@ def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "Accumulated ) def get_mean(self, adapter_index: jax.Array) -> nnx.State: - """Compute mean gradients for a specific adapter, with zeros for all other adapters. - - Handles both stacked (num_layers, num_adapters, ...) and non-stacked (num_adapters, ...) params. - """ + """Compute mean gradients for a specific adapter, with zeros for all other adapters.""" count = self.counts[adapter_index] def compute_mean(path, g): - if is_stacked_lora_path(path): - # Stacked: (num_layers, num_adapters, ...) -> index as [:, adapter_index] - return jnp.zeros_like(g).at[:, adapter_index].set(g[:, adapter_index] / count.astype(g.dtype)) - else: - # Non-stacked: (num_adapters, ...) -> index as [adapter_index] - return jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)) + idx = get_adapter_idx(path, adapter_index) + return jnp.zeros_like(g).at[idx].set(g[idx] / count.astype(g.dtype)) return jax.tree.map_with_path(compute_mean, self.grad_sum) def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients": - """Reset gradients and count for a specific adapter. - - Handles both stacked (num_layers, num_adapters, ...) and non-stacked (num_adapters, ...) params. - """ + """Reset gradients and count for a specific adapter.""" def reset_grad(path, g): - if is_stacked_lora_path(path): - # Stacked: (num_layers, num_adapters, ...) -> index as [:, adapter_index] - return g.at[:, adapter_index].set(0.0) - else: - # Non-stacked: (num_adapters, ...) -> index as [adapter_index] - return g.at[adapter_index].set(0.0) + idx = get_adapter_idx(path, adapter_index) + return g.at[idx].set(0.0) return AccumulatedGradients( grad_sum=jax.tree.map_with_path(reset_grad, self.grad_sum), diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index fb82329a6..9355cbc9e 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -94,6 +94,17 @@ def is_stacked_lora_path(path: tuple) -> bool: return any(name in path_strs for name in ("layers", "dense_layers", "moe_layers")) +def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: + """Return index tuple for accessing an adapter at the given path. + + Stacked layer params have shape (num_layers, num_adapters, ...) -> index as [:, adapter_index]. + Non-stacked params (embed_tokens) have shape (num_adapters, ...) -> index as [adapter_index]. + """ + if is_stacked_lora_path(path): + return (slice(None), adapter_index) + return (adapter_index,) + + def _is_stacked_layer_param(path: tuple) -> bool: """Check if a parameter path corresponds to a STACKED decoder layer weight. @@ -362,21 +373,6 @@ def get_optimizer(optimizer_name: OptimizerName, optimizer_args: dict) -> optax. raise ValueError("The 'learning_rate' key must be provided in optimizer_args.") -def _lora_slice(is_stacked: bool, adapter_index: int, rank: int, is_lora_A: bool) -> tuple: - """Return slice tuple for extracting/inserting LoRA params. - - LoRA param shapes: - - 3D: Non-stacked linear/embed (A, in, R) or (A, R, out) - - 4D: Stacked linear/embed (L, A, in, R) or non-stacked expert (A, E, in, R) - - 5D: Stacked expert (L, A, E, in, R) - """ - # Adapter index: axis 1 for stacked (L, A, ...), axis 0 for non-stacked (A, ...) - adapter_idx = (slice(None), adapter_index) if is_stacked else (adapter_index,) - # Rank slice: lora_A has rank at last dim, lora_B has rank at second-to-last - rank_slice = (Ellipsis, slice(None, rank)) if is_lora_A else (Ellipsis, slice(None, rank), slice(None)) - return adapter_idx + rank_slice - - @nnx.jit(static_argnames=("adapter_index", "rank")) def extract_adapter_state(adapter_index: int, lora_params: nnx.GraphState, rank: int) -> nnx.GraphState: "Helper function to extract the adapter parameters for a specific adapter index." @@ -386,8 +382,10 @@ def extract_state(path: tuple, p: jnp.ndarray): if key not in {"lora_A", "lora_B"}: return p assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" - idx = _lora_slice(is_stacked_lora_path(path), adapter_index, rank, is_lora_A=(key == "lora_A")) - return p[idx] + idx = get_adapter_idx(path, adapter_index) + if key == "lora_A": + return p[idx + (..., slice(None, rank))] + return p[idx + (..., slice(None, rank), slice(None))] return jax.tree.map_with_path(extract_state, lora_params) @@ -404,8 +402,10 @@ def insert_state(path: tuple, p: jax.Array, new: jax.Array): if key not in {"lora_A", "lora_B"}: return new assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" - idx = _lora_slice(is_stacked_lora_path(path), adapter_index, rank, is_lora_A=(key == "lora_A")) - return p.at[idx].set(new) + idx = get_adapter_idx(path, adapter_index) + if key == "lora_A": + return p.at[idx + (..., slice(None, rank))].set(new) + return p.at[idx + (..., slice(None, rank), slice(None))].set(new) updated = jax.tree.map_with_path(insert_state, lora_params, new_params) nnx.update(lora_params, updated) From 6f8e486efc3b7ae0294fda7359ae366590314951 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 16:00:03 -0800 Subject: [PATCH 101/133] Revert "Fix Qwen3 MoE softmax ordering to match HuggingFace" This reverts commit 7d5bf5b085e48e704171497a3cd29acf179c2527. --- skyrl-tx/tx/models/qwen3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 912be0bfc..5be6fb0f1 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -192,8 +192,8 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> def __call__( self, hidden_states: jax.Array, router_logits: jax.Array, adapter_indices: jax.Array | None = None ) -> jax.Array: - routing_weights = nnx.softmax(router_logits, axis=-1) - routing_weights, selected_experts = jax.lax.top_k(routing_weights, k=self.config.num_experts_per_tok) + routing_weights, selected_experts = jax.lax.top_k(router_logits, k=self.config.num_experts_per_tok) + routing_weights = nnx.softmax(routing_weights, axis=-1) num_experts = self.config.num_experts num_experts_per_tok = self.config.num_experts_per_tok From 2f2f7652a2cd76f29075b4aa0e2ade1499b909c9 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 16:04:09 -0800 Subject: [PATCH 102/133] Remove redundant _is_stacked_layer_param function Use is_stacked_lora_path directly since we always stack layers. The digit check for non-stacked format is no longer needed. Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/utils/models.py | 27 ++------------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 9355cbc9e..c976fe0b8 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -105,29 +105,6 @@ def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: return (adapter_index,) -def _is_stacked_layer_param(path: tuple) -> bool: - """Check if a parameter path corresponds to a STACKED decoder layer weight. - - Stacked layers have paths like: - - Qwen3/Llama3: ('model', 'layers', 'self_attn', ...) - - DeepSeekV3 dense: ('model', 'dense_layers', 'self_attn', ...) - - DeepSeekV3 MoE: ('model', 'moe_layers', 'self_attn', ...) - - Non-stacked layers have paths like: ('model', 'layers', '0', 'self_attn', ...) - """ - path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] - # Check for split stacked layer names (DeepSeekV3) - if "dense_layers" in path_strs or "moe_layers" in path_strs: - return True - # Check for regular stacked layers (Qwen3/Llama3) - if "layers" not in path_strs: - return False - layers_idx = path_strs.index("layers") - if layers_idx + 1 < len(path_strs) and path_strs[layers_idx + 1].isdigit(): - return False # Non-stacked: path already contains layer index - return True # Stacked: no layer index in path - - def _get_layer_group_info(path: tuple, config: ModelConfig) -> tuple[str, int]: """Get layer group name and starting layer index for a stacked param path. @@ -226,7 +203,7 @@ def load_safetensors( if skip_lora and any(k in path_keys for k in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")): continue - if _is_stacked_layer_param(path): + if is_stacked_lora_path(path): # Stack per-layer weights from HF format # Infer layer count from param shape and get offset for split stacked layers stacked_layer_count = param.shape[0] @@ -265,7 +242,7 @@ def save_safetensors( if filter_fn is not None and not filter_fn(path): continue - if _is_stacked_layer_param(path): + if is_stacked_lora_path(path): # Unstack and save as individual layer weights # Infer layer count from param shape and get offset for split stacked layers stacked_layer_count = param.shape[0] From ab1a7c9789fcec0a17008cf3bb6eafe737ab290c Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 16:11:23 -0800 Subject: [PATCH 103/133] Use KVCache.split() and concatenate() in DeepseekV3 Make split() and concatenate() handle None for empty layer groups. Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/models/deepseekv3.py | 26 ++--------------------- skyrl-tx/tx/utils/generator.py | 36 ++++++++++++++++++-------------- 2 files changed, 22 insertions(+), 40 deletions(-) diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 0f19ded95..a7415cd1e 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -513,18 +513,7 @@ def __call__( dense_kv_cache = None moe_kv_cache = None if kv_cache is not None: - if self.num_dense_layers > 0: - dense_kv_cache = KVCache( - keys=kv_cache.keys[: self.num_dense_layers], - values=kv_cache.values[: self.num_dense_layers], - cache_position=kv_cache.cache_position, - ) - if self.num_moe_layers > 0: - moe_kv_cache = KVCache( - keys=kv_cache.keys[self.num_dense_layers :], - values=kv_cache.values[self.num_dense_layers :], - cache_position=kv_cache.cache_position, - ) + dense_kv_cache, moe_kv_cache = kv_cache.split(self.num_dense_layers) # Forward through dense layers dense_kv_result = None @@ -563,18 +552,7 @@ def __call__( all_hidden_states.append(hidden_states) # Merge KV caches from dense and MoE layers - if dense_kv_result is not None and moe_kv_result is not None: - new_kv_cache = KVCache( - keys=jnp.concatenate([dense_kv_result.keys, moe_kv_result.keys], axis=0), - values=jnp.concatenate([dense_kv_result.values, moe_kv_result.values], axis=0), - cache_position=moe_kv_result.cache_position, - ) - elif dense_kv_result is not None: - new_kv_cache = dense_kv_result - elif moe_kv_result is not None: - new_kv_cache = moe_kv_result - else: - new_kv_cache = None + new_kv_cache = KVCache.concatenate(dense_kv_result, moe_kv_result) return ModelOutput( last_hidden_state=hidden_states, diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index c32f4a661..f48f45198 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -114,7 +114,7 @@ def seq_len(self) -> int: """Current sequence length.""" return self.keys.shape[2] - def split(self, layer_idx: int) -> tuple[KVCache, KVCache]: + def split(self, layer_idx: int) -> tuple[KVCache | None, KVCache | None]: """Split the cache at a layer index. Args: @@ -123,31 +123,35 @@ def split(self, layer_idx: int) -> tuple[KVCache, KVCache]: Returns: Tuple of (first_cache, second_cache) where first_cache contains layers [0, layer_idx) and second_cache contains layers [layer_idx, num_layers). + Returns None for empty splits. """ - return ( - KVCache( - keys=self.keys[:layer_idx], - values=self.values[:layer_idx], - cache_position=self.cache_position, - ), - KVCache( - keys=self.keys[layer_idx:], - values=self.values[layer_idx:], - cache_position=self.cache_position, - ), + first = None if layer_idx == 0 else KVCache( + keys=self.keys[:layer_idx], + values=self.values[:layer_idx], + cache_position=self.cache_position, + ) + second = None if layer_idx == self.num_layers else KVCache( + keys=self.keys[layer_idx:], + values=self.values[layer_idx:], + cache_position=self.cache_position, ) + return first, second @staticmethod - def concatenate(first: KVCache, second: KVCache) -> KVCache: + def concatenate(first: KVCache | None, second: KVCache | None) -> KVCache | None: """Concatenate two caches along the layer dimension. Args: - first: First cache (earlier layers). - second: Second cache (later layers). + first: First cache (earlier layers), or None. + second: Second cache (later layers), or None. Returns: - Combined KVCache with all layers. + Combined KVCache, or the non-None input, or None if both are None. """ + if first is None: + return second + if second is None: + return first return KVCache( keys=jnp.concatenate([first.keys, second.keys], axis=0), values=jnp.concatenate([first.values, second.values], axis=0), From 9635e4d28ad4f5753a76eb6fedac2d9ddff9993f Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 16:19:26 -0800 Subject: [PATCH 104/133] lint --- skyrl-tx/tx/utils/generator.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index f48f45198..6c0651991 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -125,15 +125,23 @@ def split(self, layer_idx: int) -> tuple[KVCache | None, KVCache | None]: layers [0, layer_idx) and second_cache contains layers [layer_idx, num_layers). Returns None for empty splits. """ - first = None if layer_idx == 0 else KVCache( - keys=self.keys[:layer_idx], - values=self.values[:layer_idx], - cache_position=self.cache_position, + first = ( + None + if layer_idx == 0 + else KVCache( + keys=self.keys[:layer_idx], + values=self.values[:layer_idx], + cache_position=self.cache_position, + ) ) - second = None if layer_idx == self.num_layers else KVCache( - keys=self.keys[layer_idx:], - values=self.values[layer_idx:], - cache_position=self.cache_position, + second = ( + None + if layer_idx == self.num_layers + else KVCache( + keys=self.keys[layer_idx:], + values=self.values[layer_idx:], + cache_position=self.cache_position, + ) ) return first, second From 1bf80be292d2faa8b78a5201b70be5881c044d03 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 17:32:24 -0800 Subject: [PATCH 105/133] fix --- skyrl-tx/tx/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index fca7c6645..228781132 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -44,7 +44,7 @@ def create_stacked_layers( """ @nnx.split_rngs(splits=num_layers) - @nnx.vmap(in_axes=(0,), out_axes=0) + @nnx.vmap(in_axes=(0,), out_axes=0, transform_metadata={nnx.PARTITION_NAME: None}) def vmapped_create(rngs: nnx.Rngs): return create_layer_fn(rngs) From 3abaa7cd80bb89a39c25243fe766c24cd5074b99 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 18:02:42 -0800 Subject: [PATCH 106/133] skip kv cache for training --- skyrl-tx/tx/models/deepseekv3.py | 5 +++++ skyrl-tx/tx/models/llama3.py | 4 ++++ skyrl-tx/tx/models/qwen3.py | 4 ++++ skyrl-tx/tx/models/types.py | 8 ++++---- skyrl-tx/tx/models/utils.py | 29 ++++++++++++++++------------- skyrl-tx/tx/tinker/backends/jax.py | 1 + 6 files changed, 34 insertions(+), 17 deletions(-) diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index a7415cd1e..c13e7efab 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -501,6 +501,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -528,6 +529,7 @@ def __call__( kv_cache=dense_kv_cache, output_hidden_states=output_hidden_states, gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, ) all_hidden_states.extend(dense_hidden_states) @@ -544,6 +546,7 @@ def __call__( kv_cache=moe_kv_cache, output_hidden_states=output_hidden_states, gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, ) all_hidden_states.extend(moe_hidden_states) @@ -598,6 +601,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = jnp.arange(attention_mask.shape[1])[None, :] @@ -609,6 +613,7 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, + is_training=is_training, ) return CausalLMOutput( diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 4c9d8c9d2..cf68076ca 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -221,6 +221,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -238,6 +239,7 @@ def __call__( kv_cache=kv_cache, output_hidden_states=output_hidden_states, gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, ) hidden_states = self.norm(hidden_states) @@ -290,6 +292,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = jnp.arange(attention_mask.shape[1])[None, :] @@ -301,6 +304,7 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, + is_training=is_training, ) return CausalLMOutput( diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 5be6fb0f1..db349ba7b 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -336,6 +336,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -353,6 +354,7 @@ def __call__( kv_cache=kv_cache, output_hidden_states=output_hidden_states, gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, ) hidden_states = self.norm(hidden_states) @@ -405,6 +407,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = jnp.arange(attention_mask.shape[1])[None, :] @@ -416,6 +419,7 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, + is_training=is_training, ) return CausalLMOutput( diff --git a/skyrl-tx/tx/models/types.py b/skyrl-tx/tx/models/types.py index 8067c9f8a..16d0241d5 100644 --- a/skyrl-tx/tx/models/types.py +++ b/skyrl-tx/tx/models/types.py @@ -24,12 +24,12 @@ class ModelOutput: Attributes: last_hidden_state: The last hidden state from the model. - kv_cache: The updated key-value cache. + kv_cache: The updated key-value cache (None during training). hidden_states: All hidden states if output_hidden_states=True. """ last_hidden_state: jax.Array - kv_cache: KVCache + kv_cache: KVCache | None hidden_states: list[jax.Array] | None = None @@ -40,10 +40,10 @@ class CausalLMOutput: Attributes: last_hidden_state: The last hidden state from the model. - kv_cache: The updated key-value cache. + kv_cache: The updated key-value cache (None during training). hidden_states: All hidden states, if output_hidden_states=True. """ last_hidden_state: jax.Array - kv_cache: KVCache + kv_cache: KVCache | None hidden_states: list[jax.Array] | None = None diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 228781132..7cf5e999c 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -1,8 +1,8 @@ """Utility functions for model forward passes with stacked decoder layers. -This module provides a unified forward_layers function that works for both training -(with gradient checkpointing) and inference. The key insight is that jax.checkpoint -is a no-op when not computing gradients, so we can use the same scan-based code path. +This module provides: +- create_stacked_layers: Create decoder layers with stacked weights using nnx.vmap +- forward_layers: Unified forward pass using scan (skips KV cache during training) Prerequisites: - Layers must be created with nnx.vmap (stacked weights) @@ -62,12 +62,9 @@ def forward_layers( kv_cache: KVCache | None, output_hidden_states: bool, gradient_checkpointing: bool, -) -> tuple[jax.Array, list[jax.Array], KVCache]: - """Unified forward pass through stacked decoder layers. - - Uses jax.lax.scan for both training and inference. When gradient_checkpointing=True, - wraps the body function with jax.checkpoint. This is a no-op during inference - (when not computing gradients), so we can use a single code path. + is_training: bool = False, +) -> tuple[jax.Array, list[jax.Array], KVCache | None]: + """Unified forward pass through stacked decoder layers using scan. Args: layers: Stacked decoder layers (created with create_stacked_layers/nnx.vmap). @@ -78,10 +75,12 @@ def forward_layers( adapter_indices: Optional LoRA adapter indices of shape (batch,). kv_cache: Optional KV cache for decode mode (None for prefill). output_hidden_states: Whether to return intermediate hidden states. - gradient_checkpointing: Whether to use gradient checkpointing. + gradient_checkpointing: Whether to use gradient checkpointing (training only). + is_training: Whether in training mode. Skips KV cache to save memory. Returns: Tuple of (final_hidden_states, all_hidden_states, kv_cache). + kv_cache is None when is_training=True. """ assert num_layers > 0, "num_layers must be positive" @@ -99,7 +98,6 @@ def body_fn(hs, xs): # Reconstruct layer module from stacked weights layer = nnx.merge(layer_graphdef, jax.tree.map(lambda x: x[layer_idx], layer_state)) - new_hs, (k, v) = layer( hs, attention_mask=attention_mask, @@ -107,8 +105,11 @@ def body_fn(hs, xs): adapter_indices=adapter_indices, kv_cache=layer_kv, ) - hs_output = new_hs if output_hidden_states else None + + if is_training: + # Avoid accumulating large KV tensors for training. + k = v = None return new_hs, (hs_output, k, v) if gradient_checkpointing: @@ -124,7 +125,9 @@ def body_fn(hs, xs): # [embed, layer0_out, ..., layer(N-2)_out]; final layer output gets normed by caller all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] - if is_decode: + if is_training: + new_kv_cache = None + elif is_decode: # Decode mode: scan stacked the per-layer updated caches into (num_layers, ...) new_kv_cache = KVCache( keys=all_keys, diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 6eb8ab52a..744c70d98 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -259,6 +259,7 @@ def _model_forward( input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices, + is_training=True, ) return model.compute_logprobs(output.last_hidden_state, target_ids, adapter_indices) From 209f959167c82e39f25786d45c6b36981ccfd5d3 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 30 Jan 2026 19:19:06 -0800 Subject: [PATCH 107/133] Fix shard_map_ep PartitionSpec length mismatch for extracted layers When a single layer is extracted from stacked layers via x[layer_idx], the tensor loses a dimension but the PartitionSpec metadata still has the extra leading None (from vmap transform_metadata). Truncate the spec from the beginning to match the actual tensor rank. Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/layers/util.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 0030c604d..e0f596d94 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -96,12 +96,21 @@ def shard_map_ep(module: nnx.Module, func, *args): *args: Arguments to pass to func (replicated across shards). """ graphdef, state = nnx.split(module) - # Extract only 'ep' dims from PartitionSpecs, replacing others with None - state_specs = jax.tree.map( - lambda s: PartitionSpec(*(p if p == "ep" else None for p in s)) if isinstance(s, PartitionSpec) else s, - nnx.get_partition_spec(state), - is_leaf=lambda x: isinstance(x, PartitionSpec), - ) + + def make_ep_spec(spec, value): + """Create a PartitionSpec with only 'ep' dims, truncated to match tensor rank.""" + if not isinstance(spec, PartitionSpec): + return spec + # When a layer is extracted from stacked layers via x[layer_idx], the tensor + # loses a dimension but the PartitionSpec metadata still has the extra leading None. + # Truncate the spec to match the actual tensor rank. + arr = value.value if hasattr(value, "value") else value + rank = len(arr.shape) if hasattr(arr, "shape") else 0 + truncated = tuple(spec)[-rank:] if rank > 0 else () + return PartitionSpec(*(p if p == "ep" else None for p in truncated)) + + partition_specs = nnx.get_partition_spec(state) + state_specs = jax.tree.map(make_ep_spec, partition_specs, state, is_leaf=lambda x: isinstance(x, PartitionSpec)) in_specs = (state_specs,) + (PartitionSpec(),) * len(args) @jax.shard_map(mesh=get_abstract_mesh(), in_specs=in_specs, out_specs=PartitionSpec(), axis_names={"ep"}) From 5122c2c481bd1d2eed65943bf22d7de9ffbaefa1 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 30 Jan 2026 22:04:47 -0800 Subject: [PATCH 108/133] remove closure --- skyrl-tx/tx/models/utils.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 7cf5e999c..091146cc4 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -13,7 +13,6 @@ from flax import nnx import jax -from jax import numpy as jnp from tx.utils.generator import KVCache @@ -88,16 +87,16 @@ def forward_layers( is_decode = kv_cache is not None def body_fn(hs, xs): - # Unpack xs based on mode (structure differs between prefill and decode) + # Unpack xs: scan automatically slices the leading dimension of layer_state if is_decode: - layer_idx, layer_k, layer_v = xs + layer_params, layer_k, layer_v = xs layer_kv = (layer_k, layer_v) else: - layer_idx = xs + layer_params = xs layer_kv = None - # Reconstruct layer module from stacked weights - layer = nnx.merge(layer_graphdef, jax.tree.map(lambda x: x[layer_idx], layer_state)) + # Merge using the sliced params directly - no manual gather needed + layer = nnx.merge(layer_graphdef, layer_params) new_hs, (k, v) = layer( hs, attention_mask=attention_mask, @@ -115,10 +114,10 @@ def body_fn(hs, xs): if gradient_checkpointing: body_fn = jax.checkpoint(body_fn) - # Prepare scan inputs: in decode mode, pass per-layer caches via xs - # Scan automatically slices along axis 0, so each iteration gets one layer's cache - layer_indices = jnp.arange(num_layers) - xs = (layer_indices, kv_cache.keys, kv_cache.values) if is_decode else layer_indices + # Pass layer_state as xs so scan handles the slicing automatically. + # This avoids capturing layer_state as a closure and manually gathering, + # which causes slow XLA compilation with jax.checkpoint. + xs = (layer_state, kv_cache.keys, kv_cache.values) if is_decode else layer_state final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, xs) From 2c0c3e9815e2deb3cbc9bdf39755c1ccc0ae8894 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 2 Feb 2026 19:30:53 -0800 Subject: [PATCH 109/133] Fix create_stacked_layers to avoid vmap memory overhead Replace nnx.vmap with individual layer creation + jnp.stack. vmap breaks eager sharding, causing ~4x memory overhead due to full model replication instead of tensor-parallel sharding. The new approach: - Creates layers individually with a Python loop (respects eager sharding) - Stacks parameters using jit with donate_argnums to reduce peak memory - Preserves correct sharding specs on stacked arrays Memory improvement (per GPU, Qwen3-4B with tp=8): - nnx.List baseline: 1461 MiB - Old vmap approach: 4533 MiB - New loop+stack: 2485 MiB Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/utils.py | 80 +++++++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 091146cc4..886924093 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -1,11 +1,11 @@ """Utility functions for model forward passes with stacked decoder layers. This module provides: -- create_stacked_layers: Create decoder layers with stacked weights using nnx.vmap +- create_stacked_layers: Create decoder layers with stacked weights - forward_layers: Unified forward pass using scan (skips KV cache during training) Prerequisites: -- Layers must be created with nnx.vmap (stacked weights) +- Layers must be created with create_stacked_layers (stacked weights) - KVCache must use stacked format: (num_layers, batch, seq, heads, dim) """ @@ -22,11 +22,15 @@ def create_stacked_layers( num_layers: int, rngs: nnx.Rngs, ) -> nnx.Module: - """Create stacked decoder layers using nnx.vmap. + """Create stacked decoder layers by creating individual layers and stacking their parameters. This creates a single module object where all parameters have shape (num_layers, ...). This enables efficient scanning over layers without runtime stacking. + Note: We avoid using nnx.vmap for layer creation because vmap breaks eager sharding, + causing ~4x memory overhead. Instead, we create layers individually (which respects + eager sharding) and then stack their parameters with jnp.stack. + Args: create_layer_fn: Function that takes rngs and returns a single layer module. num_layers: Number of layers to create. @@ -41,13 +45,73 @@ def create_stacked_layers( >>> layers = create_stacked_layers(create_layer, config.num_hidden_layers, rngs) >>> # layers.self_attn.q_proj.kernel.shape == (num_layers, hidden, head_dim*num_heads) """ + import warnings + from functools import partial + + import jax.numpy as jnp + import jax.random + from jax.sharding import NamedSharding, PartitionSpec + + # Split the RNG key to get unique keys for each layer + base_key = rngs.params() + layer_keys = jax.random.split(base_key, num_layers) + + # Get the current mesh for sharding + mesh = jax.sharding.get_mesh() + + # Create all layers individually - this respects eager sharding + layers = [create_layer_fn(nnx.Rngs(layer_keys[i])) for i in range(num_layers)] + + # Get graphdef from first layer (all layers have same structure) + graphdef, first_state = nnx.split(layers[0]) + + # Extract flattened states from all layers + states = [nnx.split(layer)[1] for layer in layers] + del layers + + flat_states = [jax.tree_util.tree_flatten(s)[0] for s in states] + treedef = jax.tree_util.tree_flatten(states[0])[1] + del states + + # Stack each parameter array using jit with donate_argnums for memory efficiency. + # This tells XLA to try to reuse input buffers for the output, reducing peak memory. + stacked_flat = [] + for i in range(len(flat_states[0])): + # Get arrays for this parameter across all layers + arrays = [flat_states[j][i] for j in range(num_layers)] + + # Get original sharding spec and extend it for the stacked dimension + original_sharding = arrays[0].sharding + if hasattr(original_sharding, "spec"): + original_spec = original_sharding.spec + # Prepend None for the new layer dimension + new_spec = PartitionSpec(None, *original_spec) + new_sharding = NamedSharding(mesh, new_spec) + + # Use jit with donate_argnums and out_shardings for memory-efficient stacking. + # The donation hints help XLA manage memory better during the stacking operation. + @partial(jax.jit, donate_argnums=tuple(range(num_layers)), out_shardings=new_sharding) + def do_stack(*arrs): + return jnp.stack(arrs, axis=0) + + # Suppress donation warnings since we expect some buffers can't be donated + # (stacking changes array shapes so direct donation isn't always possible) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Some donated buffers were not usable") + stacked = do_stack(*arrays) + else: + stacked = jnp.stack(arrays, axis=0) + stacked_flat.append(stacked) + del arrays + + del flat_states - @nnx.split_rngs(splits=num_layers) - @nnx.vmap(in_axes=(0,), out_axes=0, transform_metadata={nnx.PARTITION_NAME: None}) - def vmapped_create(rngs: nnx.Rngs): - return create_layer_fn(rngs) + # Reconstruct the state tree with stacked arrays + stacked_state = jax.tree_util.tree_unflatten(treedef, stacked_flat) + del stacked_flat - return vmapped_create(rngs) + # Merge back into a module with stacked parameters + return nnx.merge(graphdef, stacked_state) def forward_layers( From 40f99d4e05be1f754632f717c7f0c092bc9de949 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 4 Feb 2026 13:50:16 -0800 Subject: [PATCH 110/133] Optimize create_stacked_layers to avoid 2x peak memory Instead of creating all layers then stacking (which requires holding both original arrays and stacked arrays simultaneously), pre-allocate the stacked arrays and copy each layer's params directly using dynamic_update_slice. This keeps only one layer in memory at a time. Memory improvement during layer creation: - Old: JAX peak ~2098 MiB (originals + stacked arrays) - New: JAX peak ~1316 MiB (stacked arrays + 1 layer) Also adds memory logging via nvidia-smi and JAX memory_stats for debugging memory usage throughout the layer creation process. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/utils.py | 136 +++++++++++++++++++++++++----------- 1 file changed, 94 insertions(+), 42 deletions(-) diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 886924093..8f6752bff 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -9,6 +9,8 @@ - KVCache must use stacked format: (num_layers, batch, seq, heads, dim) """ +import logging +import subprocess from typing import Callable from flax import nnx @@ -16,20 +18,51 @@ from tx.utils.generator import KVCache +logger = logging.getLogger(__name__) + + +def _log_mem(label: str): + """Log GPU memory usage via nvidia-smi and JAX memory stats.""" + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"], + capture_output=True, + text=True, + timeout=5, + ) + nvidia_mem = max(int(x) for x in result.stdout.strip().split("\n")) + except Exception: + nvidia_mem = -1 + + try: + # Get JAX's view of memory usage + devices = jax.devices() + jax_mems = [] + for d in devices: + stats = d.memory_stats() + if stats: + # bytes_in_use is the actual memory used by JAX arrays + jax_mems.append(stats.get("bytes_in_use", 0) / 1024 / 1024) + jax_mem = max(jax_mems) if jax_mems else -1 + except Exception: + jax_mem = -1 + + logger.info(f"[MEM] {label}: nvidia={nvidia_mem} MiB, jax={jax_mem:.1f} MiB") + def create_stacked_layers( create_layer_fn: Callable[[nnx.Rngs], nnx.Module], num_layers: int, rngs: nnx.Rngs, ) -> nnx.Module: - """Create stacked decoder layers by creating individual layers and stacking their parameters. + """Create stacked decoder layers by creating one layer at a time and copying to pre-allocated arrays. This creates a single module object where all parameters have shape (num_layers, ...). This enables efficient scanning over layers without runtime stacking. - Note: We avoid using nnx.vmap for layer creation because vmap breaks eager sharding, - causing ~4x memory overhead. Instead, we create layers individually (which respects - eager sharding) and then stack their parameters with jnp.stack. + Memory optimization: Instead of creating all layers then stacking (which requires 2x memory), + we pre-allocate the stacked arrays and copy each layer's params directly, keeping only + one layer in memory at a time. Args: create_layer_fn: Function that takes rngs and returns a single layer module. @@ -45,13 +78,14 @@ def create_stacked_layers( >>> layers = create_stacked_layers(create_layer, config.num_hidden_layers, rngs) >>> # layers.self_attn.q_proj.kernel.shape == (num_layers, hidden, head_dim*num_heads) """ - import warnings from functools import partial import jax.numpy as jnp import jax.random from jax.sharding import NamedSharding, PartitionSpec + _log_mem("create_stacked_layers:start") + # Split the RNG key to get unique keys for each layer base_key = rngs.params() layer_keys = jax.random.split(base_key, num_layers) @@ -59,59 +93,77 @@ def create_stacked_layers( # Get the current mesh for sharding mesh = jax.sharding.get_mesh() - # Create all layers individually - this respects eager sharding - layers = [create_layer_fn(nnx.Rngs(layer_keys[i])) for i in range(num_layers)] + # Step 1: Create first layer to get structure and shapes + first_layer = create_layer_fn(nnx.Rngs(layer_keys[0])) + graphdef, first_state = nnx.split(first_layer) + flat_first, treedef = jax.tree_util.tree_flatten(first_state) - # Get graphdef from first layer (all layers have same structure) - graphdef, first_state = nnx.split(layers[0]) + num_params = len(flat_first) + logger.info(f"[MEM] Creating {num_layers} layers with {num_params} params each") + _log_mem("create_stacked_layers:after_first_layer") - # Extract flattened states from all layers - states = [nnx.split(layer)[1] for layer in layers] - del layers - - flat_states = [jax.tree_util.tree_flatten(s)[0] for s in states] - treedef = jax.tree_util.tree_flatten(states[0])[1] - del states - - # Stack each parameter array using jit with donate_argnums for memory efficiency. - # This tells XLA to try to reuse input buffers for the output, reducing peak memory. + # Step 2: Pre-allocate stacked arrays with proper sharding stacked_flat = [] - for i in range(len(flat_states[0])): - # Get arrays for this parameter across all layers - arrays = [flat_states[j][i] for j in range(num_layers)] - - # Get original sharding spec and extend it for the stacked dimension - original_sharding = arrays[0].sharding + for arr in flat_first: + # Determine sharding for stacked array + original_sharding = arr.sharding if hasattr(original_sharding, "spec"): original_spec = original_sharding.spec - # Prepend None for the new layer dimension new_spec = PartitionSpec(None, *original_spec) new_sharding = NamedSharding(mesh, new_spec) + else: + new_sharding = None - # Use jit with donate_argnums and out_shardings for memory-efficient stacking. - # The donation hints help XLA manage memory better during the stacking operation. - @partial(jax.jit, donate_argnums=tuple(range(num_layers)), out_shardings=new_sharding) - def do_stack(*arrs): - return jnp.stack(arrs, axis=0) - - # Suppress donation warnings since we expect some buffers can't be donated - # (stacking changes array shapes so direct donation isn't always possible) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="Some donated buffers were not usable") - stacked = do_stack(*arrays) + # Pre-allocate with zeros + stacked_shape = (num_layers,) + arr.shape + if new_sharding is not None: + stacked = jax.device_put(jnp.zeros(stacked_shape, dtype=arr.dtype), new_sharding) else: - stacked = jnp.stack(arrays, axis=0) + stacked = jnp.zeros(stacked_shape, dtype=arr.dtype) stacked_flat.append(stacked) - del arrays - del flat_states + _log_mem("create_stacked_layers:after_preallocate") + + # Step 3: Copy first layer's params to slice 0 + @jax.jit + def copy_to_slice(stacked, arr, idx): + return jax.lax.dynamic_update_slice(stacked, arr[None], (idx,) + (0,) * arr.ndim) + + for param_idx in range(num_params): + stacked_flat[param_idx] = copy_to_slice(stacked_flat[param_idx], flat_first[param_idx], 0) + + # Free first layer + del first_layer, first_state, flat_first + _log_mem("create_stacked_layers:after_layer_0") + + # Step 4: Create remaining layers one at a time, copy params, then free + for layer_idx in range(1, num_layers): + layer = create_layer_fn(nnx.Rngs(layer_keys[layer_idx])) + _, state = nnx.split(layer) + flat_state, _ = jax.tree_util.tree_flatten(state) + + # Copy each param to the appropriate slice + for param_idx in range(num_params): + stacked_flat[param_idx] = copy_to_slice( + stacked_flat[param_idx], flat_state[param_idx], layer_idx + ) + + # Free this layer immediately + del layer, state, flat_state + + if layer_idx == num_layers - 1 or (layer_idx + 1) % 6 == 0: + _log_mem(f"create_stacked_layers:after_layer_{layer_idx}") + + _log_mem("create_stacked_layers:after_all_layers") - # Reconstruct the state tree with stacked arrays + # Step 5: Reconstruct the state tree with stacked arrays stacked_state = jax.tree_util.tree_unflatten(treedef, stacked_flat) del stacked_flat # Merge back into a module with stacked parameters - return nnx.merge(graphdef, stacked_state) + result = nnx.merge(graphdef, stacked_state) + _log_mem("create_stacked_layers:end") + return result def forward_layers( From bceff5fd5566e49997c0bb0218ea02f1c277803b Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 4 Feb 2026 16:00:29 -0800 Subject: [PATCH 111/133] Use KV cache as scan carry for buffer donation Pass KV cache keys/values as part of scan carry instead of xs, enabling JAX buffer donation for effective in-place updates. This reduces peak memory during decode from 10793 MiB to 6697 MiB (38% reduction) by avoiding duplicate cache allocation. Also unifies the body_fn and scan call for prefill/decode paths. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/utils.py | 53 +++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index 8f6752bff..af5f02821 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -202,16 +202,16 @@ def forward_layers( layer_graphdef, layer_state = nnx.split(layers) is_decode = kv_cache is not None - def body_fn(hs, xs): - # Unpack xs: scan automatically slices the leading dimension of layer_state - if is_decode: - layer_params, layer_k, layer_v = xs - layer_kv = (layer_k, layer_v) + def body_fn(carry, layer_params): + hs, cache_keys, cache_values, layer_idx = carry + + # Extract layer's cache slice if available + if cache_keys is not None: + layer_kv = (cache_keys[layer_idx], cache_values[layer_idx]) else: - layer_params = xs layer_kv = None - # Merge using the sliced params directly - no manual gather needed + # Forward through layer layer = nnx.merge(layer_graphdef, layer_params) new_hs, (k, v) = layer( hs, @@ -220,37 +220,38 @@ def body_fn(hs, xs): adapter_indices=adapter_indices, kv_cache=layer_kv, ) + hs_output = new_hs if output_hidden_states else None - if is_training: - # Avoid accumulating large KV tensors for training. + # Update cache in carry if present (decode), otherwise accumulate outputs (prefill) + if cache_keys is not None: + cache_keys = cache_keys.at[layer_idx].set(k) + cache_values = cache_values.at[layer_idx].set(v) + k = v = None # Don't accumulate in output - cache is in carry + elif is_training: k = v = None - return new_hs, (hs_output, k, v) + + return (new_hs, cache_keys, cache_values, layer_idx + 1), (hs_output, k, v) if gradient_checkpointing: body_fn = jax.checkpoint(body_fn) - # Pass layer_state as xs so scan handles the slicing automatically. - # This avoids capturing layer_state as a closure and manually gathering, - # which causes slow XLA compilation with jax.checkpoint. - xs = (layer_state, kv_cache.keys, kv_cache.values) if is_decode else layer_state + cache_keys = kv_cache.keys if kv_cache else None + cache_values = kv_cache.values if kv_cache else None + init_carry = (hidden_states, cache_keys, cache_values, 0) - final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, xs) + (final_hs, final_keys, final_values, _), (all_hs, all_keys, all_values) = jax.lax.scan( + body_fn, init_carry, layer_state + ) - # [embed, layer0_out, ..., layer(N-2)_out]; final layer output gets normed by caller - all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] - - if is_training: - new_kv_cache = None - elif is_decode: - # Decode mode: scan stacked the per-layer updated caches into (num_layers, ...) + if is_decode: new_kv_cache = KVCache( - keys=all_keys, - values=all_values, + keys=final_keys, + values=final_values, cache_position=kv_cache.cache_position + positions.shape[1], ) else: - # Prefill mode: build cache from collected k,v outputs - new_kv_cache = KVCache.from_layer_outputs(all_keys, all_values, attention_mask) + new_kv_cache = None if is_training else KVCache.from_layer_outputs(all_keys, all_values, attention_mask) + all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] return final_hs, all_hidden_states, new_kv_cache From 08ec23aa811977f1202d52e2062beaedaae95e31 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 4 Feb 2026 16:54:05 -0800 Subject: [PATCH 112/133] Simplify create_stacked_layers while preserving memory efficiency - Remove logging/debugging code (_log_mem, subprocess calls) - Use .at[idx].set() instead of dynamic_update_slice (cleaner syntax) - Keep donate_argnums=(0,) for buffer reuse (key to memory efficiency) - Reduce code from ~80 lines to ~40 lines Memory benchmark unchanged at 6697 MiB (vs 10797 MiB with vmap). Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/utils.py | 124 +++++++----------------------------- 1 file changed, 24 insertions(+), 100 deletions(-) diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index af5f02821..fd53a1b15 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -9,60 +9,28 @@ - KVCache must use stacked format: (num_layers, batch, seq, heads, dim) """ -import logging -import subprocess +import functools from typing import Callable from flax import nnx import jax +import jax.numpy as jnp from tx.utils.generator import KVCache -logger = logging.getLogger(__name__) - - -def _log_mem(label: str): - """Log GPU memory usage via nvidia-smi and JAX memory stats.""" - try: - result = subprocess.run( - ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"], - capture_output=True, - text=True, - timeout=5, - ) - nvidia_mem = max(int(x) for x in result.stdout.strip().split("\n")) - except Exception: - nvidia_mem = -1 - - try: - # Get JAX's view of memory usage - devices = jax.devices() - jax_mems = [] - for d in devices: - stats = d.memory_stats() - if stats: - # bytes_in_use is the actual memory used by JAX arrays - jax_mems.append(stats.get("bytes_in_use", 0) / 1024 / 1024) - jax_mem = max(jax_mems) if jax_mems else -1 - except Exception: - jax_mem = -1 - - logger.info(f"[MEM] {label}: nvidia={nvidia_mem} MiB, jax={jax_mem:.1f} MiB") - def create_stacked_layers( create_layer_fn: Callable[[nnx.Rngs], nnx.Module], num_layers: int, rngs: nnx.Rngs, ) -> nnx.Module: - """Create stacked decoder layers by creating one layer at a time and copying to pre-allocated arrays. + """Create stacked decoder layers by creating layers individually and stacking. This creates a single module object where all parameters have shape (num_layers, ...). This enables efficient scanning over layers without runtime stacking. - Memory optimization: Instead of creating all layers then stacking (which requires 2x memory), - we pre-allocate the stacked arrays and copy each layer's params directly, keeping only - one layer in memory at a time. + Note: We avoid nnx.vmap because it breaks eager sharding, causing ~4x memory overhead. + We also avoid jnp.stack because it creates a temporary full replica before resharding. Args: create_layer_fn: Function that takes rngs and returns a single layer module. @@ -78,92 +46,48 @@ def create_stacked_layers( >>> layers = create_stacked_layers(create_layer, config.num_hidden_layers, rngs) >>> # layers.self_attn.q_proj.kernel.shape == (num_layers, hidden, head_dim*num_heads) """ - from functools import partial - - import jax.numpy as jnp - import jax.random from jax.sharding import NamedSharding, PartitionSpec - _log_mem("create_stacked_layers:start") - - # Split the RNG key to get unique keys for each layer - base_key = rngs.params() - layer_keys = jax.random.split(base_key, num_layers) - - # Get the current mesh for sharding + layer_keys = jax.random.split(rngs.params(), num_layers) mesh = jax.sharding.get_mesh() - # Step 1: Create first layer to get structure and shapes + # Create first layer to get structure and shapes first_layer = create_layer_fn(nnx.Rngs(layer_keys[0])) graphdef, first_state = nnx.split(first_layer) flat_first, treedef = jax.tree_util.tree_flatten(first_state) - num_params = len(flat_first) - logger.info(f"[MEM] Creating {num_layers} layers with {num_params} params each") - _log_mem("create_stacked_layers:after_first_layer") - - # Step 2: Pre-allocate stacked arrays with proper sharding + # Pre-allocate stacked arrays with correct sharding stacked_flat = [] for arr in flat_first: - # Determine sharding for stacked array + stacked_shape = (num_layers,) + arr.shape original_sharding = arr.sharding if hasattr(original_sharding, "spec"): - original_spec = original_sharding.spec - new_spec = PartitionSpec(None, *original_spec) - new_sharding = NamedSharding(mesh, new_spec) + new_spec = PartitionSpec(None, *original_sharding.spec) + stacked = jax.device_put(jnp.zeros(stacked_shape, arr.dtype), NamedSharding(mesh, new_spec)) else: - new_sharding = None - - # Pre-allocate with zeros - stacked_shape = (num_layers,) + arr.shape - if new_sharding is not None: - stacked = jax.device_put(jnp.zeros(stacked_shape, dtype=arr.dtype), new_sharding) - else: - stacked = jnp.zeros(stacked_shape, dtype=arr.dtype) + stacked = jnp.zeros(stacked_shape, arr.dtype) stacked_flat.append(stacked) - _log_mem("create_stacked_layers:after_preallocate") - - # Step 3: Copy first layer's params to slice 0 - @jax.jit + # JIT with donate_argnums enables buffer reuse + @functools.partial(jax.jit, donate_argnums=(0,)) def copy_to_slice(stacked, arr, idx): - return jax.lax.dynamic_update_slice(stacked, arr[None], (idx,) + (0,) * arr.ndim) + return stacked.at[idx].set(arr) - for param_idx in range(num_params): - stacked_flat[param_idx] = copy_to_slice(stacked_flat[param_idx], flat_first[param_idx], 0) + # Copy first layer's params to slot 0 + for i, arr in enumerate(flat_first): + stacked_flat[i] = copy_to_slice(stacked_flat[i], flat_first[i], 0) - # Free first layer - del first_layer, first_state, flat_first - _log_mem("create_stacked_layers:after_layer_0") - - # Step 4: Create remaining layers one at a time, copy params, then free + # Create remaining layers one at a time and copy params for layer_idx in range(1, num_layers): layer = create_layer_fn(nnx.Rngs(layer_keys[layer_idx])) _, state = nnx.split(layer) - flat_state, _ = jax.tree_util.tree_flatten(state) - - # Copy each param to the appropriate slice - for param_idx in range(num_params): - stacked_flat[param_idx] = copy_to_slice( - stacked_flat[param_idx], flat_state[param_idx], layer_idx - ) + flat, _ = jax.tree_util.tree_flatten(state) + for i, arr in enumerate(flat): + stacked_flat[i] = copy_to_slice(stacked_flat[i], flat[i], layer_idx) - # Free this layer immediately - del layer, state, flat_state - - if layer_idx == num_layers - 1 or (layer_idx + 1) % 6 == 0: - _log_mem(f"create_stacked_layers:after_layer_{layer_idx}") - - _log_mem("create_stacked_layers:after_all_layers") - - # Step 5: Reconstruct the state tree with stacked arrays + # Reconstruct and merge stacked_state = jax.tree_util.tree_unflatten(treedef, stacked_flat) - del stacked_flat - - # Merge back into a module with stacked parameters - result = nnx.merge(graphdef, stacked_state) - _log_mem("create_stacked_layers:end") - return result + return nnx.merge(graphdef, stacked_state) def forward_layers( From 98d54291caf56d419a73910e033ffc4a40256f30 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 4 Feb 2026 19:16:07 -0800 Subject: [PATCH 113/133] Sync NNX sharding metadata after stacking layers tree_unflatten creates Variables with metadata from the original treedef, which doesn't include the stacked sharding. NNX APIs (get_partition_spec, Optimizer) read from 'sharding_names' metadata rather than array.sharding, so we sync them after unflatten. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/models/utils.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py index fd53a1b15..49036e9e3 100644 --- a/skyrl-tx/tx/models/utils.py +++ b/skyrl-tx/tx/models/utils.py @@ -85,8 +85,23 @@ def copy_to_slice(stacked, arr, idx): for i, arr in enumerate(flat): stacked_flat[i] = copy_to_slice(stacked_flat[i], flat[i], layer_idx) - # Reconstruct and merge + # Reconstruct state from stacked arrays. + # tree_unflatten creates new Variables with values from stacked_flat, + # but metadata (including sharding_names) comes from treedef (the original first layer). stacked_state = jax.tree_util.tree_unflatten(treedef, stacked_flat) + + # Sync NNX sharding metadata with actual array sharding. + # The arrays have correct stacked sharding from device_put (line 66), but NNX APIs + # (nnx.get_partition_spec, nnx.Optimizer) read from 'sharding_names' metadata instead. + def update_sharding_metadata(var): + if isinstance(var, nnx.Variable) and hasattr(var.value, "sharding"): + array_sharding = var.value.sharding + if hasattr(array_sharding, "spec"): + var.set_metadata("sharding_names", tuple(array_sharding.spec)) + return var + + jax.tree.map(update_sharding_metadata, stacked_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + return nnx.merge(graphdef, stacked_state) From 8bac19386072c0e862e44037bf3e05106f274bf5 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 4 Feb 2026 20:57:19 -0800 Subject: [PATCH 114/133] Integrate StackedDecoderLayers abstraction with unstack_state approach - Add StackedDecoderLayers class with ArrayRef write-through views - Add unstack_state() for checkpoint loading transformation - Update all models (Llama3, Qwen3, DeepSeekV3) to use StackedDecoderLayers - Simplify load_safetensors and save_safetensors using unstack_state - Update is_stacked_lora_path to detect _stacked in paths - Delete tx/models/utils.py (moved to tx/layers/stacked.py) Passes 35/42 tests. Known issues: - DeepSeekV3 checkpoint loading needs path remapping for split layers - Will refactor to direct access pattern (Option 3) to fix --- skyrl-tx/tx/layers/stacked.py | 243 +++++++++++++++++++++++++++++++ skyrl-tx/tx/models/deepseekv3.py | 14 +- skyrl-tx/tx/models/llama3.py | 8 +- skyrl-tx/tx/models/qwen3.py | 8 +- skyrl-tx/tx/models/utils.py | 196 ------------------------- skyrl-tx/tx/utils/models.py | 162 ++++++--------------- 6 files changed, 301 insertions(+), 330 deletions(-) create mode 100644 skyrl-tx/tx/layers/stacked.py delete mode 100644 skyrl-tx/tx/models/utils.py diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py new file mode 100644 index 000000000..f54ba14a5 --- /dev/null +++ b/skyrl-tx/tx/layers/stacked.py @@ -0,0 +1,243 @@ +"""StackedDecoderLayers module for efficient transformer layer stacking.""" + +import functools +from typing import Callable + +from flax import nnx +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec + +from tx.utils.generator import KVCache + + +class ArrayRef(nnx.Variable): + """A Variable providing a view into an indexed slice of a parent Variable.""" + + def __init__(self, parent: nnx.Variable, idx: int): + super().__init__(parent[idx]) + self.set_metadata("_parent", parent) + self.set_metadata("_idx", idx) + + def __getitem__(self, key): + parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") + return parent[idx] if key is Ellipsis else parent[idx][key] + + def set_raw_value(self, value, **kwargs): + """Write through to parent when value is set.""" + parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") + parent[...] = parent[...].at[idx].set(value) + super().set_raw_value(value, **kwargs) + + @property + def shape(self): + return self.get_metadata("_parent")[self.get_metadata("_idx")].shape + + +class StackedDecoderLayers(nnx.Module): + """Decoder layers with stacked weights for efficient scan-based forward pass. + + Parameters are stored in stacked format (num_layers, ...). The forward pass + uses jax.lax.scan for all modes (training/prefill/decode) with KV cache as + scan carry for efficient buffer donation. + + This class encapsulates both layer creation and forward pass logic. + """ + + def __init__( + self, + create_layer_fn: Callable[[nnx.Rngs], nnx.Module], + num_layers: int, + rngs: nnx.Rngs, + ): + """Create stacked decoder layers. + + This creates a single _stacked module where all parameters have shape (num_layers, ...). + Layers are created individually and stacked to avoid nnx.vmap memory overhead. + + Args: + create_layer_fn: Function that takes rngs and returns a single layer module. + num_layers: Number of layers to create. + rngs: Random number generators for initialization. + """ + self.num_layers = num_layers + + layer_keys = jax.random.split(rngs.params(), num_layers) + mesh = jax.sharding.get_mesh() + + # Create first layer to get structure and shapes + first_layer = create_layer_fn(nnx.Rngs(layer_keys[0])) + graphdef, first_state = nnx.split(first_layer) + flat_first, treedef = jax.tree_util.tree_flatten(first_state) + + # Pre-allocate stacked arrays with correct sharding + stacked_flat = [] + for arr in flat_first: + stacked_shape = (num_layers,) + arr.shape + original_sharding = arr.sharding + if hasattr(original_sharding, "spec"): + new_spec = PartitionSpec(None, *original_sharding.spec) + stacked = jax.device_put(jnp.zeros(stacked_shape, arr.dtype), NamedSharding(mesh, new_spec)) + else: + stacked = jnp.zeros(stacked_shape, arr.dtype) + stacked_flat.append(stacked) + + # JIT with donate_argnums enables buffer reuse + @functools.partial(jax.jit, donate_argnums=(0,)) + def copy_to_slice(stacked, arr, idx): + return stacked.at[idx].set(arr) + + # Copy first layer's params to slot 0 + for i, arr in enumerate(flat_first): + stacked_flat[i] = copy_to_slice(stacked_flat[i], flat_first[i], 0) + + # Create remaining layers one at a time and copy params + for layer_idx in range(1, num_layers): + layer = create_layer_fn(nnx.Rngs(layer_keys[layer_idx])) + _, state = nnx.split(layer) + flat, _ = jax.tree_util.tree_flatten(state) + for i, arr in enumerate(flat): + stacked_flat[i] = copy_to_slice(stacked_flat[i], flat[i], layer_idx) + + # Reconstruct and merge + stacked_state = jax.tree_util.tree_unflatten(treedef, stacked_flat) + self._stacked = nnx.merge(graphdef, stacked_state) + + def __len__(self) -> int: + """Return the number of layers.""" + return self.num_layers + + def __getitem__(self, index: int) -> nnx.Module: + """Get view into layer at index (stays synced with stacked state).""" + if index < 0 or index >= self.num_layers: + raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") + graphdef, state = nnx.split(self._stacked) + layer_state = jax.tree.map( + lambda x: ArrayRef(x, index), + state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + return nnx.merge(graphdef, layer_state) + + def __iter__(self): + """Iterate over individual layers (for testing/weight loading).""" + for i in range(self.num_layers): + yield self[i] + + def __call__( + self, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + kv_cache: KVCache | None, + output_hidden_states: bool, + gradient_checkpointing: bool, + is_training: bool = False, + ) -> tuple[jax.Array, list[jax.Array], KVCache | None]: + """Forward pass through all layers using scan. + + Uses jax.lax.scan for all modes (training/prefill/decode). For decode mode, + the KV cache is passed as scan carry for efficient buffer donation. + + Args: + hidden_states: Input hidden states of shape (batch, seq, hidden). + attention_mask: Attention mask of shape (batch, seq). + positions: Position indices of shape (batch, seq). + adapter_indices: Optional LoRA adapter indices of shape (batch,). + kv_cache: Optional KV cache for decode mode (None for prefill). + output_hidden_states: Whether to return intermediate hidden states. + gradient_checkpointing: Whether to use gradient checkpointing. + is_training: Whether in training mode. Skips KV cache to save memory. + + Returns: + Tuple of (final_hidden_states, all_hidden_states, kv_cache). + kv_cache is None when is_training=True. + """ + assert self.num_layers > 0, "num_layers must be positive" + + graphdef, state = nnx.split(self._stacked) + is_decode = kv_cache is not None + + def body_fn(carry, layer_params): + hs, cache_keys, cache_values, layer_idx = carry + + # Extract layer's cache slice if available + if cache_keys is not None: + layer_kv = (cache_keys[layer_idx], cache_values[layer_idx]) + else: + layer_kv = None + + # Forward through layer + layer = nnx.merge(graphdef, layer_params) + new_hs, (k, v) = layer( + hs, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=layer_kv, + ) + + hs_output = new_hs if output_hidden_states else None + + # Update cache in carry if present (decode), otherwise accumulate outputs (prefill) + if cache_keys is not None: + cache_keys = cache_keys.at[layer_idx].set(k) + cache_values = cache_values.at[layer_idx].set(v) + k = v = None # Don't accumulate in output - cache is in carry + elif is_training: + k = v = None + + return (new_hs, cache_keys, cache_values, layer_idx + 1), (hs_output, k, v) + + if gradient_checkpointing: + body_fn = jax.checkpoint(body_fn) + + cache_keys = kv_cache.keys if kv_cache else None + cache_values = kv_cache.values if kv_cache else None + init_carry = (hidden_states, cache_keys, cache_values, 0) + + (final_hs, final_keys, final_values, _), (all_hs, all_keys, all_values) = jax.lax.scan( + body_fn, init_carry, state + ) + + if is_decode: + new_kv_cache = KVCache( + keys=final_keys, + values=final_values, + cache_position=kv_cache.cache_position + positions.shape[1], + ) + else: + new_kv_cache = None if is_training else KVCache.from_layer_outputs(all_keys, all_values, attention_mask) + + all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] + return final_hs, all_hidden_states, new_kv_cache + + +def unstack_state(module: nnx.Module) -> nnx.GraphState: + """Transform stacked layer state to unstacked ArrayRef views. + + Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc. + Each entry is an ArrayRef that writes through to the original stacked variable. + + This is useful for checkpoint loading where weights are stored per-layer. + + Args: + module: Module containing StackedDecoderLayers. + + Returns: + GraphState with unstacked paths and ArrayRef views. + """ + expanded = [] + for path, var in nnx.to_flat_state(nnx.state(module)): + if "_stacked" not in path: + expanded.append((path, var)) + continue + + idx = path.index("_stacked") + for i in range(var[...].shape[0]): + new_path = path[:idx] + (str(i),) + path[idx + 1 :] + expanded.append((new_path, ArrayRef(var, i))) + + return nnx.from_flat_state(expanded) diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index c64a446f7..8d01855f2 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -7,9 +7,9 @@ from tx.layers.rotary_embedding import get_rope from tx.layers.util import Param, prepare_routing, shard_map_ep from tx.layers.layernorm import RMSNorm +from tx.layers.stacked import StackedDecoderLayers from tx.models.configs import DeepseekV3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput -from tx.models.utils import create_stacked_layers, forward_layers from tx.utils.generator import GeneratorMixin, KVCache from tx.utils.logits_processor import LogitsProcessorMixin, LMHead @@ -489,7 +489,7 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MLP, dtype=dtype, rngs=rngs) - self.dense_layers = create_stacked_layers(create_dense_layer, self.num_dense_layers, rngs) + self.dense_layers = StackedDecoderLayers(create_dense_layer, self.num_dense_layers, rngs) else: self.dense_layers = None @@ -499,7 +499,7 @@ def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MoE, dtype=dtype, rngs=rngs) - self.moe_layers = create_stacked_layers(create_moe_layer, self.num_moe_layers, rngs) + self.moe_layers = StackedDecoderLayers(create_moe_layer, self.num_moe_layers, rngs) else: self.moe_layers = None @@ -532,10 +532,8 @@ def __call__( # Forward through dense layers dense_kv_result = None if self.dense_layers is not None: - hidden_states, dense_hidden_states, dense_kv_result = forward_layers( - self.dense_layers, + hidden_states, dense_hidden_states, dense_kv_result = self.dense_layers( hidden_states, - self.num_dense_layers, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, @@ -549,10 +547,8 @@ def __call__( # Forward through MoE layers moe_kv_result = None if self.moe_layers is not None: - hidden_states, moe_hidden_states, moe_kv_result = forward_layers( - self.moe_layers, + hidden_states, moe_hidden_states, moe_kv_result = self.moe_layers( hidden_states, - self.num_moe_layers, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index be38e15a9..8ff6c85ff 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -8,7 +8,7 @@ from tx.layers.lora import LoRAEmbed, LoRALinear from tx.layers.rotary_embedding import apply_rope from tx.layers.layernorm import RMSNorm -from tx.models.utils import create_stacked_layers, forward_layers +from tx.layers.stacked import StackedDecoderLayers from tx.utils.logits_processor import LogitsProcessorMixin, LMHead from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache @@ -217,7 +217,7 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> def create_layer(rngs: nnx.Rngs) -> Llama3DecoderLayer: return Llama3DecoderLayer(config, dtype=dtype, rngs=rngs) - self.layers = create_stacked_layers(create_layer, config.num_hidden_layers, rngs) + self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) def __call__( @@ -237,10 +237,8 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - hidden_states, all_hidden_states, new_kv_cache = forward_layers( - self.layers, + hidden_states, all_hidden_states, new_kv_cache = self.layers( hidden_states, - self.num_layers, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 303fb3137..a067e8245 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -8,9 +8,9 @@ from tx.layers.util import prepare_routing, shard_map_ep from tx.layers.rotary_embedding import apply_rope from tx.layers.layernorm import RMSNorm +from tx.layers.stacked import StackedDecoderLayers from tx.models.configs import Qwen3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput -from tx.models.utils import create_stacked_layers, forward_layers from tx.utils.generator import GeneratorMixin, KVCache from tx.utils.logits_processor import LogitsProcessorMixin, LMHead @@ -335,7 +335,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> def create_layer(rngs: nnx.Rngs) -> Qwen3DecoderLayer: return Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs) - self.layers = create_stacked_layers(create_layer, config.num_hidden_layers, rngs) + self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) def __call__( @@ -355,10 +355,8 @@ def __call__( hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - hidden_states, all_hidden_states, new_kv_cache = forward_layers( - self.layers, + hidden_states, all_hidden_states, new_kv_cache = self.layers( hidden_states, - self.num_layers, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, diff --git a/skyrl-tx/tx/models/utils.py b/skyrl-tx/tx/models/utils.py deleted file mode 100644 index 49036e9e3..000000000 --- a/skyrl-tx/tx/models/utils.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Utility functions for model forward passes with stacked decoder layers. - -This module provides: -- create_stacked_layers: Create decoder layers with stacked weights -- forward_layers: Unified forward pass using scan (skips KV cache during training) - -Prerequisites: -- Layers must be created with create_stacked_layers (stacked weights) -- KVCache must use stacked format: (num_layers, batch, seq, heads, dim) -""" - -import functools -from typing import Callable - -from flax import nnx -import jax -import jax.numpy as jnp - -from tx.utils.generator import KVCache - - -def create_stacked_layers( - create_layer_fn: Callable[[nnx.Rngs], nnx.Module], - num_layers: int, - rngs: nnx.Rngs, -) -> nnx.Module: - """Create stacked decoder layers by creating layers individually and stacking. - - This creates a single module object where all parameters have shape (num_layers, ...). - This enables efficient scanning over layers without runtime stacking. - - Note: We avoid nnx.vmap because it breaks eager sharding, causing ~4x memory overhead. - We also avoid jnp.stack because it creates a temporary full replica before resharding. - - Args: - create_layer_fn: Function that takes rngs and returns a single layer module. - num_layers: Number of layers to create. - rngs: Random number generators for initialization. - - Returns: - A single module with stacked parameters. - - Example: - >>> def create_layer(rngs): - ... return Llama3DecoderLayer(config, dtype=dtype, rngs=rngs) - >>> layers = create_stacked_layers(create_layer, config.num_hidden_layers, rngs) - >>> # layers.self_attn.q_proj.kernel.shape == (num_layers, hidden, head_dim*num_heads) - """ - from jax.sharding import NamedSharding, PartitionSpec - - layer_keys = jax.random.split(rngs.params(), num_layers) - mesh = jax.sharding.get_mesh() - - # Create first layer to get structure and shapes - first_layer = create_layer_fn(nnx.Rngs(layer_keys[0])) - graphdef, first_state = nnx.split(first_layer) - flat_first, treedef = jax.tree_util.tree_flatten(first_state) - - # Pre-allocate stacked arrays with correct sharding - stacked_flat = [] - for arr in flat_first: - stacked_shape = (num_layers,) + arr.shape - original_sharding = arr.sharding - if hasattr(original_sharding, "spec"): - new_spec = PartitionSpec(None, *original_sharding.spec) - stacked = jax.device_put(jnp.zeros(stacked_shape, arr.dtype), NamedSharding(mesh, new_spec)) - else: - stacked = jnp.zeros(stacked_shape, arr.dtype) - stacked_flat.append(stacked) - - # JIT with donate_argnums enables buffer reuse - @functools.partial(jax.jit, donate_argnums=(0,)) - def copy_to_slice(stacked, arr, idx): - return stacked.at[idx].set(arr) - - # Copy first layer's params to slot 0 - for i, arr in enumerate(flat_first): - stacked_flat[i] = copy_to_slice(stacked_flat[i], flat_first[i], 0) - - # Create remaining layers one at a time and copy params - for layer_idx in range(1, num_layers): - layer = create_layer_fn(nnx.Rngs(layer_keys[layer_idx])) - _, state = nnx.split(layer) - flat, _ = jax.tree_util.tree_flatten(state) - for i, arr in enumerate(flat): - stacked_flat[i] = copy_to_slice(stacked_flat[i], flat[i], layer_idx) - - # Reconstruct state from stacked arrays. - # tree_unflatten creates new Variables with values from stacked_flat, - # but metadata (including sharding_names) comes from treedef (the original first layer). - stacked_state = jax.tree_util.tree_unflatten(treedef, stacked_flat) - - # Sync NNX sharding metadata with actual array sharding. - # The arrays have correct stacked sharding from device_put (line 66), but NNX APIs - # (nnx.get_partition_spec, nnx.Optimizer) read from 'sharding_names' metadata instead. - def update_sharding_metadata(var): - if isinstance(var, nnx.Variable) and hasattr(var.value, "sharding"): - array_sharding = var.value.sharding - if hasattr(array_sharding, "spec"): - var.set_metadata("sharding_names", tuple(array_sharding.spec)) - return var - - jax.tree.map(update_sharding_metadata, stacked_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) - - return nnx.merge(graphdef, stacked_state) - - -def forward_layers( - layers: nnx.Module, - hidden_states: jax.Array, - num_layers: int, - *, - attention_mask: jax.Array, - positions: jax.Array, - adapter_indices: jax.Array | None, - kv_cache: KVCache | None, - output_hidden_states: bool, - gradient_checkpointing: bool, - is_training: bool = False, -) -> tuple[jax.Array, list[jax.Array], KVCache | None]: - """Unified forward pass through stacked decoder layers using scan. - - Args: - layers: Stacked decoder layers (created with create_stacked_layers/nnx.vmap). - hidden_states: Input hidden states of shape (batch, seq, hidden). - num_layers: Number of decoder layers. - attention_mask: Attention mask of shape (batch, seq). - positions: Position indices of shape (batch, seq). - adapter_indices: Optional LoRA adapter indices of shape (batch,). - kv_cache: Optional KV cache for decode mode (None for prefill). - output_hidden_states: Whether to return intermediate hidden states. - gradient_checkpointing: Whether to use gradient checkpointing (training only). - is_training: Whether in training mode. Skips KV cache to save memory. - - Returns: - Tuple of (final_hidden_states, all_hidden_states, kv_cache). - kv_cache is None when is_training=True. - """ - assert num_layers > 0, "num_layers must be positive" - - layer_graphdef, layer_state = nnx.split(layers) - is_decode = kv_cache is not None - - def body_fn(carry, layer_params): - hs, cache_keys, cache_values, layer_idx = carry - - # Extract layer's cache slice if available - if cache_keys is not None: - layer_kv = (cache_keys[layer_idx], cache_values[layer_idx]) - else: - layer_kv = None - - # Forward through layer - layer = nnx.merge(layer_graphdef, layer_params) - new_hs, (k, v) = layer( - hs, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=layer_kv, - ) - - hs_output = new_hs if output_hidden_states else None - - # Update cache in carry if present (decode), otherwise accumulate outputs (prefill) - if cache_keys is not None: - cache_keys = cache_keys.at[layer_idx].set(k) - cache_values = cache_values.at[layer_idx].set(v) - k = v = None # Don't accumulate in output - cache is in carry - elif is_training: - k = v = None - - return (new_hs, cache_keys, cache_values, layer_idx + 1), (hs_output, k, v) - - if gradient_checkpointing: - body_fn = jax.checkpoint(body_fn) - - cache_keys = kv_cache.keys if kv_cache else None - cache_values = kv_cache.values if kv_cache else None - init_carry = (hidden_states, cache_keys, cache_values, 0) - - (final_hs, final_keys, final_values, _), (all_hs, all_keys, all_values) = jax.lax.scan( - body_fn, init_carry, layer_state - ) - - if is_decode: - new_kv_cache = KVCache( - keys=final_keys, - values=final_values, - cache_position=kv_cache.cache_position + positions.shape[1], - ) - else: - new_kv_cache = None if is_training else KVCache.from_layer_outputs(all_keys, all_values, attention_mask) - - all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] - return final_hs, all_hidden_states, new_kv_cache diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index c976fe0b8..8739b15d6 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -88,10 +88,10 @@ def is_stacked_lora_path(path: tuple) -> bool: path: Parameter path tuple (can be nnx path objects or strings). Returns: - True if the path contains 'layers', 'dense_layers', or 'moe_layers'. + True if the path contains '_stacked' (from StackedDecoderLayers). """ path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] - return any(name in path_strs for name in ("layers", "dense_layers", "moe_layers")) + return "_stacked" in path_strs def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: @@ -105,76 +105,19 @@ def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: return (adapter_index,) -def _get_layer_group_info(path: tuple, config: ModelConfig) -> tuple[str, int]: - """Get layer group name and starting layer index for a stacked param path. - - Returns: - Tuple of (layer_name_for_hf_key, layer_offset) where: - - layer_name_for_hf_key is 'layers' (HF always uses 'layers') - - layer_offset is the starting layer index for this group - """ - path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] - if "dense_layers" in path_strs: - return "layers", 0 - elif "moe_layers" in path_strs: - return "layers", getattr(config, "first_k_dense_replace", 0) - else: - return "layers", 0 - - -def _path_to_hf_key(path: tuple, layer_idx: int | None = None) -> str: - """Convert param path to HuggingFace key. If layer_idx provided, insert it after 'layers'. - - Handles split stacked layer names (dense_layers, moe_layers) by converting them to 'layers'. - """ - parts = [] - for p in path: - key = p.key if hasattr(p, "key") else str(p) - # Handle split stacked layer names - convert to 'layers' with index - if key in ("layers", "dense_layers", "moe_layers") and layer_idx is not None: - parts.append(f"layers.{layer_idx}") - elif key in ("kernel", "embedding"): - parts.append("weight") - elif key in ("lora_A", "lora_B"): - parts.extend([key, "weight"]) - else: - parts.append(key) - return ".".join(parts) - - -def _load_hf_tensor(tensors: dict, key: str, target_shape: tuple, num_experts: int | None) -> np.ndarray: - """Load tensor from HF format, handling experts, transpose, and reshape.""" - # Handle MoE expert weights (HF stores each expert separately) - if ".experts." in key and num_experts: - tensor = np.stack([tensors[key.replace(".experts.", f".experts.{i}.")].T for i in range(num_experts)], axis=0) - else: - tensor = tensors[key] - if "embed_tokens" not in key: - tensor = tensor.T - - # Reshape attention projections to match model's grouped head format - if any(p in key for p in ("q_proj", "k_proj", "v_proj", "o_proj")): - tensor = tensor.reshape(target_shape) - - return tensor +def get_param_key(path: tuple, prefix: str = "") -> str: + "Get the safetensors key for a given model path." + if path[-1] in {"embedding", "kernel"}: + path = (*path[:-1], "weight") + elif path[-1] in {"lora_A", "lora_B"}: + path = (*path, "weight") + return prefix + ".".join(map(str, path)) -def _save_hf_tensor(tensors: dict, key: str, param: np.ndarray, num_experts: int | None) -> None: - """Save tensor to HF format, handling experts, transpose, and reshape.""" - # Handle MoE expert weights - if ".experts." in key and num_experts: - for i in range(num_experts): - tensors[key.replace(".experts.", f".experts.{i}.")] = param[i].T - return - - # Reshape attention projections back to 2D - if any(p in key for p in ("q_proj", "k_proj", "v_proj")): - param = param.reshape(param.shape[0], -1) - elif "o_proj" in key: - param = param.reshape(-1, param.shape[-1]) - - # Transpose to HF format - tensors[key] = param if "embed_tokens" in key else param.T +def get_expert_key(path: tuple, expert_idx: int) -> str: + "Get the safetensors key for an expert weight model path." + path = tuple(s if s != "experts" else f"experts.{expert_idx}" for s in path) + return ".".join(map(str, path)) def load_safetensors( @@ -186,41 +129,33 @@ def load_safetensors( filter_fn: Callable[[tuple], bool] | None = None, ) -> None: """Load safetensors weights into a model with stacked layers.""" + from tx.layers.stacked import unstack_state + tensors = {} for file in Path(checkpoint_dir).glob("*.safetensors"): tensors.update(safetensors.numpy.load_file(file)) tensors = {k.removeprefix(prefix): v for k, v in tensors.items()} - num_experts = config.get_num_experts() - model_params = nnx.to_flat_state(nnx.state(model)) - updates = [] - - for path, param in model_params: + # unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths + # (layers.0.xxx) with ArrayRef write-through, matching checkpoint key format + for path, param in nnx.to_flat_state(unstack_state(model)): if filter_fn is not None and not filter_fn(path): continue - - path_keys = [p.key if hasattr(p, "key") else str(p) for p in path] - if skip_lora and any(k in path_keys for k in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")): + key = get_param_key(path) + # Skip LoRA parameters if requested + if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): continue - - if is_stacked_lora_path(path): - # Stack per-layer weights from HF format - # Infer layer count from param shape and get offset for split stacked layers - stacked_layer_count = param.shape[0] - _, layer_offset = _get_layer_group_info(path, config) - stacked_tensor = np.empty(param.shape, dtype=param.dtype) - for i in range(stacked_layer_count): - key = _path_to_hf_key(path, layer_idx=layer_offset + i) - stacked_tensor[i] = _load_hf_tensor(tensors, key, param.shape[1:], num_experts) + if "experts" in path: + tensor = np.stack( + [tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0 + ) else: - # Non-stacked layers or non-layer params - key = _path_to_hf_key(path) - stacked_tensor = _load_hf_tensor(tensors, key, param.shape, num_experts) - - assert param.shape == stacked_tensor.shape, f"Shape mismatch for {path}" - updates.append((path, jax.device_put(stacked_tensor.astype(param.dtype), param.sharding))) - - nnx.update(model, nnx.from_flat_state(updates)) + tensor = tensors[key] if "embed_tokens" in key else tensors[key].T + if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: + tensor = tensor.reshape(param.shape) + assert param.shape == tensor.shape, f"shape mismatch for {key}" + # ArrayRef.set_raw_value writes through to the stacked parent variable + param.set_raw_value(jax.device_put(tensor.astype(param.dtype), param.sharding)) def save_safetensors( @@ -231,29 +166,26 @@ def save_safetensors( filter_fn: Callable[[tuple], bool] | None = None, ) -> None: """Save model weights to safetensors, unstacking layer weights for HF compatibility.""" - num_experts = config.get_num_experts() - model_params = nnx.to_flat_state(nnx.state(model)) - tensors = {} + from tx.layers.stacked import unstack_state - for path, param in model_params: - path_keys = [p.key if hasattr(p, "key") else str(p) for p in path] - if "rngs" in path_keys: + # unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths + # (layers.0.xxx) matching the checkpoint key format used by HuggingFace + tensors = {} + for path, param in nnx.to_flat_state(unstack_state(model)): + if "rngs" in path: continue if filter_fn is not None and not filter_fn(path): continue - - if is_stacked_lora_path(path): - # Unstack and save as individual layer weights - # Infer layer count from param shape and get offset for split stacked layers - stacked_layer_count = param.shape[0] - _, layer_offset = _get_layer_group_info(path, config) - for i in range(stacked_layer_count): - key = prefix + _path_to_hf_key(path, layer_idx=layer_offset + i) - _save_hf_tensor(tensors, key, param[i], num_experts) - else: - # Non-stacked layers or non-layer params - key = prefix + _path_to_hf_key(path) - _save_hf_tensor(tensors, key, param, num_experts) + key = get_param_key(path, prefix=prefix) + if "experts" in path: + for i in range(config.get_num_experts()): + tensors[get_expert_key(path, i)] = param[i, :, :].T + continue + if "q_proj" in path or "k_proj" in path or "v_proj" in path: + param = param.reshape(param.shape[0], -1) + elif "o_proj" in path: + param = param.reshape(-1, param.shape[-1]) + tensors[key] = param if "embed_tokens" in path else param.T # In multi-host mode, gather all shards and only save from rank 0 if jax.process_count() > 1: From 6b63f49abd83afd5f8122c591bb237ecb48da237 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 5 Feb 2026 12:57:08 -0800 Subject: [PATCH 115/133] Improve unstack_state to support hybrid layer architectures MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enhanced unstack_state() to handle models with multiple StackedDecoderLayers (e.g., DeepSeekV3's dense + MoE layers) by using model-provided ordering. Key changes: - Models can optionally provide get_stacked_layers_list() to specify layer ordering - unstack_state() assigns sequential checkpoint indices across all stacks - DeepSeekV3: dense_layers[0] → layers.0, moe_layers[0] → layers.1, etc. - Llama3/Qwen3: fallback to simple per-stack numbering (no method needed) - Added ArrayRef.__setitem__ for write-through support - Fixed test_qwen3_lora to access _stacked for LoRA parameters Results: 40/42 tests passing (95.2%) - 2 pre-existing failures: Qwen3 MoE numerical mismatch Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tests/models/test_qwen3.py | 3 +- skyrl-tx/tx/layers/stacked.py | 60 ++++++++++++++++++++++++----- skyrl-tx/tx/models/deepseekv3.py | 13 +++++++ 3 files changed, 66 insertions(+), 10 deletions(-) diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index dcf2680b9..cf2316e2c 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -268,6 +268,7 @@ def test_qwen3_lora(): ) # Load layer LoRA weights (stacked format) + # Access _stacked to get the stacked module with LoRA parameters for i in range(config.num_hidden_layers): hf_layer = hf_model.base_model.model.model.layers[i] for module_name, projections in [ @@ -276,7 +277,7 @@ def test_qwen3_lora(): ]: for proj_name in projections: hf_proj = getattr(getattr(hf_layer, module_name), proj_name) - jax_proj = getattr(getattr(model.model.layers, module_name), proj_name) + jax_proj = getattr(getattr(model.model.layers._stacked, module_name), proj_name) load_stacked_lora_weights( jax_proj, layer_idx=i, diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index f54ba14a5..b4cbdfce2 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -23,6 +23,18 @@ def __getitem__(self, key): parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") return parent[idx] if key is Ellipsis else parent[idx][key] + def __setitem__(self, key, value): + """Write through to parent when value is set via indexing.""" + parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") + if key is Ellipsis: + # param[...] = value -> update entire slice + parent[...] = parent[...].at[idx].set(value) + else: + # param[key] = value -> update sub-slice + parent[...] = parent[...].at[idx][key].set(value) + # Also update our local value + super().__setitem__(key, value) + def set_raw_value(self, value, **kwargs): """Write through to parent when value is set.""" parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") @@ -218,10 +230,13 @@ def body_fn(carry, layer_params): def unstack_state(module: nnx.Module) -> nnx.GraphState: """Transform stacked layer state to unstacked ArrayRef views. - Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc. - Each entry is an ArrayRef that writes through to the original stacked variable. + Converts paths like `dense_layers._stacked.xxx` or `layers._stacked.xxx` to + `layers.0.xxx`, `layers.1.xxx`, etc. Each entry is an ArrayRef that writes + through to the original stacked variable. - This is useful for checkpoint loading where weights are stored per-layer. + For models with multiple StackedDecoderLayers (e.g., DeepSeek with dense + MoE), + the model can provide get_stacked_layers_list() to specify ordering. Otherwise, + falls back to simple per-stack numbering. Args: module: Module containing StackedDecoderLayers. @@ -229,15 +244,42 @@ def unstack_state(module: nnx.Module) -> nnx.GraphState: Returns: GraphState with unstacked paths and ArrayRef views. """ + # Build mapping: StackedDecoderLayers object id → starting checkpoint index + checkpoint_mapping = {} + + if hasattr(module, "model") and hasattr(module.model, "get_stacked_layers_list"): + # Model provides explicit ordering - use sequential checkpoint indices + counter = 0 + for stacked_layers in module.model.get_stacked_layers_list(): + checkpoint_mapping[id(stacked_layers)] = counter + counter += len(stacked_layers) + expanded = [] - for path, var in nnx.to_flat_state(nnx.state(module)): + for path, param in nnx.to_flat_state(nnx.state(module)): if "_stacked" not in path: - expanded.append((path, var)) + expanded.append((path, param)) continue - idx = path.index("_stacked") - for i in range(var[...].shape[0]): - new_path = path[:idx] + (str(i),) + path[idx + 1 :] - expanded.append((new_path, ArrayRef(var, i))) + stacked_idx = path.index("_stacked") + + # Find the StackedDecoderLayers object this parameter belongs to + stacked_layers = module + for key in path[:stacked_idx]: + stacked_layers = getattr(stacked_layers, key) + + if id(stacked_layers) in checkpoint_mapping: + # Use checkpoint mapping - replace attribute name with "layers" + start_idx = checkpoint_mapping[id(stacked_layers)] + # Path: ("model", "dense_layers", "_stacked", ...) → ("model", "layers", "0", ...) + base_path = path[:stacked_idx-1] + ("layers",) + for layer_idx in range(stacked_layers.num_layers): + checkpoint_idx = start_idx + layer_idx + new_path = base_path + (str(checkpoint_idx),) + path[stacked_idx+1:] + expanded.append((new_path, ArrayRef(param, layer_idx))) + else: + # Fallback: simple numbering within the same attribute + for layer_idx in range(param[...].shape[0]): + new_path = path[:stacked_idx] + (str(layer_idx),) + path[stacked_idx+1:] + expanded.append((new_path, ArrayRef(param, layer_idx))) return nnx.from_flat_state(expanded) diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 8d01855f2..693c8eb4c 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -505,6 +505,19 @@ def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + def get_stacked_layers_list(self): + """Return ordered list of StackedDecoderLayers for checkpoint loading. + + Returns dense layers first (checkpoint indices 0 to first_k-1), + then MoE layers (checkpoint indices first_k to num_layers-1). + """ + result = [] + if self.dense_layers is not None: + result.append(self.dense_layers) + if self.moe_layers is not None: + result.append(self.moe_layers) + return result + def __call__( self, input_ids: jax.Array, From ae2c8cb57c5150f506065b05c20897ba81f6494f Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 5 Feb 2026 13:05:12 -0800 Subject: [PATCH 116/133] minor updates --- skyrl-tx/tx/layers/stacked.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index b4cbdfce2..3a36f85be 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -230,9 +230,11 @@ def body_fn(carry, layer_params): def unstack_state(module: nnx.Module) -> nnx.GraphState: """Transform stacked layer state to unstacked ArrayRef views. - Converts paths like `dense_layers._stacked.xxx` or `layers._stacked.xxx` to - `layers.0.xxx`, `layers.1.xxx`, etc. Each entry is an ArrayRef that writes - through to the original stacked variable. + Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc. + Each entry is an ArrayRef that writes through to the original stacked variable. + + This is useful for checkpoint loading where weights are stored per-layer. + For models with multiple StackedDecoderLayers (e.g., DeepSeek with dense + MoE), the model can provide get_stacked_layers_list() to specify ordering. Otherwise, @@ -252,7 +254,7 @@ def unstack_state(module: nnx.Module) -> nnx.GraphState: counter = 0 for stacked_layers in module.model.get_stacked_layers_list(): checkpoint_mapping[id(stacked_layers)] = counter - counter += len(stacked_layers) + counter += stacked_layers.num_layers expanded = [] for path, param in nnx.to_flat_state(nnx.state(module)): @@ -266,6 +268,7 @@ def unstack_state(module: nnx.Module) -> nnx.GraphState: stacked_layers = module for key in path[:stacked_idx]: stacked_layers = getattr(stacked_layers, key) + assert isinstance(stacked_layers, StackedDecoderLayers) if id(stacked_layers) in checkpoint_mapping: # Use checkpoint mapping - replace attribute name with "layers" From 426ad87b22fcfafd2bb11966d7ecc550ec091361 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 5 Feb 2026 15:54:03 -0800 Subject: [PATCH 117/133] support 0 layers --- skyrl-tx/tx/layers/stacked.py | 11 ++++- skyrl-tx/tx/models/deepseekv3.py | 75 ++++++++++++-------------------- 2 files changed, 38 insertions(+), 48 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 3a36f85be..ec93583bb 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -69,11 +69,16 @@ def __init__( Args: create_layer_fn: Function that takes rngs and returns a single layer module. - num_layers: Number of layers to create. + num_layers: Number of layers to create. Can be 0 for empty layer stack. rngs: Random number generators for initialization. """ self.num_layers = num_layers + # Handle empty layer case + if num_layers == 0: + self._stacked = None + return + layer_keys = jax.random.split(rngs.params(), num_layers) mesh = jax.sharding.get_mesh() @@ -167,7 +172,9 @@ def __call__( Tuple of (final_hidden_states, all_hidden_states, kv_cache). kv_cache is None when is_training=True. """ - assert self.num_layers > 0, "num_layers must be positive" + # Handle empty layer case - pass through inputs unchanged + if self.num_layers == 0: + return hidden_states, [], kv_cache graphdef, state = nnx.split(self._stacked) is_decode = kv_cache is not None diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 693c8eb4c..d0692cfeb 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -484,24 +484,16 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs ) # Create stacked dense layers (layers 0 to first_k_dense_replace - 1) - if self.num_dense_layers > 0: + def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: + return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MLP, dtype=dtype, rngs=rngs) - def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: - return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MLP, dtype=dtype, rngs=rngs) - - self.dense_layers = StackedDecoderLayers(create_dense_layer, self.num_dense_layers, rngs) - else: - self.dense_layers = None + self.dense_layers = StackedDecoderLayers(create_dense_layer, self.num_dense_layers, rngs) # Create stacked MoE layers (layers first_k_dense_replace to num_hidden_layers - 1) - if self.num_moe_layers > 0: + def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: + return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MoE, dtype=dtype, rngs=rngs) - def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: - return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MoE, dtype=dtype, rngs=rngs) - - self.moe_layers = StackedDecoderLayers(create_moe_layer, self.num_moe_layers, rngs) - else: - self.moe_layers = None + self.moe_layers = StackedDecoderLayers(create_moe_layer, self.num_moe_layers, rngs) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) @@ -511,12 +503,7 @@ def get_stacked_layers_list(self): Returns dense layers first (checkpoint indices 0 to first_k-1), then MoE layers (checkpoint indices first_k to num_layers-1). """ - result = [] - if self.dense_layers is not None: - result.append(self.dense_layers) - if self.moe_layers is not None: - result.append(self.moe_layers) - return result + return [self.dense_layers, self.moe_layers] def __call__( self, @@ -543,34 +530,30 @@ def __call__( dense_kv_cache, moe_kv_cache = kv_cache.split(self.num_dense_layers) # Forward through dense layers - dense_kv_result = None - if self.dense_layers is not None: - hidden_states, dense_hidden_states, dense_kv_result = self.dense_layers( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=dense_kv_cache, - output_hidden_states=output_hidden_states, - gradient_checkpointing=self.config.gradient_checkpointing, - is_training=is_training, - ) - all_hidden_states.extend(dense_hidden_states) + hidden_states, dense_hidden_states, dense_kv_result = self.dense_layers( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=dense_kv_cache, + output_hidden_states=output_hidden_states, + gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, + ) + all_hidden_states.extend(dense_hidden_states) # Forward through MoE layers - moe_kv_result = None - if self.moe_layers is not None: - hidden_states, moe_hidden_states, moe_kv_result = self.moe_layers( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=moe_kv_cache, - output_hidden_states=output_hidden_states, - gradient_checkpointing=self.config.gradient_checkpointing, - is_training=is_training, - ) - all_hidden_states.extend(moe_hidden_states) + hidden_states, moe_hidden_states, moe_kv_result = self.moe_layers( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=moe_kv_cache, + output_hidden_states=output_hidden_states, + gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, + ) + all_hidden_states.extend(moe_hidden_states) hidden_states = self.norm(hidden_states) if output_hidden_states: From 2932c4b5562eba59c518da652e9871716757a23c Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 5 Feb 2026 17:17:58 -0800 Subject: [PATCH 118/133] Add MultiStackedDecoderLayers for heterogeneous layer architectures Introduces MultiStackedDecoderLayers to cleanly handle models like DeepSeek that combine different layer types (dense MLP + MoE). Key improvements: - Eliminates conditional checks in DeepSeekV3 model code - Adds KV cache offset parameter to avoid split/concatenate in decode mode - Delegates unstack_state logic to layer objects for cleaner separation - Supports multiple KVCache.concatenate arguments for prefill mode This simplifies both the model code and checkpoint loading logic. Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/layers/stacked.py | 234 +++++++++++++++++++++++++------ skyrl-tx/tx/models/deepseekv3.py | 51 ++----- skyrl-tx/tx/utils/generator.py | 27 ++-- 3 files changed, 218 insertions(+), 94 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index ec93583bb..53d0de2f9 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -141,6 +141,36 @@ def __iter__(self): for i in range(self.num_layers): yield self[i] + def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tuple[tuple, ArrayRef]]: + """Transform _stacked paths to per-layer paths with ArrayRef. + + Args: + state: GraphState containing this module's state. + base_path: Path prefix to this module (e.g., ("model", "layers")). + + Returns: + List of (path, ArrayRef) tuples for unstacked parameters. + """ + result = [] + for path, param in nnx.to_flat_state(state): + # Only process paths belonging to this module + if not path[:len(base_path)] == base_path: + continue + # Only process _stacked paths + if "_stacked" not in path[len(base_path):]: + continue + + # Find _stacked in the relative path + rel_path = path[len(base_path):] + stacked_idx = rel_path.index("_stacked") + + # Create per-layer paths: base_path + (layer_idx,) + rest + for layer_idx in range(self.num_layers): + new_path = base_path + (str(layer_idx),) + rel_path[stacked_idx+1:] + result.append((new_path, ArrayRef(param, layer_idx))) + + return result + def __call__( self, hidden_states: jax.Array, @@ -149,6 +179,7 @@ def __call__( positions: jax.Array, adapter_indices: jax.Array | None, kv_cache: KVCache | None, + kv_cache_offset: int = 0, output_hidden_states: bool, gradient_checkpointing: bool, is_training: bool = False, @@ -164,6 +195,8 @@ def __call__( positions: Position indices of shape (batch, seq). adapter_indices: Optional LoRA adapter indices of shape (batch,). kv_cache: Optional KV cache for decode mode (None for prefill). + kv_cache_offset: Layer offset into the KV cache. Used when multiple + StackedDecoderLayers share the same cache. output_hidden_states: Whether to return intermediate hidden states. gradient_checkpointing: Whether to use gradient checkpointing. is_training: Whether in training mode. Skips KV cache to save memory. @@ -183,8 +216,9 @@ def body_fn(carry, layer_params): hs, cache_keys, cache_values, layer_idx = carry # Extract layer's cache slice if available + cache_idx = kv_cache_offset + layer_idx if cache_keys is not None: - layer_kv = (cache_keys[layer_idx], cache_values[layer_idx]) + layer_kv = (cache_keys[cache_idx], cache_values[cache_idx]) else: layer_kv = None @@ -202,8 +236,8 @@ def body_fn(carry, layer_params): # Update cache in carry if present (decode), otherwise accumulate outputs (prefill) if cache_keys is not None: - cache_keys = cache_keys.at[layer_idx].set(k) - cache_values = cache_values.at[layer_idx].set(v) + cache_keys = cache_keys.at[cache_idx].set(k) + cache_values = cache_values.at[cache_idx].set(v) k = v = None # Don't accumulate in output - cache is in carry elif is_training: k = v = None @@ -234,6 +268,152 @@ def body_fn(carry, layer_params): return final_hs, all_hidden_states, new_kv_cache +class MultiStackedDecoderLayers(nnx.Module): + """Multiple StackedDecoderLayers groups with unified interface. + + This allows models like DeepSeek to have different layer types (dense/MoE) + while presenting a unified interface for forward passes and checkpointing. + """ + + def __init__(self, *layer_groups: StackedDecoderLayers): + """Create multi-stacked decoder layers. + + Args: + *layer_groups: One or more StackedDecoderLayers to combine. + """ + self.layer_groups = nnx.List(layer_groups) + self.num_layers = sum(group.num_layers for group in self.layer_groups) + + def __len__(self) -> int: + """Return the total number of layers across all groups.""" + return self.num_layers + + def __getitem__(self, index: int) -> nnx.Module: + """Get view into layer at global index (across all groups).""" + if index < 0 or index >= self.num_layers: + raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") + + # Find which group contains this index + offset = 0 + for group in self.layer_groups: + if index < offset + group.num_layers: + return group[index - offset] + offset += group.num_layers + + raise IndexError(f"Layer index {index} not found") + + def __iter__(self): + """Iterate over all layers across all groups.""" + for group in self.layer_groups: + yield from group + + def get_stacked_layers_list(self) -> list[StackedDecoderLayers]: + """Return list of StackedDecoderLayers for checkpoint loading.""" + return list(self.layer_groups) + + def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tuple[tuple, ArrayRef]]: + """Transform _stacked paths from all groups to unified per-layer paths. + + Args: + state: GraphState containing this module's state. + base_path: Path prefix to this module (e.g., ("model", "layers")). + + Returns: + List of (path, ArrayRef) tuples for unstacked parameters. + """ + result = [] + checkpoint_idx = 0 + + for i, group in enumerate(self.layer_groups): + # Path to this group: base_path + ("layer_groups", i) + group_path = base_path + ("layer_groups", i) + + # Get unstacked paths from the group + for path, array_ref in group.unstack_paths(state, group_path): + # Extract layer index from path: group_path + (layer_idx,) + rest + layer_idx = int(path[len(group_path)]) + # New path: base_path + (checkpoint_idx + layer_idx,) + rest + new_path = base_path + (str(checkpoint_idx + layer_idx),) + path[len(group_path)+1:] + result.append((new_path, array_ref)) + + checkpoint_idx += group.num_layers + + return result + + def _forward_group( + self, + group: StackedDecoderLayers, + hidden_states: jax.Array, + layer_offset: int, + kv_cache: KVCache | None, + is_decode: bool, + **kwargs, + ) -> tuple[jax.Array, list[jax.Array], KVCache | None]: + """Forward through a single layer group with appropriate cache handling.""" + return group( + hidden_states, + kv_cache=kv_cache if is_decode else None, + kv_cache_offset=layer_offset if is_decode else 0, + **kwargs, + ) + + def __call__( + self, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + kv_cache: KVCache | None, + output_hidden_states: bool, + gradient_checkpointing: bool, + is_training: bool = False, + ) -> tuple[jax.Array, list[jax.Array], KVCache | None]: + """Forward pass through all layer groups. + + Args: + hidden_states: Input hidden states of shape (batch, seq, hidden). + attention_mask: Attention mask of shape (batch, seq). + positions: Position indices of shape (batch, seq). + adapter_indices: Optional LoRA adapter indices of shape (batch,). + kv_cache: Optional KV cache for decode mode (None for prefill). + output_hidden_states: Whether to return intermediate hidden states. + gradient_checkpointing: Whether to use gradient checkpointing. + is_training: Whether in training mode. Skips KV cache to save memory. + + Returns: + Tuple of (final_hidden_states, all_hidden_states, kv_cache). + """ + all_hidden_states: list[jax.Array] = [] + is_decode = kv_cache is not None + layer_offset = 0 + kv_results = [] + + for group in self.layer_groups: + hidden_states, layer_hidden_states, layer_kv_cache = self._forward_group( + group, + hidden_states, + layer_offset, + kv_cache, + is_decode, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + output_hidden_states=output_hidden_states, + gradient_checkpointing=gradient_checkpointing, + is_training=is_training, + ) + all_hidden_states.extend(layer_hidden_states) + kv_cache = layer_kv_cache if is_decode else kv_cache + kv_results.append(layer_kv_cache) + layer_offset += group.num_layers + + if not is_decode and kv_results: + kv_cache = KVCache.concatenate(*kv_results) + + return hidden_states, all_hidden_states, kv_cache + + def unstack_state(module: nnx.Module) -> nnx.GraphState: """Transform stacked layer state to unstacked ArrayRef views. @@ -242,54 +422,24 @@ def unstack_state(module: nnx.Module) -> nnx.GraphState: This is useful for checkpoint loading where weights are stored per-layer. - - For models with multiple StackedDecoderLayers (e.g., DeepSeek with dense + MoE), - the model can provide get_stacked_layers_list() to specify ordering. Otherwise, - falls back to simple per-stack numbering. - Args: module: Module containing StackedDecoderLayers. Returns: GraphState with unstacked paths and ArrayRef views. """ - # Build mapping: StackedDecoderLayers object id → starting checkpoint index - checkpoint_mapping = {} + state = nnx.state(module) + expanded = [] - if hasattr(module, "model") and hasattr(module.model, "get_stacked_layers_list"): - # Model provides explicit ordering - use sequential checkpoint indices - counter = 0 - for stacked_layers in module.model.get_stacked_layers_list(): - checkpoint_mapping[id(stacked_layers)] = counter - counter += stacked_layers.num_layers + # Delegate to layers if they support unstacking + if hasattr(module, "model") and hasattr(module.model, "layers"): + layers = module.model.layers + if isinstance(layers, (StackedDecoderLayers, MultiStackedDecoderLayers)): + expanded.extend(layers.unstack_paths(state, base_path=("model", "layers"))) - expanded = [] - for path, param in nnx.to_flat_state(nnx.state(module)): + # Keep all non-stacked paths as-is + for path, param in nnx.to_flat_state(state): if "_stacked" not in path: expanded.append((path, param)) - continue - - stacked_idx = path.index("_stacked") - - # Find the StackedDecoderLayers object this parameter belongs to - stacked_layers = module - for key in path[:stacked_idx]: - stacked_layers = getattr(stacked_layers, key) - assert isinstance(stacked_layers, StackedDecoderLayers) - - if id(stacked_layers) in checkpoint_mapping: - # Use checkpoint mapping - replace attribute name with "layers" - start_idx = checkpoint_mapping[id(stacked_layers)] - # Path: ("model", "dense_layers", "_stacked", ...) → ("model", "layers", "0", ...) - base_path = path[:stacked_idx-1] + ("layers",) - for layer_idx in range(stacked_layers.num_layers): - checkpoint_idx = start_idx + layer_idx - new_path = base_path + (str(checkpoint_idx),) + path[stacked_idx+1:] - expanded.append((new_path, ArrayRef(param, layer_idx))) - else: - # Fallback: simple numbering within the same attribute - for layer_idx in range(param[...].shape[0]): - new_path = path[:stacked_idx] + (str(layer_idx),) + path[stacked_idx+1:] - expanded.append((new_path, ArrayRef(param, layer_idx))) return nnx.from_flat_state(expanded) diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index d0692cfeb..00b28ffe4 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -7,7 +7,7 @@ from tx.layers.rotary_embedding import get_rope from tx.layers.util import Param, prepare_routing, shard_map_ep from tx.layers.layernorm import RMSNorm -from tx.layers.stacked import StackedDecoderLayers +from tx.layers.stacked import MultiStackedDecoderLayers, StackedDecoderLayers from tx.models.configs import DeepseekV3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache @@ -483,27 +483,22 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs rngs=rngs, ) - # Create stacked dense layers (layers 0 to first_k_dense_replace - 1) + # Create stacked layers: dense layers followed by MoE layers def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MLP, dtype=dtype, rngs=rngs) - self.dense_layers = StackedDecoderLayers(create_dense_layer, self.num_dense_layers, rngs) - - # Create stacked MoE layers (layers first_k_dense_replace to num_hidden_layers - 1) def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MoE, dtype=dtype, rngs=rngs) - self.moe_layers = StackedDecoderLayers(create_moe_layer, self.num_moe_layers, rngs) + dense_layers = StackedDecoderLayers(create_dense_layer, self.num_dense_layers, rngs) + moe_layers = StackedDecoderLayers(create_moe_layer, self.num_moe_layers, rngs) + self.layers = MultiStackedDecoderLayers(dense_layers, moe_layers) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) def get_stacked_layers_list(self): - """Return ordered list of StackedDecoderLayers for checkpoint loading. - - Returns dense layers first (checkpoint indices 0 to first_k-1), - then MoE layers (checkpoint indices first_k to num_layers-1). - """ - return [self.dense_layers, self.moe_layers] + """Delegate to MultiStackedDecoderLayers for checkpoint loading.""" + return self.layers.get_stacked_layers_list() def __call__( self, @@ -521,50 +516,26 @@ def __call__( ) hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - all_hidden_states: list[jax.Array] = [] - - # Split KV cache for dense and MoE layers - dense_kv_cache = None - moe_kv_cache = None - if kv_cache is not None: - dense_kv_cache, moe_kv_cache = kv_cache.split(self.num_dense_layers) - - # Forward through dense layers - hidden_states, dense_hidden_states, dense_kv_result = self.dense_layers( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=dense_kv_cache, - output_hidden_states=output_hidden_states, - gradient_checkpointing=self.config.gradient_checkpointing, - is_training=is_training, - ) - all_hidden_states.extend(dense_hidden_states) - # Forward through MoE layers - hidden_states, moe_hidden_states, moe_kv_result = self.moe_layers( + # Forward through all layers + hidden_states, all_hidden_states, kv_cache = self.layers( hidden_states, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, - kv_cache=moe_kv_cache, + kv_cache=kv_cache, output_hidden_states=output_hidden_states, gradient_checkpointing=self.config.gradient_checkpointing, is_training=is_training, ) - all_hidden_states.extend(moe_hidden_states) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states.append(hidden_states) - # Merge KV caches from dense and MoE layers - new_kv_cache = KVCache.concatenate(dense_kv_result, moe_kv_result) - return ModelOutput( last_hidden_state=hidden_states, - kv_cache=new_kv_cache, + kv_cache=kv_cache, hidden_states=all_hidden_states if output_hidden_states else None, ) diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index 6c0651991..5539ab3f2 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -146,24 +146,27 @@ def split(self, layer_idx: int) -> tuple[KVCache | None, KVCache | None]: return first, second @staticmethod - def concatenate(first: KVCache | None, second: KVCache | None) -> KVCache | None: - """Concatenate two caches along the layer dimension. + def concatenate(*caches: KVCache | None) -> KVCache | None: + """Concatenate multiple caches along the layer dimension. Args: - first: First cache (earlier layers), or None. - second: Second cache (later layers), or None. + *caches: KVCache objects to concatenate, or None values to skip. Returns: - Combined KVCache, or the non-None input, or None if both are None. + Combined KVCache, or None if all inputs are None. """ - if first is None: - return second - if second is None: - return first + # Filter out None values + non_none_caches = [c for c in caches if c is not None] + + if len(non_none_caches) == 0: + return None + if len(non_none_caches) == 1: + return non_none_caches[0] + return KVCache( - keys=jnp.concatenate([first.keys, second.keys], axis=0), - values=jnp.concatenate([first.values, second.values], axis=0), - cache_position=second.cache_position, + keys=jnp.concatenate([c.keys for c in non_none_caches], axis=0), + values=jnp.concatenate([c.values for c in non_none_caches], axis=0), + cache_position=non_none_caches[-1].cache_position, ) From 5b368a93859333ae1bca1d44833689571d94cdc4 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 5 Feb 2026 17:21:37 -0800 Subject: [PATCH 119/133] Use split/concatenate for KV cache instead of offset approach Simplifies MultiStackedDecoderLayers by using KVCache.split() to divide the cache among layer groups, then concatenating results. This is cleaner than the offset approach since concatenation is inevitable anyway. - Removed kv_cache_offset parameter from StackedDecoderLayers - Updated KVCache.split() to support multiple split points - MultiStackedDecoderLayers now splits cache upfront and concatenates after Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/layers/stacked.py | 60 +++++++++++++--------------------- skyrl-tx/tx/utils/generator.py | 49 ++++++++++++++------------- 2 files changed, 46 insertions(+), 63 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 53d0de2f9..265f0f4d9 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -179,7 +179,6 @@ def __call__( positions: jax.Array, adapter_indices: jax.Array | None, kv_cache: KVCache | None, - kv_cache_offset: int = 0, output_hidden_states: bool, gradient_checkpointing: bool, is_training: bool = False, @@ -195,8 +194,6 @@ def __call__( positions: Position indices of shape (batch, seq). adapter_indices: Optional LoRA adapter indices of shape (batch,). kv_cache: Optional KV cache for decode mode (None for prefill). - kv_cache_offset: Layer offset into the KV cache. Used when multiple - StackedDecoderLayers share the same cache. output_hidden_states: Whether to return intermediate hidden states. gradient_checkpointing: Whether to use gradient checkpointing. is_training: Whether in training mode. Skips KV cache to save memory. @@ -216,9 +213,8 @@ def body_fn(carry, layer_params): hs, cache_keys, cache_values, layer_idx = carry # Extract layer's cache slice if available - cache_idx = kv_cache_offset + layer_idx if cache_keys is not None: - layer_kv = (cache_keys[cache_idx], cache_values[cache_idx]) + layer_kv = (cache_keys[layer_idx], cache_values[layer_idx]) else: layer_kv = None @@ -236,8 +232,8 @@ def body_fn(carry, layer_params): # Update cache in carry if present (decode), otherwise accumulate outputs (prefill) if cache_keys is not None: - cache_keys = cache_keys.at[cache_idx].set(k) - cache_values = cache_values.at[cache_idx].set(v) + cache_keys = cache_keys.at[layer_idx].set(k) + cache_values = cache_values.at[layer_idx].set(v) k = v = None # Don't accumulate in output - cache is in carry elif is_training: k = v = None @@ -340,23 +336,6 @@ def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tu return result - def _forward_group( - self, - group: StackedDecoderLayers, - hidden_states: jax.Array, - layer_offset: int, - kv_cache: KVCache | None, - is_decode: bool, - **kwargs, - ) -> tuple[jax.Array, list[jax.Array], KVCache | None]: - """Forward through a single layer group with appropriate cache handling.""" - return group( - hidden_states, - kv_cache=kv_cache if is_decode else None, - kv_cache_offset=layer_offset if is_decode else 0, - **kwargs, - ) - def __call__( self, hidden_states: jax.Array, @@ -385,33 +364,38 @@ def __call__( Tuple of (final_hidden_states, all_hidden_states, kv_cache). """ all_hidden_states: list[jax.Array] = [] - is_decode = kv_cache is not None - layer_offset = 0 - kv_results = [] - for group in self.layer_groups: - hidden_states, layer_hidden_states, layer_kv_cache = self._forward_group( - group, + # Split KV cache for each group + if kv_cache is not None: + split_points = [] + cumsum = 0 + for group in self.layer_groups[:-1]: + cumsum += group.num_layers + split_points.append(cumsum) + kv_caches = kv_cache.split(*split_points) + else: + kv_caches = (None,) * len(self.layer_groups) + + # Forward through each group + kv_results = [] + for group, group_kv_cache in zip(self.layer_groups, kv_caches): + hidden_states, layer_hidden_states, layer_kv_cache = group( hidden_states, - layer_offset, - kv_cache, - is_decode, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, + kv_cache=group_kv_cache, output_hidden_states=output_hidden_states, gradient_checkpointing=gradient_checkpointing, is_training=is_training, ) all_hidden_states.extend(layer_hidden_states) - kv_cache = layer_kv_cache if is_decode else kv_cache kv_results.append(layer_kv_cache) - layer_offset += group.num_layers - if not is_decode and kv_results: - kv_cache = KVCache.concatenate(*kv_results) + # Concatenate KV caches + new_kv_cache = KVCache.concatenate(*kv_results) if kv_results else None - return hidden_states, all_hidden_states, kv_cache + return hidden_states, all_hidden_states, new_kv_cache def unstack_state(module: nnx.Module) -> nnx.GraphState: diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index 5539ab3f2..9ba753d91 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -114,36 +114,35 @@ def seq_len(self) -> int: """Current sequence length.""" return self.keys.shape[2] - def split(self, layer_idx: int) -> tuple[KVCache | None, KVCache | None]: - """Split the cache at a layer index. + def split(self, *layer_indices: int) -> tuple[KVCache | None, ...]: + """Split the cache at one or more layer indices. Args: - layer_idx: Layer index to split at. + *layer_indices: Layer indices to split at. For example, split(3, 7) + creates 3 caches: [0:3), [3:7), [7:end). Returns: - Tuple of (first_cache, second_cache) where first_cache contains - layers [0, layer_idx) and second_cache contains layers [layer_idx, num_layers). - Returns None for empty splits. + Tuple of KVCache objects, one for each segment. Returns None for empty segments. """ - first = ( - None - if layer_idx == 0 - else KVCache( - keys=self.keys[:layer_idx], - values=self.values[:layer_idx], - cache_position=self.cache_position, - ) - ) - second = ( - None - if layer_idx == self.num_layers - else KVCache( - keys=self.keys[layer_idx:], - values=self.values[layer_idx:], - cache_position=self.cache_position, - ) - ) - return first, second + if len(layer_indices) == 0: + return (self,) + + # Build split points: 0, idx1, idx2, ..., num_layers + split_points = [0] + list(layer_indices) + [self.num_layers] + + caches = [] + for start, end in zip(split_points[:-1], split_points[1:]): + if start == end: + caches.append(None) + else: + caches.append( + KVCache( + keys=self.keys[start:end], + values=self.values[start:end], + cache_position=self.cache_position, + ) + ) + return tuple(caches) @staticmethod def concatenate(*caches: KVCache | None) -> KVCache | None: From 85d57dc5426ea0c46ddc5194d50556ddd9e7e0d1 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 5 Feb 2026 17:28:47 -0800 Subject: [PATCH 120/133] Use list KV cache format for decode to enable buffer donation The stacked KV cache format caused XLA to copy the entire cache array on each layer during decode (~16MB per layer). This happened because carrying the cache through jax.lax.scan and updating with .at[idx].set() prevents buffer donation - XLA can't prove the slices are non-overlapping. Changes: - Revert KVCache to list[jax.Array] format (one array per layer) - Use Python loop for decode mode in StackedDecoderLayers (enables per-layer buffer donation) - Keep scan for prefill/training (efficient, no KV cache needed) Performance: decode time improved from ~31s to ~24s (vs ~21s on main). Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/layers/stacked.py | 91 ++++++++++++++++++++------------ skyrl-tx/tx/utils/generator.py | 96 +++++++++++++++------------------- 2 files changed, 99 insertions(+), 88 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 265f0f4d9..d557c8660 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -183,17 +183,17 @@ def __call__( gradient_checkpointing: bool, is_training: bool = False, ) -> tuple[jax.Array, list[jax.Array], KVCache | None]: - """Forward pass through all layers using scan. + """Forward pass through all layers. - Uses jax.lax.scan for all modes (training/prefill/decode). For decode mode, - the KV cache is passed as scan carry for efficient buffer donation. + Uses scan for prefill/training (efficient, no KV cache needed). + Uses Python loop for decode (with list-format KV cache) to enable buffer donation. Args: hidden_states: Input hidden states of shape (batch, seq, hidden). attention_mask: Attention mask of shape (batch, seq). positions: Position indices of shape (batch, seq). adapter_indices: Optional LoRA adapter indices of shape (batch,). - kv_cache: Optional KV cache for decode mode (None for prefill). + kv_cache: Optional KV cache for decode mode (None for prefill). Uses list format. output_hidden_states: Whether to return intermediate hidden states. gradient_checkpointing: Whether to use gradient checkpointing. is_training: Whether in training mode. Skips KV cache to save memory. @@ -209,56 +209,81 @@ def __call__( graphdef, state = nnx.split(self._stacked) is_decode = kv_cache is not None - def body_fn(carry, layer_params): - hs, cache_keys, cache_values, layer_idx = carry + if is_decode: + # Decode mode: Use Python loop with list KV cache for buffer donation. + # We avoid jax.lax.scan here because carrying a stacked KV cache through scan + # and updating it with cache.at[layer_idx].set() causes XLA to copy the entire + # cache array on each layer (16MB per layer). XLA can't prove the buffer can be + # donated since it doesn't know the slices are non-overlapping. With a Python + # loop and list format, each layer's KV array is independent and can be donated. + flat_state, treedef = jax.tree_util.tree_flatten(state) + all_hidden_states: list[jax.Array] = [] + updated_keys: list[jax.Array] = [] + updated_values: list[jax.Array] = [] - # Extract layer's cache slice if available - if cache_keys is not None: - layer_kv = (cache_keys[layer_idx], cache_values[layer_idx]) - else: - layer_kv = None + for layer_idx in range(self.num_layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + # Extract this layer's parameters + layer_params_flat = [p[layer_idx] for p in flat_state] + layer_params = jax.tree_util.tree_unflatten(treedef, layer_params_flat) + layer = nnx.merge(graphdef, layer_params) + + # Get this layer's KV cache + layer_kv = (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]) + + hidden_states, (k, v) = layer( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=layer_kv, + ) + updated_keys.append(k) + updated_values.append(v) + + new_kv_cache = KVCache.update( + kv_cache, updated_keys, updated_values, positions, attention_mask + ) + return hidden_states, all_hidden_states, new_kv_cache - # Forward through layer + # Prefill/training mode: use scan for efficiency + def body_fn(carry, layer_params): + hs = carry + + # Forward through layer (no KV cache input for prefill) layer = nnx.merge(graphdef, layer_params) new_hs, (k, v) = layer( hs, attention_mask=attention_mask, positions=positions, adapter_indices=adapter_indices, - kv_cache=layer_kv, + kv_cache=None, ) hs_output = new_hs if output_hidden_states else None - # Update cache in carry if present (decode), otherwise accumulate outputs (prefill) - if cache_keys is not None: - cache_keys = cache_keys.at[layer_idx].set(k) - cache_values = cache_values.at[layer_idx].set(v) - k = v = None # Don't accumulate in output - cache is in carry - elif is_training: + # Skip KV accumulation in training mode to save memory + if is_training: k = v = None - return (new_hs, cache_keys, cache_values, layer_idx + 1), (hs_output, k, v) + return new_hs, (hs_output, k, v) if gradient_checkpointing: body_fn = jax.checkpoint(body_fn) - cache_keys = kv_cache.keys if kv_cache else None - cache_values = kv_cache.values if kv_cache else None - init_carry = (hidden_states, cache_keys, cache_values, 0) - - (final_hs, final_keys, final_values, _), (all_hs, all_keys, all_values) = jax.lax.scan( - body_fn, init_carry, state + final_hs, (all_hs, all_keys, all_values) = jax.lax.scan( + body_fn, hidden_states, state ) - if is_decode: - new_kv_cache = KVCache( - keys=final_keys, - values=final_values, - cache_position=kv_cache.cache_position + positions.shape[1], - ) + if is_training: + new_kv_cache = None else: - new_kv_cache = None if is_training else KVCache.from_layer_outputs(all_keys, all_values, attention_mask) + # Convert stacked scan outputs to list format + keys_list = [all_keys[i] for i in range(self.num_layers)] + values_list = [all_values[i] for i in range(self.num_layers)] + new_kv_cache = KVCache.update(None, keys_list, values_list, positions, attention_mask) all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] return final_hs, all_hidden_states, new_kv_cache diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index 9ba753d91..37b6f7b54 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -1,4 +1,4 @@ -"""Generator mixin for autoregressive text generation with stacked KV caching.""" +"""Generator mixin for autoregressive text generation with KV caching.""" from __future__ import annotations from dataclasses import dataclass @@ -14,59 +14,49 @@ @jax.tree_util.register_dataclass @dataclass class KVCache: - """Key-value cache for all layers in stacked format. + """Key-value cache for all layers, each entry in the list corresponds to one layer.""" - Attributes: - keys: Stacked key cache of shape (num_layers, batch, seq, num_kv_heads, head_dim). - values: Stacked value cache of shape (num_layers, batch, seq, num_kv_heads, head_dim). - cache_position: Per-sequence positions of shape (batch,) for left-aligned decoding. - """ - - keys: jax.Array # (num_layers, batch, seq, num_kv_heads, head_dim) - values: jax.Array # (num_layers, batch, seq, num_kv_heads, head_dim) - cache_position: jax.Array # (batch,) + keys: list[jax.Array] # list of (batch, seq, num_kv_heads, head_dim) per layer + values: list[jax.Array] # list of (batch, seq, num_kv_heads, head_dim) per layer + cache_position: jax.Array # Per-sequence positions of shape (batch,) @staticmethod - def from_layer_outputs( - keys: jax.Array, - values: jax.Array, + def update( + kv_cache: KVCache | None, + keys: list[jax.Array], + values: list[jax.Array], + positions: jax.Array, attention_mask: jax.Array, ) -> KVCache: - """Create KVCache from stacked layer outputs after prefill. + """Create an updated KVCache with computed cache positions for left-aligned decoding. Args: - keys: Stacked keys of shape (num_layers, batch, seq, num_kv_heads, head_dim). - values: Stacked values of shape (num_layers, batch, seq, num_kv_heads, head_dim). - attention_mask: Attention mask of shape (batch, seq). + kv_cache: Existing KVCache (None during prefill). + keys: List of key arrays per layer. + values: List of value arrays per layer. + positions: Position indices with shape [B, seq_len]. + attention_mask: Attention mask with shape [B, seq_len]. Returns: New KVCache with computed cache_position. """ - # Prefill: next position is the sequence length (number of real tokens) - cache_position = attention_mask.sum(axis=1).astype(jnp.int32) + if kv_cache is not None: + # Decode: next position is current position + 1 + cache_position = positions[:, 0] + 1 + else: + # Prefill: next position is the sequence length (number of real tokens) + cache_position = attention_mask.sum(axis=1) return KVCache(keys=keys, values=values, cache_position=cache_position) @staticmethod - def update_layer( - kv_cache: tuple[jax.Array, jax.Array], - k: jax.Array, - v: jax.Array, - positions: jax.Array, - ) -> tuple[jax.Array, jax.Array]: - """Update a single layer's KV cache at the given positions. - - This is called from within the scan body to update a single layer's cache. - The layer index is handled by the caller (indexing into stacked cache). + def update_layer(kv_cache, k, v, positions): + """Update a single layer's KV cache at the given positions (for left-aligned decoding). Args: - kv_cache: Tuple of (k_cache, v_cache) for this layer. - Each has shape (batch, seq, num_kv_heads, head_dim). - k: New key values of shape (batch, seq_len, num_kv_heads, head_dim). - v: New value values of shape (batch, seq_len, num_kv_heads, head_dim). - positions: Position indices of shape (batch, seq_len). - - Returns: - Updated (k_cache, v_cache) tuple with new values at positions. + kv_cache: Tuple of (k_cache, v_cache) arrays for this layer. + k: New key values with shape [B, seq_len, num_heads, head_dim]. + v: New value values with shape [B, seq_len, num_heads, head_dim]. + positions: Position indices with shape [B, seq_len]. """ k_cache, v_cache = kv_cache @@ -78,41 +68,37 @@ def update_at_pos(cache_slice, new_val_slice, pos): return k, v def pad_to_length(self, max_length: int) -> KVCache: - """Pad KV cache to a specified maximum sequence length. + """Pad KV cache to a specified maximum length. Args: - max_length: Target sequence length to pad to. + max_length: Target length to pad the cache to. Returns: New KVCache with padded keys and values. """ - current_length = self.keys.shape[2] # (num_layers, batch, seq, heads, dim) - if current_length >= max_length: - return self - - pad_length = max_length - current_length - # Pad only the sequence dimension (axis 2) - pad_spec = ((0, 0), (0, 0), (0, pad_length), (0, 0), (0, 0)) + # k and v have shape [B, T, num_heads, head_dim] + cache_pad_length = max_length - self.keys[0].shape[1] + pad_spec = ((0, 0), (0, cache_pad_length), (0, 0), (0, 0)) return KVCache( - keys=jnp.pad(self.keys, pad_spec), - values=jnp.pad(self.values, pad_spec), + keys=[jnp.pad(k, pad_spec) for k in self.keys], + values=[jnp.pad(v, pad_spec) for v in self.values], cache_position=self.cache_position, ) @property def num_layers(self) -> int: """Number of layers in the cache.""" - return self.keys.shape[0] + return len(self.keys) @property def batch_size(self) -> int: """Batch size.""" - return self.keys.shape[1] + return self.keys[0].shape[0] @property def seq_len(self) -> int: """Current sequence length.""" - return self.keys.shape[2] + return self.keys[0].shape[1] def split(self, *layer_indices: int) -> tuple[KVCache | None, ...]: """Split the cache at one or more layer indices. @@ -163,8 +149,8 @@ def concatenate(*caches: KVCache | None) -> KVCache | None: return non_none_caches[0] return KVCache( - keys=jnp.concatenate([c.keys for c in non_none_caches], axis=0), - values=jnp.concatenate([c.values for c in non_none_caches], axis=0), + keys=sum([c.keys for c in non_none_caches], []), + values=sum([c.values for c in non_none_caches], []), cache_position=non_none_caches[-1].cache_position, ) @@ -280,7 +266,7 @@ def _prefill_and_decode( last_logits = model.compute_logits(last_hidden, adapter_indices)[:, 0, :] prompt_logprobs_array = None - # Pad KV cache to max_length (cache_position is already set by from_layer_outputs) + # Pad KV cache and attention mask to max_length kv_cache = outputs.kv_cache.pad_to_length(max_length) decode_attention_mask = jnp.pad(attention_mask, ((0, 0), (0, max_length - attention_mask.shape[1]))) From 11a9a1c9b014d9b1e85ffac667d4eb730fd34658 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 5 Feb 2026 17:31:14 -0800 Subject: [PATCH 121/133] lint --- skyrl-tx/tx/layers/stacked.py | 18 +++++++----------- skyrl-tx/tx/utils/models.py | 4 +--- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index d557c8660..4756a5bd7 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -154,19 +154,19 @@ def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tu result = [] for path, param in nnx.to_flat_state(state): # Only process paths belonging to this module - if not path[:len(base_path)] == base_path: + if not path[: len(base_path)] == base_path: continue # Only process _stacked paths - if "_stacked" not in path[len(base_path):]: + if "_stacked" not in path[len(base_path) :]: continue # Find _stacked in the relative path - rel_path = path[len(base_path):] + rel_path = path[len(base_path) :] stacked_idx = rel_path.index("_stacked") # Create per-layer paths: base_path + (layer_idx,) + rest for layer_idx in range(self.num_layers): - new_path = base_path + (str(layer_idx),) + rel_path[stacked_idx+1:] + new_path = base_path + (str(layer_idx),) + rel_path[stacked_idx + 1 :] result.append((new_path, ArrayRef(param, layer_idx))) return result @@ -243,9 +243,7 @@ def __call__( updated_keys.append(k) updated_values.append(v) - new_kv_cache = KVCache.update( - kv_cache, updated_keys, updated_values, positions, attention_mask - ) + new_kv_cache = KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask) return hidden_states, all_hidden_states, new_kv_cache # Prefill/training mode: use scan for efficiency @@ -273,9 +271,7 @@ def body_fn(carry, layer_params): if gradient_checkpointing: body_fn = jax.checkpoint(body_fn) - final_hs, (all_hs, all_keys, all_values) = jax.lax.scan( - body_fn, hidden_states, state - ) + final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, state) if is_training: new_kv_cache = None @@ -354,7 +350,7 @@ def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tu # Extract layer index from path: group_path + (layer_idx,) + rest layer_idx = int(path[len(group_path)]) # New path: base_path + (checkpoint_idx + layer_idx,) + rest - new_path = base_path + (str(checkpoint_idx + layer_idx),) + path[len(group_path)+1:] + new_path = base_path + (str(checkpoint_idx + layer_idx),) + path[len(group_path) + 1 :] result.append((new_path, array_ref)) checkpoint_idx += group.num_layers diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 8739b15d6..6ad3f8ea6 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -146,9 +146,7 @@ def load_safetensors( if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): continue if "experts" in path: - tensor = np.stack( - [tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0 - ) + tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) else: tensor = tensors[key] if "embed_tokens" in key else tensors[key].T if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: From 90f7173d6cb4be52acc158e071faa642b55e5eb9 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 5 Feb 2026 17:50:10 -0800 Subject: [PATCH 122/133] Sync NNX sharding metadata after stacking layers The arrays have correct stacked sharding from device_put, but NNX APIs (nnx.get_partition_spec, nnx.Optimizer) read from 'sharding_names' metadata. Without this sync, optimizer initialization fails with sharding errors. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tx/layers/stacked.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 4756a5bd7..f49322a83 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -116,8 +116,21 @@ def copy_to_slice(stacked, arr, idx): for i, arr in enumerate(flat): stacked_flat[i] = copy_to_slice(stacked_flat[i], flat[i], layer_idx) - # Reconstruct and merge + # Reconstruct state from stacked arrays stacked_state = jax.tree_util.tree_unflatten(treedef, stacked_flat) + + # Sync NNX sharding metadata with actual array sharding. + # The arrays have correct stacked sharding from device_put, but NNX APIs + # (nnx.get_partition_spec, nnx.Optimizer) read from 'sharding_names' metadata. + def update_sharding_metadata(var): + if isinstance(var, nnx.Variable) and hasattr(var.value, "sharding"): + array_sharding = var.value.sharding + if hasattr(array_sharding, "spec"): + var.set_metadata("sharding_names", tuple(array_sharding.spec)) + return var + + jax.tree.map(update_sharding_metadata, stacked_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + self._stacked = nnx.merge(graphdef, stacked_state) def __len__(self) -> int: From 5a5a48aace0c8097585fe727c0a9dd5ddb6d8727 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 5 Feb 2026 18:23:04 -0800 Subject: [PATCH 123/133] Fix KVCache format inconsistencies in tests Address Gemini code review feedback: KVCache uses list[jax.Array] format, not stacked arrays. Update tests to: - Use list format in DummyModel initialization - Check list length instead of array shape - Update comments to reflect list format Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tests/models/test_deepseekv3.py | 6 +++--- skyrl-tx/tests/models/test_models_common.py | 4 ++-- skyrl-tx/tests/utils/test_generator.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tests/models/test_deepseekv3.py b/skyrl-tx/tests/models/test_deepseekv3.py index 2b18a4b83..ca168606e 100644 --- a/skyrl-tx/tests/models/test_deepseekv3.py +++ b/skyrl-tx/tests/models/test_deepseekv3.py @@ -221,7 +221,7 @@ def test_deepseekv3_gradient_checkpointing(): results[use_checkpointing] = { "logits": np.array(logits), "hidden_states": [np.array(hs) for hs in out.hidden_states], - "kv_cache_shape": out.kv_cache.keys.shape, + "kv_cache_len": len(out.kv_cache.keys), } # Verify outputs match @@ -232,5 +232,5 @@ def test_deepseekv3_gradient_checkpointing(): for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(results[False]["hidden_states"], results[True]["hidden_states"])): np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}") - # Verify KV cache shape is correct (num_layers, batch, seq, heads, dim) - assert results[True]["kv_cache_shape"][0] == config.num_hidden_layers + # Verify KV cache has correct number of layers + assert results[True]["kv_cache_len"] == config.num_hidden_layers diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index 0eb283ed0..737eafca4 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -149,8 +149,8 @@ def test_kv_cache_with_checkpointing( out = model(input_ids, attention_mask=attention_mask) - # keys is a stacked array with shape (num_layers, batch, seq, heads, dim) - assert out.kv_cache.keys.shape[0] == config.num_hidden_layers + # keys is a list with one entry per layer + assert len(out.kv_cache.keys) == config.num_hidden_layers @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index f4cbe3421..598a0b12e 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -52,9 +52,9 @@ def __call__( if kv_cache is None: # Prefill: deterministic hidden_states (which equal logits) hidden_states = jnp.tile(base[None, None, :], (batch_size, seq_len, 1)) - # Stacked format: (num_layers, batch, seq, heads, dim) - use 1 layer for this dummy model - keys = jnp.zeros((1, batch_size, seq_len, 1, 1), dtype=jnp.float32) - values = jnp.zeros((1, batch_size, seq_len, 1, 1), dtype=jnp.float32) + # List format: one entry per layer (use 1 layer for this dummy model) + keys = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] + values = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] # Per-sequence cache_position (all same length in this test) cache_position = ( attention_mask.sum(axis=1) if attention_mask is not None else jnp.full((batch_size,), seq_len) From 66d2ac8756207529349152f0c08586c09c17ebf4 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 6 Feb 2026 10:03:19 -0800 Subject: [PATCH 124/133] Address PR review feedback: simplify and clean up stacked layers - Simplify ArrayRef __getitem__/__setitem__ by removing Ellipsis special cases - Remove unreachable IndexError in MultiStackedDecoderLayers.__getitem__ - Remove unused get_stacked_layers_list from DeepseekV3Model and MultiStackedDecoderLayers - Remove unused self.num_layers from DeepseekV3Model, Llama3Model, Qwen3Model - Use extended unpacking syntax (p[*idx, ..., :rank]) for readability - Rewrite make_ep_spec to use is_stacked_path + tree_map_with_path - Rename is_stacked_lora_path to is_stacked_path (not LoRA-specific) - Remove load_stacked_lora_weights in test_qwen3, use unstacked view instead Co-Authored-By: Claude Opus 4.6 --- skyrl-tx/tests/models/test_qwen3.py | 31 +++------------- skyrl-tx/tests/utils/test_models.py | 10 +++--- skyrl-tx/tx/layers/stacked.py | 55 +++++++++++++++++------------ skyrl-tx/tx/layers/util.py | 22 ++++++------ skyrl-tx/tx/models/deepseekv3.py | 5 --- skyrl-tx/tx/models/llama3.py | 1 - skyrl-tx/tx/models/qwen3.py | 1 - skyrl-tx/tx/utils/models.py | 17 +++++---- 8 files changed, 60 insertions(+), 82 deletions(-) diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index cf2316e2c..fda1d409e 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -105,28 +105,6 @@ def load_lora_weights( jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[adapter_idx].set(rank) -def load_stacked_lora_weights( - jax_module: LoRAMixin, - layer_idx: int, - adapter_idx: int, - lora_A_weights: np.ndarray, - lora_B_weights: np.ndarray, - scaling: float, - rank: int, -) -> None: - """Load LoRA weights for a specific layer in stacked format (decoder layers).""" - assert ( - jax_module.lora_A is not None - and jax_module.lora_B is not None - and jax_module.lora_scaling is not None - and jax_module.lora_ranks is not None - ) - jax_module.lora_A[...] = jax_module.lora_A[...].at[layer_idx, adapter_idx].set(jnp.array(lora_A_weights)) - jax_module.lora_B[...] = jax_module.lora_B[...].at[layer_idx, adapter_idx].set(jnp.array(lora_B_weights)) - jax_module.lora_scaling[...] = jax_module.lora_scaling[...].at[layer_idx, adapter_idx].set(scaling) - jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[layer_idx, adapter_idx].set(rank) - - @pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)]) def test_qwen3_moe_layer_lora(ep: int, tp: int): """Test MoE LoRA by merging adapter into base weights and comparing outputs.""" @@ -267,20 +245,19 @@ def test_qwen3_lora(): rank=lora_config.r, ) - # Load layer LoRA weights (stacked format) - # Access _stacked to get the stacked module with LoRA parameters + # Load layer LoRA weights via unstacked view for i in range(config.num_hidden_layers): hf_layer = hf_model.base_model.model.model.layers[i] + jax_layer = model.model.layers[i] for module_name, projections in [ ("mlp", ["gate_proj", "up_proj", "down_proj"]), ("self_attn", ["q_proj", "k_proj", "v_proj", "o_proj"]), ]: for proj_name in projections: hf_proj = getattr(getattr(hf_layer, module_name), proj_name) - jax_proj = getattr(getattr(model.model.layers._stacked, module_name), proj_name) - load_stacked_lora_weights( + jax_proj = getattr(getattr(jax_layer, module_name), proj_name) + load_lora_weights( jax_proj, - layer_idx=i, adapter_idx=adapter_idx, lora_A_weights=hf_proj.lora_A["default"].weight.detach().numpy().T, lora_B_weights=hf_proj.lora_B["default"].weight.detach().numpy().T, diff --git a/skyrl-tx/tests/utils/test_models.py b/skyrl-tx/tests/utils/test_models.py index 747ab0e66..e41baabe3 100644 --- a/skyrl-tx/tests/utils/test_models.py +++ b/skyrl-tx/tests/utils/test_models.py @@ -18,7 +18,7 @@ from tx.models.qwen3 import Qwen3ForCausalLM from tx.tinker.types import LoraConfig from tx.utils import models -from tx.utils.models import extract_adapter_state, insert_adapter_state, is_stacked_lora_path +from tx.utils.models import extract_adapter_state, insert_adapter_state, is_stacked_path from tx.utils.storage import download_and_unpack @@ -107,9 +107,9 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat ], ids=["layers", "dense_layers", "moe_layers", "embed_tokens", "lm_head", "str_layers", "str_embed"], ) -def test_is_stacked_lora_path(path, expected): - """Test is_stacked_lora_path correctly identifies stacked vs non-stacked paths.""" - assert is_stacked_lora_path(path) is expected +def test_is_stacked_path(path, expected): + """Test is_stacked_path correctly identifies stacked vs non-stacked paths.""" + assert is_stacked_path(path) is expected def test_extract_insert_adapter_state_roundtrip(): @@ -139,7 +139,7 @@ def test_extract_insert_adapter_state_roundtrip(): key = path[-2].key if hasattr(path[-2], "key") else str(path[-2]) if key in {"lora_A", "lora_B"}: # Stacked: should have (num_layers, ...) not (num_layers, num_adapters, ...) - if is_stacked_lora_path(path): + if is_stacked_path(path): assert leaf.shape[0] == 1 # num_layers assert leaf.ndim == 3 # (layers, in_dim, rank) or (layers, rank, out_dim) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index f49322a83..34cfab908 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -21,17 +21,12 @@ def __init__(self, parent: nnx.Variable, idx: int): def __getitem__(self, key): parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") - return parent[idx] if key is Ellipsis else parent[idx][key] + return parent[idx][key] def __setitem__(self, key, value): """Write through to parent when value is set via indexing.""" parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") - if key is Ellipsis: - # param[...] = value -> update entire slice - parent[...] = parent[...].at[idx].set(value) - else: - # param[key] = value -> update sub-slice - parent[...] = parent[...].at[idx][key].set(value) + parent[...] = parent[...].at[idx][key].set(value) # Also update our local value super().__setitem__(key, value) @@ -94,7 +89,9 @@ def __init__( original_sharding = arr.sharding if hasattr(original_sharding, "spec"): new_spec = PartitionSpec(None, *original_sharding.spec) - stacked = jax.device_put(jnp.zeros(stacked_shape, arr.dtype), NamedSharding(mesh, new_spec)) + stacked = jax.device_put( + jnp.zeros(stacked_shape, arr.dtype), NamedSharding(mesh, new_spec) + ) else: stacked = jnp.zeros(stacked_shape, arr.dtype) stacked_flat.append(stacked) @@ -129,7 +126,11 @@ def update_sharding_metadata(var): var.set_metadata("sharding_names", tuple(array_sharding.spec)) return var - jax.tree.map(update_sharding_metadata, stacked_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + jax.tree.map( + update_sharding_metadata, + stacked_state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) self._stacked = nnx.merge(graphdef, stacked_state) @@ -154,7 +155,9 @@ def __iter__(self): for i in range(self.num_layers): yield self[i] - def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tuple[tuple, ArrayRef]]: + def unstack_paths( + self, state: nnx.GraphState, base_path: tuple = () + ) -> list[tuple[tuple, ArrayRef]]: """Transform _stacked paths to per-layer paths with ArrayRef. Args: @@ -256,7 +259,9 @@ def __call__( updated_keys.append(k) updated_values.append(v) - new_kv_cache = KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask) + new_kv_cache = KVCache.update( + kv_cache, updated_keys, updated_values, positions, attention_mask + ) return hidden_states, all_hidden_states, new_kv_cache # Prefill/training mode: use scan for efficiency @@ -284,7 +289,9 @@ def body_fn(carry, layer_params): if gradient_checkpointing: body_fn = jax.checkpoint(body_fn) - final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, state) + final_hs, (all_hs, all_keys, all_values) = jax.lax.scan( + body_fn, hidden_states, state + ) if is_training: new_kv_cache = None @@ -292,9 +299,13 @@ def body_fn(carry, layer_params): # Convert stacked scan outputs to list format keys_list = [all_keys[i] for i in range(self.num_layers)] values_list = [all_values[i] for i in range(self.num_layers)] - new_kv_cache = KVCache.update(None, keys_list, values_list, positions, attention_mask) + new_kv_cache = KVCache.update( + None, keys_list, values_list, positions, attention_mask + ) - all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] + all_hidden_states = ( + [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] + ) return final_hs, all_hidden_states, new_kv_cache @@ -330,18 +341,14 @@ def __getitem__(self, index: int) -> nnx.Module: return group[index - offset] offset += group.num_layers - raise IndexError(f"Layer index {index} not found") - def __iter__(self): """Iterate over all layers across all groups.""" for group in self.layer_groups: yield from group - def get_stacked_layers_list(self) -> list[StackedDecoderLayers]: - """Return list of StackedDecoderLayers for checkpoint loading.""" - return list(self.layer_groups) - - def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tuple[tuple, ArrayRef]]: + def unstack_paths( + self, state: nnx.GraphState, base_path: tuple = () + ) -> list[tuple[tuple, ArrayRef]]: """Transform _stacked paths from all groups to unified per-layer paths. Args: @@ -363,7 +370,11 @@ def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tu # Extract layer index from path: group_path + (layer_idx,) + rest layer_idx = int(path[len(group_path)]) # New path: base_path + (checkpoint_idx + layer_idx,) + rest - new_path = base_path + (str(checkpoint_idx + layer_idx),) + path[len(group_path) + 1 :] + new_path = ( + base_path + + (str(checkpoint_idx + layer_idx),) + + path[len(group_path) + 1 :] + ) result.append((new_path, array_ref)) checkpoint_idx += group.num_layers diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index e0f596d94..04f401f8b 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -4,6 +4,8 @@ from jax import numpy as jnp from jax.sharding import get_abstract_mesh, PartitionSpec +from tx.utils.models import is_stacked_path + def ragged_dot( lhs: jax.Array, @@ -97,20 +99,16 @@ def shard_map_ep(module: nnx.Module, func, *args): """ graphdef, state = nnx.split(module) - def make_ep_spec(spec, value): - """Create a PartitionSpec with only 'ep' dims, truncated to match tensor rank.""" - if not isinstance(spec, PartitionSpec): - return spec - # When a layer is extracted from stacked layers via x[layer_idx], the tensor - # loses a dimension but the PartitionSpec metadata still has the extra leading None. - # Truncate the spec to match the actual tensor rank. - arr = value.value if hasattr(value, "value") else value - rank = len(arr.shape) if hasattr(arr, "shape") else 0 - truncated = tuple(spec)[-rank:] if rank > 0 else () - return PartitionSpec(*(p if p == "ep" else None for p in truncated)) + def make_ep_spec(path, s): + if not isinstance(s, PartitionSpec): + return s + # Strip leading stacking dimension if path is stacked + dims = s[1:] if is_stacked_path(path) else s + # Extract only 'ep' dims from PartitionSpecs, replacing others with None + return PartitionSpec(*(p if p == "ep" else None for p in dims)) partition_specs = nnx.get_partition_spec(state) - state_specs = jax.tree.map(make_ep_spec, partition_specs, state, is_leaf=lambda x: isinstance(x, PartitionSpec)) + state_specs = jax.tree_util.tree_map_with_path(make_ep_spec, partition_specs, is_leaf=lambda x: isinstance(x, PartitionSpec)) in_specs = (state_specs,) + (PartitionSpec(),) * len(args) @jax.shard_map(mesh=get_abstract_mesh(), in_specs=in_specs, out_specs=PartitionSpec(), axis_names={"ep"}) diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 00b28ffe4..25699cfe4 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -469,7 +469,6 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.config = config self.num_dense_layers = config.first_k_dense_replace self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace - self.num_layers = config.num_hidden_layers self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, @@ -496,10 +495,6 @@ def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) - def get_stacked_layers_list(self): - """Delegate to MultiStackedDecoderLayers for checkpoint loading.""" - return self.layers.get_stacked_layers_list() - def __call__( self, input_ids: jax.Array, diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 8ff6c85ff..124bc7a09 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -200,7 +200,6 @@ class Llama3Model(nnx.Module): def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config - self.num_layers = config.num_hidden_layers self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index a067e8245..6c786bb39 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -318,7 +318,6 @@ class Qwen3Model(nnx.Module): def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config - self.num_layers = config.num_hidden_layers self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 6ad3f8ea6..2138ec6e7 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -78,11 +78,10 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: raise ValueError(f"None of the architectures {config.architectures} is currently supported.") -def is_stacked_lora_path(path: tuple) -> bool: - """Check if a parameter path is for stacked layer weights (for LoRA indexing). +def is_stacked_path(path: tuple) -> bool: + """Check if a parameter path is for stacked layer weights. - Stacked layer params have the adapter dimension at axis 1: (num_layers, num_adapters, ...). - Non-stacked params (e.g., embed_tokens) have adapter dimension at axis 0: (num_adapters, ...). + Stacked layer params have an extra leading dimension: (num_layers, ...). Args: path: Parameter path tuple (can be nnx path objects or strings). @@ -100,7 +99,7 @@ def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: Stacked layer params have shape (num_layers, num_adapters, ...) -> index as [:, adapter_index]. Non-stacked params (embed_tokens) have shape (num_adapters, ...) -> index as [adapter_index]. """ - if is_stacked_lora_path(path): + if is_stacked_path(path): return (slice(None), adapter_index) return (adapter_index,) @@ -291,8 +290,8 @@ def extract_state(path: tuple, p: jnp.ndarray): assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" idx = get_adapter_idx(path, adapter_index) if key == "lora_A": - return p[idx + (..., slice(None, rank))] - return p[idx + (..., slice(None, rank), slice(None))] + return p[*idx, ..., :rank] + return p[*idx, ..., :rank, :] return jax.tree.map_with_path(extract_state, lora_params) @@ -311,8 +310,8 @@ def insert_state(path: tuple, p: jax.Array, new: jax.Array): assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" idx = get_adapter_idx(path, adapter_index) if key == "lora_A": - return p.at[idx + (..., slice(None, rank))].set(new) - return p.at[idx + (..., slice(None, rank), slice(None))].set(new) + return p.at[*idx, ..., :rank].set(new) + return p.at[*idx, ..., :rank, :].set(new) updated = jax.tree.map_with_path(insert_state, lora_params, new_params) nnx.update(lora_params, updated) From d9f9c3d4f3688c459819af78479357acf7eeabff Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 6 Feb 2026 10:19:31 -0800 Subject: [PATCH 125/133] Simplify test_jax_backend to use unstacked layer views Use model.model.layers[0] instead of accessing stacked format directly. ArrayRef provides transparent unstacked views, so tests don't need explicit stacked indexing. Co-Authored-By: Claude Opus 4.6 --- skyrl-tx/tests/tinker/test_jax_backend.py | 28 +++++++++++------------ 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index 5ffa5d60c..3543c7378 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -108,18 +108,17 @@ def test_clear_lora_adapter(): # Verify adapter has non-zero rank after creation model = backend.model - # With stacked layers, lora_ranks has shape (num_layers, num_adapters) - lora_layer: LoRALinear = model.model.layers.self_attn.q_proj - assert lora_layer.lora_ranks[0, adapter_idx] > 0 + lora_layer: LoRALinear = model.model.layers[0].self_attn.q_proj + assert lora_layer.lora_ranks[adapter_idx] > 0 # Delete the model (calls clear_lora_adapter internally) backend.delete_model(model_id) - # Verify adapter state is zeroed (check layer 0) - assert lora_layer.lora_ranks[0, adapter_idx] == 0 - assert lora_layer.lora_scaling[0, adapter_idx] == 0.0 - assert (lora_layer.lora_A[0, adapter_idx] == 0.0).all() - assert (lora_layer.lora_B[0, adapter_idx] == 0.0).all() + # Verify adapter state is zeroed + assert lora_layer.lora_ranks[adapter_idx] == 0 + assert lora_layer.lora_scaling[adapter_idx] == 0.0 + assert (lora_layer.lora_A[adapter_idx] == 0.0).all() + assert (lora_layer.lora_B[adapter_idx] == 0.0).all() def make_fwd_bwd_input(token_lists: list[list[int]]) -> types.ForwardBackwardInput: @@ -535,21 +534,20 @@ def test_adapter_reuse_initializes_lora_adapter(): # (slot 0 is reserved for base model) backend = create_backend(max_lora_adapters=2) model = backend.model - # With stacked layers, lora_A has shape (num_layers, num_adapters, in_features, max_rank) - lora_layer: LoRALinear = model.model.layers.self_attn.q_proj + lora_layer: LoRALinear = model.model.layers[0].self_attn.q_proj # Create first model model_id_1 = "model_1" adapter_idx = create_model(backend, model_id_1) - # Verify lora_A is non-zero after creation (check layer 0) + # Verify lora_A is non-zero after creation assert not ( - lora_layer.lora_A[0, adapter_idx, ..., :LORA_RANK] == 0.0 + lora_layer.lora_A[adapter_idx, ..., :LORA_RANK] == 0.0 ).all(), "lora_A should be initialized with he_uniform (non-zero)" # Delete the model (clears both lora_A and lora_B to zeros) backend.delete_model(model_id_1) - assert (lora_layer.lora_A[0, adapter_idx] == 0.0).all(), "lora_A should be zeroed after clear_lora_adapter" + assert (lora_layer.lora_A[adapter_idx] == 0.0).all(), "lora_A should be zeroed after clear_lora_adapter" # Create a new model that reuses the same adapter slot model_id_2 = "model_2" @@ -558,11 +556,11 @@ def test_adapter_reuse_initializes_lora_adapter(): # Verify lora_A is initialized (non-zero) assert not ( - lora_layer.lora_A[0, adapter_idx, ..., :LORA_RANK] == 0.0 + lora_layer.lora_A[adapter_idx, ..., :LORA_RANK] == 0.0 ).all(), "lora_A should be initialized with he_uniform after adapter reuse" # Verify lora_B is zeros - assert (lora_layer.lora_B[0, adapter_idx] == 0.0).all(), "lora_B should be zeros" + assert (lora_layer.lora_B[adapter_idx] == 0.0).all(), "lora_B should be zeros" def test_mixed_train_unembed_adapters(): From 084551ec45aebe8710d02ed34ec4cee7fc3a5733 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 6 Feb 2026 10:26:28 -0800 Subject: [PATCH 126/133] revert some comments --- skyrl-tx/tests/models/test_qwen3.py | 4 ++-- skyrl-tx/tests/utils/test_generator.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index fda1d409e..ab64c5a82 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -92,7 +92,7 @@ def load_lora_weights( scaling: float, rank: int, ) -> None: - """Load LoRA weights from numpy arrays to JAX module (non-stacked modules like embed_tokens).""" + """Load LoRA weights from numpy arrays to JAX module.""" assert ( jax_module.lora_A is not None and jax_module.lora_B is not None @@ -245,7 +245,7 @@ def test_qwen3_lora(): rank=lora_config.r, ) - # Load layer LoRA weights via unstacked view + # Load layer LoRA weights for i in range(config.num_hidden_layers): hf_layer = hf_model.base_model.model.model.layers[i] jax_layer = model.model.layers[i] diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 598a0b12e..89bc637be 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -52,7 +52,6 @@ def __call__( if kv_cache is None: # Prefill: deterministic hidden_states (which equal logits) hidden_states = jnp.tile(base[None, None, :], (batch_size, seq_len, 1)) - # List format: one entry per layer (use 1 layer for this dummy model) keys = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] values = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] # Per-sequence cache_position (all same length in this test) From 119c1d350249ad69524a2749cd2f2e279c3632ac Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 6 Feb 2026 10:54:23 -0800 Subject: [PATCH 127/133] lint --- skyrl-tx/tx/layers/stacked.py | 34 ++++++++-------------------------- skyrl-tx/tx/layers/util.py | 4 +++- 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 34cfab908..4699a791d 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -89,9 +89,7 @@ def __init__( original_sharding = arr.sharding if hasattr(original_sharding, "spec"): new_spec = PartitionSpec(None, *original_sharding.spec) - stacked = jax.device_put( - jnp.zeros(stacked_shape, arr.dtype), NamedSharding(mesh, new_spec) - ) + stacked = jax.device_put(jnp.zeros(stacked_shape, arr.dtype), NamedSharding(mesh, new_spec)) else: stacked = jnp.zeros(stacked_shape, arr.dtype) stacked_flat.append(stacked) @@ -155,9 +153,7 @@ def __iter__(self): for i in range(self.num_layers): yield self[i] - def unstack_paths( - self, state: nnx.GraphState, base_path: tuple = () - ) -> list[tuple[tuple, ArrayRef]]: + def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tuple[tuple, ArrayRef]]: """Transform _stacked paths to per-layer paths with ArrayRef. Args: @@ -259,9 +255,7 @@ def __call__( updated_keys.append(k) updated_values.append(v) - new_kv_cache = KVCache.update( - kv_cache, updated_keys, updated_values, positions, attention_mask - ) + new_kv_cache = KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask) return hidden_states, all_hidden_states, new_kv_cache # Prefill/training mode: use scan for efficiency @@ -289,9 +283,7 @@ def body_fn(carry, layer_params): if gradient_checkpointing: body_fn = jax.checkpoint(body_fn) - final_hs, (all_hs, all_keys, all_values) = jax.lax.scan( - body_fn, hidden_states, state - ) + final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, state) if is_training: new_kv_cache = None @@ -299,13 +291,9 @@ def body_fn(carry, layer_params): # Convert stacked scan outputs to list format keys_list = [all_keys[i] for i in range(self.num_layers)] values_list = [all_values[i] for i in range(self.num_layers)] - new_kv_cache = KVCache.update( - None, keys_list, values_list, positions, attention_mask - ) + new_kv_cache = KVCache.update(None, keys_list, values_list, positions, attention_mask) - all_hidden_states = ( - [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] - ) + all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] return final_hs, all_hidden_states, new_kv_cache @@ -346,9 +334,7 @@ def __iter__(self): for group in self.layer_groups: yield from group - def unstack_paths( - self, state: nnx.GraphState, base_path: tuple = () - ) -> list[tuple[tuple, ArrayRef]]: + def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tuple[tuple, ArrayRef]]: """Transform _stacked paths from all groups to unified per-layer paths. Args: @@ -370,11 +356,7 @@ def unstack_paths( # Extract layer index from path: group_path + (layer_idx,) + rest layer_idx = int(path[len(group_path)]) # New path: base_path + (checkpoint_idx + layer_idx,) + rest - new_path = ( - base_path - + (str(checkpoint_idx + layer_idx),) - + path[len(group_path) + 1 :] - ) + new_path = base_path + (str(checkpoint_idx + layer_idx),) + path[len(group_path) + 1 :] result.append((new_path, array_ref)) checkpoint_idx += group.num_layers diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 04f401f8b..0a3f545ef 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -108,7 +108,9 @@ def make_ep_spec(path, s): return PartitionSpec(*(p if p == "ep" else None for p in dims)) partition_specs = nnx.get_partition_spec(state) - state_specs = jax.tree_util.tree_map_with_path(make_ep_spec, partition_specs, is_leaf=lambda x: isinstance(x, PartitionSpec)) + state_specs = jax.tree_util.tree_map_with_path( + make_ep_spec, partition_specs, is_leaf=lambda x: isinstance(x, PartitionSpec) + ) in_specs = (state_specs,) + (PartitionSpec(),) * len(args) @jax.shard_map(mesh=get_abstract_mesh(), in_specs=in_specs, out_specs=PartitionSpec(), axis_names={"ep"}) From 5230a1b071bbd00d8fd660b74ac3b1838a6ddc9d Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 6 Feb 2026 13:09:36 -0800 Subject: [PATCH 128/133] fix --- skyrl-tx/tx/layers/stacked.py | 13 ++++++++----- skyrl-tx/tx/layers/util.py | 24 +++++++++++++----------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 4699a791d..4ab8c03a3 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -24,11 +24,14 @@ def __getitem__(self, key): return parent[idx][key] def __setitem__(self, key, value): - """Write through to parent when value is set via indexing.""" - parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") - parent[...] = parent[...].at[idx][key].set(value) - # Also update our local value - super().__setitem__(key, value) + """Write through to parent when value is set via indexing. + + Only supports Ellipsis key (param[...] = value) because JAX's .at[idx] + returns _IndexUpdateRef which doesn't support further subscripting. + """ + if key is not Ellipsis: + raise NotImplementedError("ArrayRef only supports `ref[...] = value`") + self.set_raw_value(value) def set_raw_value(self, value, **kwargs): """Write through to parent when value is set.""" diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 0a3f545ef..d1fd80cac 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -4,8 +4,6 @@ from jax import numpy as jnp from jax.sharding import get_abstract_mesh, PartitionSpec -from tx.utils.models import is_stacked_path - def ragged_dot( lhs: jax.Array, @@ -99,17 +97,21 @@ def shard_map_ep(module: nnx.Module, func, *args): """ graphdef, state = nnx.split(module) - def make_ep_spec(path, s): - if not isinstance(s, PartitionSpec): - return s - # Strip leading stacking dimension if path is stacked - dims = s[1:] if is_stacked_path(path) else s - # Extract only 'ep' dims from PartitionSpecs, replacing others with None - return PartitionSpec(*(p if p == "ep" else None for p in dims)) + def make_ep_spec(spec, value): + """Create a PartitionSpec with only 'ep' dims, truncated to match tensor rank.""" + if not isinstance(spec, PartitionSpec): + return spec + # When a layer is extracted from stacked layers via x[layer_idx], the tensor + # loses a dimension but the PartitionSpec metadata still has the extra leading None. + # Truncate the spec to match the actual tensor rank. + arr = value[...] if isinstance(value, nnx.Variable) else value + rank = len(arr.shape) if hasattr(arr, "shape") else 0 + truncated = tuple(spec)[-rank:] if rank > 0 else () + return PartitionSpec(*(p if p == "ep" else None for p in truncated)) partition_specs = nnx.get_partition_spec(state) - state_specs = jax.tree_util.tree_map_with_path( - make_ep_spec, partition_specs, is_leaf=lambda x: isinstance(x, PartitionSpec) + state_specs = jax.tree.map( + make_ep_spec, partition_specs, state, is_leaf=lambda x: isinstance(x, PartitionSpec) ) in_specs = (state_specs,) + (PartitionSpec(),) * len(args) From bbf0bd55d4c2688d3162c9449746794db9a27322 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 6 Feb 2026 13:55:41 -0800 Subject: [PATCH 129/133] fix test_models --- skyrl-tx/tests/utils/test_models.py | 35 +++++++++++++---------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/skyrl-tx/tests/utils/test_models.py b/skyrl-tx/tests/utils/test_models.py index e41baabe3..ce1c7004a 100644 --- a/skyrl-tx/tests/utils/test_models.py +++ b/skyrl-tx/tests/utils/test_models.py @@ -58,18 +58,14 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat adapter_config = LoraConfig(rank=rank, alpha=alpha, seed=0) # Set LoRA weights to random values for testing (to catch transpose bugs) - # layers is now stacked, so access directly (not subscriptable) - # LoRA weights have shape (num_layers, num_adapters, ...) for stacked layers - q_proj = model.model.layers.self_attn.q_proj + q_proj = model.model.layers[0].self_attn.q_proj rng1, rng2 = jax.random.split(jax.random.PRNGKey(42)) q_proj.lora_A[...] = jax.random.normal(rng1, q_proj.lora_A[...].shape) q_proj.lora_B[...] = jax.random.normal(rng2, q_proj.lora_B[...].shape) # Store expected values (trimmed to rank and transposed) - # For stacked layers: shape is (num_layers, num_adapters, in_dim, rank) for lora_A - # We have 1 layer, so index [0] for layer, then adapter_index - expected_lora_A = np.array(q_proj.lora_A[...][0, adapter_index, :, :rank].T) - expected_lora_B = np.array(q_proj.lora_B[...][0, adapter_index, :rank, :].T) + expected_lora_A = np.array(q_proj.lora_A[...][adapter_index, :, :rank].T) + expected_lora_B = np.array(q_proj.lora_B[...][adapter_index, :rank, :].T) # Save and verify checkpoint exists models.save_lora_checkpoint(model, base_model_name, adapter_config, adapter_index, output_path) @@ -94,18 +90,17 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat @pytest.mark.parametrize( "path,expected", [ - # Stacked paths (DictKey) - ((DictKey(key="model"), DictKey(key="layers"), DictKey(key="self_attn"), DictKey(key="lora_A")), True), - ((DictKey(key="model"), DictKey(key="dense_layers"), DictKey(key="self_attn"), DictKey(key="lora_A")), True), - ((DictKey(key="model"), DictKey(key="moe_layers"), DictKey(key="mlp"), DictKey(key="lora_A")), True), + # Stacked paths (DictKey) — real NNX paths include _stacked + ((DictKey(key="model"), DictKey(key="layers"), DictKey(key="_stacked"), DictKey(key="self_attn"), DictKey(key="lora_A")), True), + ((DictKey(key="model"), DictKey(key="layers"), DictKey(key="layer_groups"), DictKey(key="_stacked"), DictKey(key="self_attn"), DictKey(key="lora_A")), True), # Non-stacked paths (DictKey) ((DictKey(key="model"), DictKey(key="embed_tokens"), DictKey(key="lora_A")), False), ((DictKey(key="lm_head"), DictKey(key="lora_A")), False), # String paths - (("model", "layers", "self_attn", "lora_A"), True), + (("model", "layers", "_stacked", "self_attn", "lora_A"), True), (("model", "embed_tokens", "lora_A"), False), ], - ids=["layers", "dense_layers", "moe_layers", "embed_tokens", "lm_head", "str_layers", "str_embed"], + ids=["stacked_layers", "multi_stacked_layers", "embed_tokens", "lm_head", "str_stacked", "str_embed"], ) def test_is_stacked_path(path, expected): """Test is_stacked_path correctly identifies stacked vs non-stacked paths.""" @@ -119,7 +114,7 @@ def test_extract_insert_adapter_state_roundtrip(): _, _, model = create_test_model(base_model_name, rank, alpha, adapter_index) # Set LoRA weights to random values - q_proj = model.model.layers.self_attn.q_proj + q_proj = model.model.layers[0].self_attn.q_proj rng1, rng2 = jax.random.split(jax.random.PRNGKey(123)) q_proj.lora_A[...] = jax.random.normal(rng1, q_proj.lora_A[...].shape) q_proj.lora_B[...] = jax.random.normal(rng2, q_proj.lora_B[...].shape) @@ -128,8 +123,8 @@ def test_extract_insert_adapter_state_roundtrip(): _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) # Store original values for comparison - original_lora_A = np.array(q_proj.lora_A[...][0, adapter_index, :, :rank]) - original_lora_B = np.array(q_proj.lora_B[...][0, adapter_index, :rank, :]) + original_lora_A = np.array(q_proj.lora_A[...][adapter_index, :, :rank]) + original_lora_B = np.array(q_proj.lora_B[...][adapter_index, :rank, :]) # Extract adapter state extracted = extract_adapter_state(adapter_index, lora_params, rank) @@ -144,12 +139,12 @@ def test_extract_insert_adapter_state_roundtrip(): assert leaf.ndim == 3 # (layers, in_dim, rank) or (layers, rank, out_dim) # Zero out the adapter's weights - q_proj.lora_A[...] = q_proj.lora_A[...].at[0, adapter_index].set(0) - q_proj.lora_B[...] = q_proj.lora_B[...].at[0, adapter_index].set(0) + q_proj.lora_A[...] = q_proj.lora_A[...].at[adapter_index].set(0) + q_proj.lora_B[...] = q_proj.lora_B[...].at[adapter_index].set(0) # Verify weights are zeroed - assert np.allclose(q_proj.lora_A[...][0, adapter_index], 0) - assert np.allclose(q_proj.lora_B[...][0, adapter_index], 0) + assert np.allclose(q_proj.lora_A[...][adapter_index], 0) + assert np.allclose(q_proj.lora_B[...][adapter_index], 0) # Re-split to get updated lora_params _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) From 9bf68ebd3b82c87521af25a5fb5ea7dd3961a519 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 6 Feb 2026 14:55:56 -0800 Subject: [PATCH 130/133] Pass full model to save/load_safetensors for LoRA checkpoint I/O Instead of extracting a GraphState and passing it to save/load_safetensors (which expect an nnx.Module for unstack_state), pass the full model with adapter_index/rank params to slice LoRA weights inline. Extract shared adapter slicing logic into get_lora_adapter_slice helper. Co-Authored-By: Claude Opus 4.6 --- skyrl-tx/tx/utils/models.py | 72 +++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 2138ec6e7..f9bd3ec25 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -104,6 +104,20 @@ def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: return (adapter_index,) +def get_lora_adapter_slice(path: tuple, adapter_index: int, rank: int) -> tuple | None: + """Return index tuple for accessing a single adapter's LoRA weight in an unstacked param. + + After unstack_state, LoRA params have shape (num_adapters, ..., max_rank, ...). + Returns the slice to extract/insert one adapter's trimmed-rank weights, or None + for non-LoRA params. + """ + if "lora_A" in path: + return (adapter_index, slice(None), slice(None, rank)) + if "lora_B" in path: + return (adapter_index, slice(None, rank), slice(None)) + return None + + def get_param_key(path: tuple, prefix: str = "") -> str: "Get the safetensors key for a given model path." if path[-1] in {"embedding", "kernel"}: @@ -126,8 +140,14 @@ def load_safetensors( skip_lora: bool = True, prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, + adapter_index: int | None = None, + rank: int | None = None, ) -> None: - """Load safetensors weights into a model with stacked layers.""" + """Load safetensors weights into a model with stacked layers. + + When adapter_index and rank are provided, loads LoRA weights into a specific + adapter slot instead of replacing the full parameter. + """ from tx.layers.stacked import unstack_state tensors = {} @@ -144,15 +164,23 @@ def load_safetensors( # Skip LoRA parameters if requested if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): continue + if key not in tensors: + continue if "experts" in path: tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) else: tensor = tensors[key] if "embed_tokens" in key else tensors[key].T - if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: - tensor = tensor.reshape(param.shape) - assert param.shape == tensor.shape, f"shape mismatch for {key}" - # ArrayRef.set_raw_value writes through to the stacked parent variable - param.set_raw_value(jax.device_put(tensor.astype(param.dtype), param.sharding)) + lora_idx = get_lora_adapter_slice(path, adapter_index, rank) if adapter_index is not None else None + if lora_idx is not None: + # Load into specific adapter slot via ArrayRef write-through + arr = param[...] + param[...] = arr.at[lora_idx].set(jnp.array(tensor, dtype=arr.dtype)) + else: + if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: + tensor = tensor.reshape(param.shape) + assert param.shape == tensor.shape, f"shape mismatch for {key}" + # ArrayRef.set_raw_value writes through to the stacked parent variable + param.set_raw_value(jax.device_put(tensor.astype(param.dtype), param.sharding)) def save_safetensors( @@ -161,8 +189,14 @@ def save_safetensors( filename: Path, prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, + adapter_index: int | None = None, + rank: int | None = None, ) -> None: - """Save model weights to safetensors, unstacking layer weights for HF compatibility.""" + """Save model weights to safetensors, unstacking layer weights for HF compatibility. + + When adapter_index and rank are provided, extracts a single adapter's LoRA + weights instead of saving the full parameter. + """ from tx.layers.stacked import unstack_state # unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths @@ -174,6 +208,10 @@ def save_safetensors( if filter_fn is not None and not filter_fn(path): continue key = get_param_key(path, prefix=prefix) + # Extract specific adapter's LoRA weights when adapter_index is provided + lora_idx = get_lora_adapter_slice(path, adapter_index, rank) if adapter_index is not None else None + if lora_idx is not None: + param = param[lora_idx] if "experts" in path: for i in range(config.get_num_experts()): tensors[get_expert_key(path, i)] = param[i, :, :].T @@ -195,6 +233,9 @@ def save_safetensors( def filter_lora(adapter_config: LoraConfig, path: tuple[str, ...]) -> bool: + """Check if a LoRA weight path matches the adapter config's training targets.""" + if "lora_A" not in path and "lora_B" not in path: + return False if not adapter_config.train_attn and "self_attn" in path: return False if not adapter_config.train_mlp and ("mlp" in path or "experts" in path): @@ -215,20 +256,17 @@ def load_lora_checkpoint( adapter_index: Index of the adapter to load into checkpoint_path: Path to the checkpoint tar.gz file """ - _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) - - adapter_lora_params = extract_adapter_state(adapter_index, lora_params, adapter_config.rank) - with download_and_unpack(checkpoint_path) as temp_dir: load_safetensors( temp_dir, model.config, - adapter_lora_params, + model, skip_lora=False, prefix="base_model.model.", filter_fn=lambda path: filter_lora(adapter_config, path), + adapter_index=adapter_index, + rank=adapter_config.rank, ) - insert_adapter_state(adapter_index, lora_params, adapter_lora_params, adapter_config.rank) def save_lora_checkpoint( @@ -246,10 +284,6 @@ def save_lora_checkpoint( adapter_index: Index of the adapter to save output_path: Path to save the checkpoint tar.gz file """ - _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) - - adapter_lora_params = extract_adapter_state(adapter_index, lora_params, adapter_config.rank) - peft_config = peft.LoraConfig( base_model_name_or_path=base_model_name, r=adapter_config.rank, lora_alpha=adapter_config.alpha ) @@ -257,10 +291,12 @@ def save_lora_checkpoint( with pack_and_upload(output_path) as temp_dir: save_safetensors( model.config, - adapter_lora_params, + model, temp_dir / "adapter_model.safetensors", prefix="base_model.model.", filter_fn=lambda path: filter_lora(adapter_config, path), + adapter_index=adapter_index, + rank=adapter_config.rank, ) peft_config.save_pretrained(temp_dir) From 6634d639615a010752619342e370019aab2e1fb9 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 6 Feb 2026 15:06:09 -0800 Subject: [PATCH 131/133] lint --- skyrl-tx/tests/utils/test_models.py | 23 +++++++++++++++++++++-- skyrl-tx/tx/layers/util.py | 4 +--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/skyrl-tx/tests/utils/test_models.py b/skyrl-tx/tests/utils/test_models.py index ce1c7004a..f230332b7 100644 --- a/skyrl-tx/tests/utils/test_models.py +++ b/skyrl-tx/tests/utils/test_models.py @@ -91,8 +91,27 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat "path,expected", [ # Stacked paths (DictKey) — real NNX paths include _stacked - ((DictKey(key="model"), DictKey(key="layers"), DictKey(key="_stacked"), DictKey(key="self_attn"), DictKey(key="lora_A")), True), - ((DictKey(key="model"), DictKey(key="layers"), DictKey(key="layer_groups"), DictKey(key="_stacked"), DictKey(key="self_attn"), DictKey(key="lora_A")), True), + ( + ( + DictKey(key="model"), + DictKey(key="layers"), + DictKey(key="_stacked"), + DictKey(key="self_attn"), + DictKey(key="lora_A"), + ), + True, + ), + ( + ( + DictKey(key="model"), + DictKey(key="layers"), + DictKey(key="layer_groups"), + DictKey(key="_stacked"), + DictKey(key="self_attn"), + DictKey(key="lora_A"), + ), + True, + ), # Non-stacked paths (DictKey) ((DictKey(key="model"), DictKey(key="embed_tokens"), DictKey(key="lora_A")), False), ((DictKey(key="lm_head"), DictKey(key="lora_A")), False), diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index d1fd80cac..d024f7cd7 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -110,9 +110,7 @@ def make_ep_spec(spec, value): return PartitionSpec(*(p if p == "ep" else None for p in truncated)) partition_specs = nnx.get_partition_spec(state) - state_specs = jax.tree.map( - make_ep_spec, partition_specs, state, is_leaf=lambda x: isinstance(x, PartitionSpec) - ) + state_specs = jax.tree.map(make_ep_spec, partition_specs, state, is_leaf=lambda x: isinstance(x, PartitionSpec)) in_specs = (state_specs,) + (PartitionSpec(),) * len(args) @jax.shard_map(mesh=get_abstract_mesh(), in_specs=in_specs, out_specs=PartitionSpec(), axis_names={"ep"}) From 4bd84c309b9779980cb02104d83bee610a669c73 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 6 Feb 2026 16:43:46 -0800 Subject: [PATCH 132/133] Fix filter_lora breaking init_lora_adapter The lora_A/lora_B guard in filter_lora caused init_lora_adapter to set effective_rank=0 for lora_ranks and lora_scaling paths (which don't contain lora_A/lora_B). Move the guard to the save/load call sites. Co-Authored-By: Claude Opus 4.6 --- skyrl-tx/tx/utils/models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index f9bd3ec25..a4c9d9d5c 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -233,9 +233,7 @@ def save_safetensors( def filter_lora(adapter_config: LoraConfig, path: tuple[str, ...]) -> bool: - """Check if a LoRA weight path matches the adapter config's training targets.""" - if "lora_A" not in path and "lora_B" not in path: - return False + """Check if a path's module matches the adapter config's training targets.""" if not adapter_config.train_attn and "self_attn" in path: return False if not adapter_config.train_mlp and ("mlp" in path or "experts" in path): @@ -263,7 +261,7 @@ def load_lora_checkpoint( model, skip_lora=False, prefix="base_model.model.", - filter_fn=lambda path: filter_lora(adapter_config, path), + filter_fn=lambda path: ("lora_A" in path or "lora_B" in path) and filter_lora(adapter_config, path), adapter_index=adapter_index, rank=adapter_config.rank, ) @@ -294,7 +292,7 @@ def save_lora_checkpoint( model, temp_dir / "adapter_model.safetensors", prefix="base_model.model.", - filter_fn=lambda path: filter_lora(adapter_config, path), + filter_fn=lambda path: ("lora_A" in path or "lora_B" in path) and filter_lora(adapter_config, path), adapter_index=adapter_index, rank=adapter_config.rank, ) From 8380f37e5665e1c5ceb68c3f11998620d405d764 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 6 Feb 2026 17:09:06 -0800 Subject: [PATCH 133/133] Increase learning rate in deepseekv3 LoRA training tests Stacked layers use scan which changes LoRA initialization order, producing slightly different he_uniform values. With bf16 precision, 1e-4 learning rate needs >10 steps to register a loss change. Bump to 1e-3 for reliable convergence within 10 steps. Co-Authored-By: Claude Opus 4.6 --- skyrl-tx/tests/models/test_deepseekv3_lora_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py index 3d46c56ec..f1fad1549 100644 --- a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py +++ b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py @@ -39,7 +39,7 @@ def test_lora_training_moe_rank_normalized(): init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) init_lora_adapter(model, adapter_index=1, lora_config=LoraConfig(rank=8, alpha=8, seed=1)) - optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=model.is_lora_param) + optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=model.is_lora_param) batch = jnp.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]], dtype=jnp.int32) target_ids = batch[:, 1:] @@ -117,7 +117,7 @@ def test_lora_training_high_rank(): init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) init_lora_adapter(model, adapter_index=1, lora_config=LoraConfig(rank=8, alpha=8, seed=1)) - optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=model.is_lora_param) + optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=model.is_lora_param) batch = jnp.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]], dtype=jnp.int32) target_ids = batch[:, 1:]