diff --git a/skyrl-tx/tests/models/lora_test_utils.py b/skyrl-tx/tests/models/lora_test_utils.py new file mode 100644 index 000000000..b83d583d6 --- /dev/null +++ b/skyrl-tx/tests/models/lora_test_utils.py @@ -0,0 +1,81 @@ +"""Shared test utilities for LoRA training tests.""" + +import jax +import jax.numpy as jnp + +from tx.utils.models import get_adapter_idx + + +def get_adapter_params(params, adapter_idx: int): + """Extract adapter params at a specific index. + + Decoder layer LoRA params have shape (num_layers, num_adapters, ...). + Embed tokens LoRA params have shape (num_adapters, ...). + """ + + def extract(path, p): + idx = get_adapter_idx(path, adapter_idx) + return p[idx].copy() + + return jax.tree.map_with_path(extract, params) + + +def _slice_out_of_rank(params, adapter_idx: int, get_rank): + """Extract out-of-rank params using a rank function. + + Args: + params: LoRA parameters tree. + adapter_idx: Adapter index to extract. + get_rank: Function (path) -> int returning effective rank for that path. + """ + + def slice_param(path, p): + path_str = str(path) + if "lora_A" not in path_str and "lora_B" not in path_str: + return p + rank = get_rank(path) + idx = get_adapter_idx(path, adapter_idx) + if "lora_A" in path_str: + return p[idx + (..., slice(rank, None))].copy() + return p[idx + (..., slice(rank, None), slice(None))].copy() + + return jax.tree.map_with_path(slice_param, params) + + +def get_out_of_rank_params(params, adapter_idx: int, rank: int): + """Extract out-of-rank params for an adapter.""" + return _slice_out_of_rank(params, adapter_idx, lambda _: rank) + + +def verify_params_unchanged(initial_params, final_params, error_msg_prefix: str): + """Verify that params haven't changed between initial and final state.""" + for (path, initial), (_, final) in zip( + jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) + ): + assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" + + +def _is_routed_expert_path(path) -> bool: + """Check if path is for routed experts (not shared_experts).""" + keys = [] + for p in path: + if hasattr(p, "key"): + keys.append(str(p.key)) + elif hasattr(p, "name"): + keys.append(str(p.name)) + for i, key in enumerate(keys): + if key == "experts" and i > 0 and keys[i - 1] == "mlp": + return True + return False + + +def get_moe_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: int): + """Extract out-of-rank params for MoE models. + + For routed experts, uses effective rank = max(1, rank // num_experts). + """ + + def get_rank(path): + return max(1, rank // num_experts) if _is_routed_expert_path(path) else rank + + return _slice_out_of_rank(params, adapter_idx, get_rank) diff --git a/skyrl-tx/tests/models/test_deepseekv3.py b/skyrl-tx/tests/models/test_deepseekv3.py index 23a15a639..ca168606e 100644 --- a/skyrl-tx/tests/models/test_deepseekv3.py +++ b/skyrl-tx/tests/models/test_deepseekv3.py @@ -186,3 +186,51 @@ def test_deepseekv3_moe_layer_lora(ep: int, tp: int): output_merged = moe_layer_merged(x_sample) assert np.allclose(output_with_lora[sample_idx : sample_idx + 1], output_merged, rtol=1e-3, atol=1e-3) + + +def test_deepseekv3_gradient_checkpointing(): + """Test that gradient checkpointing produces identical outputs for DeepSeekV3. + + DeepSeekV3 has split stacking (dense_layers + moe_layers), so this tests + that gradient checkpointing works correctly with heterogeneous layer types. + """ + model_name = "yujiepan/deepseek-v3-tiny-random" + base_config = PretrainedConfig.from_pretrained(model_name, trust_remote_code=True) + + batch_size, seq_len = 2, 8 + mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + + results = {} + for use_checkpointing in [False, True]: + config = DeepseekV3Config( + base_config, + max_lora_adapters=1, + max_lora_rank=1, + shard_attention_heads=True, + gradient_checkpointing=use_checkpointing, + ) + with jax.set_mesh(mesh): + model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + + input_ids = jax.random.randint(jax.random.key(42), (batch_size, seq_len), 0, config.vocab_size) + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + + out = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) + logits = model.compute_logits(out.last_hidden_state) + + results[use_checkpointing] = { + "logits": np.array(logits), + "hidden_states": [np.array(hs) for hs in out.hidden_states], + "kv_cache_len": len(out.kv_cache.keys), + } + + # Verify outputs match + np.testing.assert_allclose(results[False]["logits"], results[True]["logits"], rtol=1e-4, atol=1e-6) + + # Verify hidden states match + assert len(results[False]["hidden_states"]) == len(results[True]["hidden_states"]) + for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(results[False]["hidden_states"], results[True]["hidden_states"])): + np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}") + + # Verify KV cache has correct number of layers + assert results[True]["kv_cache_len"] == config.num_hidden_layers diff --git a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py index a767d0366..3d46c56ec 100644 --- a/skyrl-tx/tests/models/test_deepseekv3_lora_training.py +++ b/skyrl-tx/tests/models/test_deepseekv3_lora_training.py @@ -11,42 +11,11 @@ from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig - -def _is_routed_expert_path(path) -> bool: - """Disambiguate shared_experts and experts""" - keys = [] - for p in path: - if hasattr(p, "key"): - keys.append(str(p.key)) - elif hasattr(p, "name"): - keys.append(str(p.name)) - - for i, key in enumerate(keys): - if key == "experts" and i > 0 and keys[i - 1] == "mlp": - return True - return False - - -def _get_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: int): - """Extract out-of-rank params, using effective rank for routed expert layers.""" - - def slice_param(path, p): - path_str = str(path) - - if _is_routed_expert_path(path): - effective_rank = max(1, rank // num_experts) - else: - effective_rank = rank - - if "lora_A" in path_str: - # lora_A shape: [adapters, ..., max_rank] - slice last dim - return p[adapter_idx, ..., effective_rank:].copy() - elif "lora_B" in path_str: - # lora_B shape: [adapters, ..., max_rank, out] - slice second-to-last dim - return p[adapter_idx, ..., effective_rank:, :].copy() - return p - - return jax.tree.map_with_path(slice_param, params) +from tests.models.lora_test_utils import ( + get_adapter_params, + get_moe_out_of_rank_params, + verify_params_unchanged, +) def test_lora_training_moe_rank_normalized(): @@ -85,15 +54,12 @@ def loss_fn(model, input_ids, target_ids, attention_mask): graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) - def get_adapter_params(params, adapter_idx): - return jax.tree.map(lambda p: p[adapter_idx].copy(), params) - num_experts = config.n_routed_experts # Save initial states initial_adapter_2_params = get_adapter_params(lora_params, 2) - initial_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) - initial_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + initial_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts) + initial_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts) initial_loss = None @@ -116,12 +82,6 @@ def loss_for_lora(lora_params): final_loss = float(loss) - def verify_params_unchanged(initial_params, final_params, error_msg_prefix): - for (path, initial), (_, final) in zip( - jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) - ): - assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" - assert final_loss < initial_loss, f"Loss did not decrease: {initial_loss} -> {final_loss}" # Verify unused adapter was not modified @@ -129,11 +89,11 @@ def verify_params_unchanged(initial_params, final_params, error_msg_prefix): verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") # Verify out-of-rank params were not modified - final_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + final_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts) verify_params_unchanged( initial_adapter_0_out_of_rank, final_adapter_0_out_of_rank, "Adapter 0 out-of-rank params modified" ) - final_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + final_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts) verify_params_unchanged( initial_adapter_1_out_of_rank, final_adapter_1_out_of_rank, "Adapter 1 out-of-rank params modified" ) @@ -172,9 +132,6 @@ def loss_fn(model, input_ids, target_ids, attention_mask): graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) - def get_adapter_params(params, adapter_idx): - return jax.tree.map(lambda p: p[adapter_idx].copy(), params) - num_experts = config.n_routed_experts # Save initial states for all unused adapters @@ -183,8 +140,8 @@ def get_adapter_params(params, adapter_idx): initial_adapter_4_params = get_adapter_params(lora_params, 4) # Save out-of-rank params for adapters 0 and 1 - initial_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) - initial_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + initial_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts) + initial_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts) # Training loop for step in range(10): @@ -200,12 +157,6 @@ def loss_for_lora(lora_params): print(f"Step {step}: loss = {float(loss):.4f}") - def verify_params_unchanged(initial_params, final_params, error_msg_prefix): - for (path, initial), (_, final) in zip( - jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) - ): - assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" - # Verify unused adapters (2, 3, 4) were not modified final_adapter_2_params = get_adapter_params(lora_params, 2) verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") @@ -217,11 +168,11 @@ def verify_params_unchanged(initial_params, final_params, error_msg_prefix): verify_params_unchanged(initial_adapter_4_params, final_adapter_4_params, "Adapter 4 was modified") # Verify out-of-rank params were not modified - final_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + final_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts) verify_params_unchanged( initial_adapter_0_out_of_rank, final_adapter_0_out_of_rank, "Adapter 0 out-of-rank params modified" ) - final_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + final_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts) verify_params_unchanged( initial_adapter_1_out_of_rank, final_adapter_1_out_of_rank, "Adapter 1 out-of-rank params modified" ) diff --git a/skyrl-tx/tests/models/test_llama3_lora_training.py b/skyrl-tx/tests/models/test_llama3_lora_training.py index 61fa029c6..27b154650 100644 --- a/skyrl-tx/tests/models/test_llama3_lora_training.py +++ b/skyrl-tx/tests/models/test_llama3_lora_training.py @@ -11,6 +11,8 @@ from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig +from tests.models.lora_test_utils import get_adapter_params, get_out_of_rank_params, verify_params_unchanged + def test_lora_training(): base_model = "unsloth/Llama-3.2-1B" @@ -45,21 +47,6 @@ def loss_fn(model, input_ids, target_ids, attention_mask): # that we want to compute gradients for graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) - # Helper to extract adapter params at specific index - def get_adapter_params(params, adapter_idx): - return jax.tree.map(lambda p: p[adapter_idx].copy(), params) - - # Helper to extract out-of-rank params for an adapter - def get_out_of_rank_params(params, adapter_idx, rank): - def slice_param(path, p): - if "lora_A" in str(path): - return p[adapter_idx, :, rank:].copy() - elif "lora_B" in str(path): - return p[adapter_idx, rank:, :].copy() - return p - - return jax.tree.map_with_path(slice_param, params) - # Save initial states initial_adapter_2_params = get_adapter_params(lora_params, 2) initial_adapter_0_out_of_rank = get_out_of_rank_params(lora_params, 0, 16) @@ -79,12 +66,6 @@ def loss_for_lora(lora_params): print(f"Step {step}: loss = {float(loss):.4f}") - def verify_params_unchanged(initial_params, final_params, error_msg_prefix): - for (path, initial), (_, final) in zip( - jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) - ): - assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" - # Verify adapter 2 (unused) was not modified final_adapter_2_params = get_adapter_params(lora_params, 2) verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index a90973371..737eafca4 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -1,4 +1,5 @@ import tempfile +from typing import Any from flax import nnx import jax @@ -7,9 +8,10 @@ import pytest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -from tx.models.configs import Llama3Config, Qwen3Config +from tx.models.configs import Llama3Config, ModelConfig, Qwen3Config from tx.models.llama3 import Llama3ForCausalLM from tx.models.qwen3 import Qwen3ForCausalLM +from tx.models.types import CausalLMOutput, ModelForCausalLM from tx.utils.models import load_safetensors MODEL_PARAMS = [ @@ -19,26 +21,145 @@ MODEL_IDS = ["llama3", "qwen3"] -def load_model(tmp_dir, model_name, config_cls, model_cls, mesh_axes, *, loss_chunk_size=0): - """Load model from pre-saved weights directory.""" +def create_model( + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + *, + mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None, + seed: int = 0, + **config_kwargs: Any, +) -> tuple[ModelForCausalLM, ModelConfig]: + """Create model with random weights for testing.""" base_config = AutoConfig.from_pretrained(model_name) - config = config_cls( - base_config, - max_lora_adapters=1, - max_lora_rank=1, - shard_attention_heads=True, + config = config_cls(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True, **config_kwargs) + # Default to Auto axis types to avoid sharding resolution errors + if mesh_axis_types is None: + mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes) + mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=mesh_axis_types) + with jax.set_mesh(mesh): + model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed)) + return model, config + + +def load_model( + tmp_dir: str, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + *, + loss_chunk_size: int = 0, +) -> ModelForCausalLM: + """Load model from pre-saved weights directory.""" + model, config = create_model( + model_name, + config_cls, + model_cls, + mesh_axes, loss_chunk_size=loss_chunk_size, gradient_checkpointing=False, ) - mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * 2) - with jax.set_mesh(mesh): - model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) load_safetensors(tmp_dir, config, model) return model @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) -def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): +class TestGradientCheckpointing: + + def _forward( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + gradient_checkpointing: bool, + **forward_kwargs: Any, + ) -> tuple[ModelForCausalLM, ModelConfig, CausalLMOutput]: + """Create model, run forward pass, and return (model, config, out).""" + batch_size, seq_len = 2, 8 + model, config = create_model( + model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=gradient_checkpointing + ) + input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + out = model(input_ids, attention_mask=attention_mask, **forward_kwargs) + return model, config, out + + def test_output_matches_non_checkpointed( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + ) -> None: + """Forward pass should produce identical outputs with/without checkpointing.""" + model, _, out = self._forward(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=False) + logits_no_ckpt = model.compute_logits(out.last_hidden_state) + del model, out + + model, _, out = self._forward(model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=True) + logits_ckpt = model.compute_logits(out.last_hidden_state) + del model, out + + np.testing.assert_allclose(logits_no_ckpt, logits_ckpt, rtol=1e-4, atol=1e-6) + + def test_hidden_states_length_matches( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + ) -> None: + """Both paths should return same number of hidden states.""" + _, config, out = self._forward( + model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=False, output_hidden_states=True + ) + hidden_states_no_ckpt = out.hidden_states + num_hidden_layers = config.num_hidden_layers + del out + + _, _, out = self._forward( + model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=True, output_hidden_states=True + ) + hidden_states_ckpt = out.hidden_states + del out + + assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == num_hidden_layers + 1 + for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(hidden_states_no_ckpt, hidden_states_ckpt)): + np.testing.assert_allclose( + hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}" + ) + + def test_kv_cache_with_checkpointing( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + ) -> None: + """KV cache should be populated even with gradient checkpointing enabled.""" + model, config = create_model(model_name, config_cls, model_cls, mesh_axes) + config.gradient_checkpointing = True + + batch_size, seq_len = 2, 8 + input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + + out = model(input_ids, attention_mask=attention_mask) + + # keys is a list with one entry per layer + assert len(out.kv_cache.keys) == config.num_hidden_layers + + +@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) +def test_compute_logits( + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], +) -> None: """Test that model.compute_logits matches HuggingFace logits.""" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -65,7 +186,13 @@ def test_compute_logits(model_name, config_cls, model_cls, mesh_axes): @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) @pytest.mark.parametrize("chunk_size", [8, 16, 32]) -def test_chunked_logprobs(model_name, config_cls, model_cls, mesh_axes, chunk_size): +def test_chunked_logprobs( + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + chunk_size: int, +) -> None: """Test that chunked and non-chunked compute_logprobs produce identical results.""" tokenizer = AutoTokenizer.from_pretrained(model_name) inputs = ["The capital of France is", "Hello world"] diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index 55a779c9e..cf2316e2c 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -92,7 +92,7 @@ def load_lora_weights( scaling: float, rank: int, ) -> None: - """Load LoRA weights from numpy arrays to JAX module.""" + """Load LoRA weights from numpy arrays to JAX module (non-stacked modules like embed_tokens).""" assert ( jax_module.lora_A is not None and jax_module.lora_B is not None @@ -105,6 +105,28 @@ def load_lora_weights( jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[adapter_idx].set(rank) +def load_stacked_lora_weights( + jax_module: LoRAMixin, + layer_idx: int, + adapter_idx: int, + lora_A_weights: np.ndarray, + lora_B_weights: np.ndarray, + scaling: float, + rank: int, +) -> None: + """Load LoRA weights for a specific layer in stacked format (decoder layers).""" + assert ( + jax_module.lora_A is not None + and jax_module.lora_B is not None + and jax_module.lora_scaling is not None + and jax_module.lora_ranks is not None + ) + jax_module.lora_A[...] = jax_module.lora_A[...].at[layer_idx, adapter_idx].set(jnp.array(lora_A_weights)) + jax_module.lora_B[...] = jax_module.lora_B[...].at[layer_idx, adapter_idx].set(jnp.array(lora_B_weights)) + jax_module.lora_scaling[...] = jax_module.lora_scaling[...].at[layer_idx, adapter_idx].set(scaling) + jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[layer_idx, adapter_idx].set(rank) + + @pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)]) def test_qwen3_moe_layer_lora(ep: int, tp: int): """Test MoE LoRA by merging adapter into base weights and comparing outputs.""" @@ -245,17 +267,20 @@ def test_qwen3_lora(): rank=lora_config.r, ) - # Load layer LoRA weights - for i, layer in enumerate(model.model.layers): + # Load layer LoRA weights (stacked format) + # Access _stacked to get the stacked module with LoRA parameters + for i in range(config.num_hidden_layers): hf_layer = hf_model.base_model.model.model.layers[i] - for module, projections in [ + for module_name, projections in [ ("mlp", ["gate_proj", "up_proj", "down_proj"]), ("self_attn", ["q_proj", "k_proj", "v_proj", "o_proj"]), ]: for proj_name in projections: - hf_proj = getattr(getattr(hf_layer, module), proj_name) - load_lora_weights( - getattr(getattr(layer, module), proj_name), + hf_proj = getattr(getattr(hf_layer, module_name), proj_name) + jax_proj = getattr(getattr(model.model.layers._stacked, module_name), proj_name) + load_stacked_lora_weights( + jax_proj, + layer_idx=i, adapter_idx=adapter_idx, lora_A_weights=hf_proj.lora_A["default"].weight.detach().numpy().T, lora_B_weights=hf_proj.lora_B["default"].weight.detach().numpy().T, diff --git a/skyrl-tx/tests/models/test_qwen3_lora_training.py b/skyrl-tx/tests/models/test_qwen3_lora_training.py index 85f5f3bda..a5873f506 100644 --- a/skyrl-tx/tests/models/test_qwen3_lora_training.py +++ b/skyrl-tx/tests/models/test_qwen3_lora_training.py @@ -11,6 +11,8 @@ from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig +from tests.models.lora_test_utils import get_adapter_params, get_out_of_rank_params, verify_params_unchanged + def test_lora_training(): base_model = "Qwen/Qwen3-0.6B" @@ -45,21 +47,6 @@ def loss_fn(model, input_ids, target_ids, attention_mask): # that we want to compute gradients for graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) - # Helper to extract adapter params at specific index - def get_adapter_params(params, adapter_idx): - return jax.tree.map(lambda p: p[adapter_idx].copy(), params) - - # Helper to extract out-of-rank params for an adapter - def get_out_of_rank_params(params, adapter_idx, rank): - def slice_param(path, p): - if "lora_A" in str(path): - return p[adapter_idx, :, rank:].copy() - elif "lora_B" in str(path): - return p[adapter_idx, rank:, :].copy() - return p - - return jax.tree.map_with_path(slice_param, params) - # Save initial states initial_adapter_2_params = get_adapter_params(lora_params, 2) initial_adapter_0_out_of_rank = get_out_of_rank_params(lora_params, 0, 16) @@ -79,12 +66,6 @@ def loss_for_lora(lora_params): print(f"Step {step}: loss = {float(loss):.4f}") - def verify_params_unchanged(initial_params, final_params, error_msg_prefix): - for (path, initial), (_, final) in zip( - jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) - ): - assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" - # Verify adapter 2 (unused) was not modified final_adapter_2_params = get_adapter_params(lora_params, 2) verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index c5242737b..5ffa5d60c 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -108,17 +108,18 @@ def test_clear_lora_adapter(): # Verify adapter has non-zero rank after creation model = backend.model - lora_layer: LoRALinear = model.model.layers[0].self_attn.q_proj - assert lora_layer.lora_ranks[adapter_idx] > 0 + # With stacked layers, lora_ranks has shape (num_layers, num_adapters) + lora_layer: LoRALinear = model.model.layers.self_attn.q_proj + assert lora_layer.lora_ranks[0, adapter_idx] > 0 # Delete the model (calls clear_lora_adapter internally) backend.delete_model(model_id) - # Verify adapter state is zeroed - assert lora_layer.lora_ranks[adapter_idx] == 0 - assert lora_layer.lora_scaling[adapter_idx] == 0.0 - assert (lora_layer.lora_A[adapter_idx] == 0.0).all() - assert (lora_layer.lora_B[adapter_idx] == 0.0).all() + # Verify adapter state is zeroed (check layer 0) + assert lora_layer.lora_ranks[0, adapter_idx] == 0 + assert lora_layer.lora_scaling[0, adapter_idx] == 0.0 + assert (lora_layer.lora_A[0, adapter_idx] == 0.0).all() + assert (lora_layer.lora_B[0, adapter_idx] == 0.0).all() def make_fwd_bwd_input(token_lists: list[list[int]]) -> types.ForwardBackwardInput: @@ -330,9 +331,10 @@ def apply_step(request_id: int, model_id: str, optim_input: OptimStepInput) -> f def test_gradient_checkpointing(): """ - Verify gradient checkpointing doesn't affect loss values. + Verify gradient checkpointing doesn't affect loss values or gradients. """ losses = [] + grads_list = [] for use_gradient_checkpointing in (False, True): config = JaxBackendConfig( max_lora_adapters=1, @@ -354,8 +356,8 @@ def test_gradient_checkpointing(): sampling_logprobs = jnp.zeros((B, T), dtype=jnp.float32) advantages = jnp.zeros((B, T), dtype=jnp.float32) - # Compute loss, using gradient checkpointing if enabled - _, per_token_losses, _ = backend._forward_backward_and_accumulate( + # Compute loss and gradients, using gradient checkpointing if enabled + accumulated_grads, per_token_losses, _ = backend._forward_backward_and_accumulate( backend.accumulated_grads, backend.lora_params, backend.non_lora_params, @@ -369,10 +371,14 @@ def test_gradient_checkpointing(): advantages, ) losses.append(float(per_token_losses.mean())) + grads_list.append(accumulated_grads.grad_sum) # Check relative difference between losses is small assert abs(losses[0] - losses[1]) / abs(losses[0]) < 5e-3 + # Check gradients match + _assert_tree_allclose(grads_list[0], grads_list[1], rtol=1e-3, atol=1e-3, min_match_pct=99.0) + def make_sample_input(tokens: list[int], prompt_logprobs: bool = False, max_tokens: int = 16) -> types.SampleInput: """Build a SampleInput for testing.""" @@ -529,20 +535,21 @@ def test_adapter_reuse_initializes_lora_adapter(): # (slot 0 is reserved for base model) backend = create_backend(max_lora_adapters=2) model = backend.model - lora_layer: LoRALinear = model.model.layers[0].self_attn.q_proj + # With stacked layers, lora_A has shape (num_layers, num_adapters, in_features, max_rank) + lora_layer: LoRALinear = model.model.layers.self_attn.q_proj # Create first model model_id_1 = "model_1" adapter_idx = create_model(backend, model_id_1) - # Verify lora_A is non-zero after creation + # Verify lora_A is non-zero after creation (check layer 0) assert not ( - lora_layer.lora_A[adapter_idx, ..., :LORA_RANK] == 0.0 + lora_layer.lora_A[0, adapter_idx, ..., :LORA_RANK] == 0.0 ).all(), "lora_A should be initialized with he_uniform (non-zero)" # Delete the model (clears both lora_A and lora_B to zeros) backend.delete_model(model_id_1) - assert (lora_layer.lora_A[adapter_idx] == 0.0).all(), "lora_A should be zeroed after clear_lora_adapter" + assert (lora_layer.lora_A[0, adapter_idx] == 0.0).all(), "lora_A should be zeroed after clear_lora_adapter" # Create a new model that reuses the same adapter slot model_id_2 = "model_2" @@ -551,11 +558,11 @@ def test_adapter_reuse_initializes_lora_adapter(): # Verify lora_A is initialized (non-zero) assert not ( - lora_layer.lora_A[adapter_idx, ..., :LORA_RANK] == 0.0 + lora_layer.lora_A[0, adapter_idx, ..., :LORA_RANK] == 0.0 ).all(), "lora_A should be initialized with he_uniform after adapter reuse" # Verify lora_B is zeros - assert (lora_layer.lora_B[adapter_idx] == 0.0).all(), "lora_B should be zeros" + assert (lora_layer.lora_B[0, adapter_idx] == 0.0).all(), "lora_B should be zeros" def test_mixed_train_unembed_adapters(): diff --git a/skyrl-tx/tests/utils/test_generator.py b/skyrl-tx/tests/utils/test_generator.py index 89bc637be..598a0b12e 100644 --- a/skyrl-tx/tests/utils/test_generator.py +++ b/skyrl-tx/tests/utils/test_generator.py @@ -52,6 +52,7 @@ def __call__( if kv_cache is None: # Prefill: deterministic hidden_states (which equal logits) hidden_states = jnp.tile(base[None, None, :], (batch_size, seq_len, 1)) + # List format: one entry per layer (use 1 layer for this dummy model) keys = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] values = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)] # Per-sequence cache_position (all same length in this test) diff --git a/skyrl-tx/tests/utils/test_models.py b/skyrl-tx/tests/utils/test_models.py index 70c177fe3..747ab0e66 100644 --- a/skyrl-tx/tests/utils/test_models.py +++ b/skyrl-tx/tests/utils/test_models.py @@ -11,11 +11,14 @@ from peft import PeftModel from transformers import AutoConfig, AutoModelForCausalLM +from jax.tree_util import DictKey + from tx.layers.lora import init_lora_adapter from tx.models.configs import Qwen3Config from tx.models.qwen3 import Qwen3ForCausalLM from tx.tinker.types import LoraConfig from tx.utils import models +from tx.utils.models import extract_adapter_state, insert_adapter_state, is_stacked_lora_path from tx.utils.storage import download_and_unpack @@ -55,14 +58,18 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat adapter_config = LoraConfig(rank=rank, alpha=alpha, seed=0) # Set LoRA weights to random values for testing (to catch transpose bugs) - q_proj = model.model.layers[0].self_attn.q_proj + # layers is now stacked, so access directly (not subscriptable) + # LoRA weights have shape (num_layers, num_adapters, ...) for stacked layers + q_proj = model.model.layers.self_attn.q_proj rng1, rng2 = jax.random.split(jax.random.PRNGKey(42)) q_proj.lora_A[...] = jax.random.normal(rng1, q_proj.lora_A[...].shape) q_proj.lora_B[...] = jax.random.normal(rng2, q_proj.lora_B[...].shape) # Store expected values (trimmed to rank and transposed) - expected_lora_A = np.array(q_proj.lora_A[...][adapter_index, :, :rank].T) - expected_lora_B = np.array(q_proj.lora_B[...][adapter_index, :rank, :].T) + # For stacked layers: shape is (num_layers, num_adapters, in_dim, rank) for lora_A + # We have 1 layer, so index [0] for layer, then adapter_index + expected_lora_A = np.array(q_proj.lora_A[...][0, adapter_index, :, :rank].T) + expected_lora_B = np.array(q_proj.lora_B[...][0, adapter_index, :rank, :].T) # Save and verify checkpoint exists models.save_lora_checkpoint(model, base_model_name, adapter_config, adapter_index, output_path) @@ -82,3 +89,83 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat assert torch.allclose(lora_A, torch.from_numpy(expected_lora_A), atol=1e-6) assert torch.allclose(lora_B, torch.from_numpy(expected_lora_B), atol=1e-6) + + +@pytest.mark.parametrize( + "path,expected", + [ + # Stacked paths (DictKey) + ((DictKey(key="model"), DictKey(key="layers"), DictKey(key="self_attn"), DictKey(key="lora_A")), True), + ((DictKey(key="model"), DictKey(key="dense_layers"), DictKey(key="self_attn"), DictKey(key="lora_A")), True), + ((DictKey(key="model"), DictKey(key="moe_layers"), DictKey(key="mlp"), DictKey(key="lora_A")), True), + # Non-stacked paths (DictKey) + ((DictKey(key="model"), DictKey(key="embed_tokens"), DictKey(key="lora_A")), False), + ((DictKey(key="lm_head"), DictKey(key="lora_A")), False), + # String paths + (("model", "layers", "self_attn", "lora_A"), True), + (("model", "embed_tokens", "lora_A"), False), + ], + ids=["layers", "dense_layers", "moe_layers", "embed_tokens", "lm_head", "str_layers", "str_embed"], +) +def test_is_stacked_lora_path(path, expected): + """Test is_stacked_lora_path correctly identifies stacked vs non-stacked paths.""" + assert is_stacked_lora_path(path) is expected + + +def test_extract_insert_adapter_state_roundtrip(): + """Test that extract_adapter_state and insert_adapter_state are inverses.""" + base_model_name = "Qwen/Qwen3-0.6B" + rank, alpha, adapter_index = 8, 16, 2 + _, _, model = create_test_model(base_model_name, rank, alpha, adapter_index) + + # Set LoRA weights to random values + q_proj = model.model.layers.self_attn.q_proj + rng1, rng2 = jax.random.split(jax.random.PRNGKey(123)) + q_proj.lora_A[...] = jax.random.normal(rng1, q_proj.lora_A[...].shape) + q_proj.lora_B[...] = jax.random.normal(rng2, q_proj.lora_B[...].shape) + + # Split model to get lora_params + _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) + + # Store original values for comparison + original_lora_A = np.array(q_proj.lora_A[...][0, adapter_index, :, :rank]) + original_lora_B = np.array(q_proj.lora_B[...][0, adapter_index, :rank, :]) + + # Extract adapter state + extracted = extract_adapter_state(adapter_index, lora_params, rank) + + # Verify extracted shape is correct (no adapter dimension) + for path, leaf in jax.tree.leaves_with_path(extracted): + key = path[-2].key if hasattr(path[-2], "key") else str(path[-2]) + if key in {"lora_A", "lora_B"}: + # Stacked: should have (num_layers, ...) not (num_layers, num_adapters, ...) + if is_stacked_lora_path(path): + assert leaf.shape[0] == 1 # num_layers + assert leaf.ndim == 3 # (layers, in_dim, rank) or (layers, rank, out_dim) + + # Zero out the adapter's weights + q_proj.lora_A[...] = q_proj.lora_A[...].at[0, adapter_index].set(0) + q_proj.lora_B[...] = q_proj.lora_B[...].at[0, adapter_index].set(0) + + # Verify weights are zeroed + assert np.allclose(q_proj.lora_A[...][0, adapter_index], 0) + assert np.allclose(q_proj.lora_B[...][0, adapter_index], 0) + + # Re-split to get updated lora_params + _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) + + # Insert extracted state back (modifies lora_params in-place via nnx.update) + insert_adapter_state(adapter_index, lora_params, extracted, rank) + + # Verify weights are restored by checking lora_params directly + for path, leaf in jax.tree.leaves_with_path(lora_params): + key = path[-2].key if hasattr(path[-2], "key") else str(path[-2]) + # leaf is a state wrapper with .value, or can be an array directly + arr = leaf.value if hasattr(leaf, "value") else leaf + if "q_proj" in str(path) and key == "lora_A": + restored_lora_A = np.array(arr[0, adapter_index, :, :rank]) + elif "q_proj" in str(path) and key == "lora_B": + restored_lora_B = np.array(arr[0, adapter_index, :rank, :]) + + assert np.allclose(original_lora_A, restored_lora_A), "lora_A not restored correctly" + assert np.allclose(original_lora_B, restored_lora_B), "lora_B not restored correctly" diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 776b7af59..3a32387c4 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -2,7 +2,7 @@ import jax from jax import numpy as jnp -from tx.utils.models import filter_lora +from tx.utils.models import filter_lora, get_adapter_idx from tx.layers.util import Param, prepare_routing, ragged_dot from tx.models.types import ModelForCausalLM from tx.tinker.types import LoraConfig @@ -345,21 +345,22 @@ def init_adapter(path, value): if not filter_lora(lora_config, normalized_path): effective_rank = 0 + idx = get_adapter_idx(path, adapter_index) + key_name = path[-2].key if key_name == "lora_ranks": - return value.at[adapter_index].set(effective_rank) + return value.at[idx].set(effective_rank) if key_name == "lora_scaling": - # Set scaling to 0.0 if rank is 0 - return value.at[adapter_index].set(lora_config.alpha / effective_rank if effective_rank > 0 else 0.0) + scaling = lora_config.alpha / effective_rank if effective_rank > 0 else 0.0 + return value.at[idx].set(scaling) if key_name == "lora_A": # Reinitialize with he_uniform, then zero columns beyond rank - shape = value[adapter_index].shape - new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype) + new_A = nnx.initializers.he_uniform()(rngs.params(), value[idx].shape, value.dtype) new_A = new_A.at[..., effective_rank:].set(0.0) - return value.at[adapter_index].set(new_A) + return value.at[idx].set(new_A) if key_name == "lora_B": # Explicitly zero lora_B - return value.at[adapter_index].set(0.0) + return value.at[idx].set(0.0) return value updated_state = jax.tree.map_with_path(init_adapter, state) @@ -376,11 +377,10 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int): def clear_adapter(path, value): key = path[-2].key - if key == "lora_ranks": - return value.at[adapter_index].set(0) - if key in ("lora_scaling", "lora_A", "lora_B"): - return value.at[adapter_index].set(0.0) - return value + if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"): + return value + idx = get_adapter_idx(path, adapter_index) + return value.at[idx].set(0 if key == "lora_ranks" else 0.0) updated_state = jax.tree.map_with_path(clear_adapter, state) nnx.update(model, updated_state) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py new file mode 100644 index 000000000..f49322a83 --- /dev/null +++ b/skyrl-tx/tx/layers/stacked.py @@ -0,0 +1,463 @@ +"""StackedDecoderLayers module for efficient transformer layer stacking.""" + +import functools +from typing import Callable + +from flax import nnx +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec + +from tx.utils.generator import KVCache + + +class ArrayRef(nnx.Variable): + """A Variable providing a view into an indexed slice of a parent Variable.""" + + def __init__(self, parent: nnx.Variable, idx: int): + super().__init__(parent[idx]) + self.set_metadata("_parent", parent) + self.set_metadata("_idx", idx) + + def __getitem__(self, key): + parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") + return parent[idx] if key is Ellipsis else parent[idx][key] + + def __setitem__(self, key, value): + """Write through to parent when value is set via indexing.""" + parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") + if key is Ellipsis: + # param[...] = value -> update entire slice + parent[...] = parent[...].at[idx].set(value) + else: + # param[key] = value -> update sub-slice + parent[...] = parent[...].at[idx][key].set(value) + # Also update our local value + super().__setitem__(key, value) + + def set_raw_value(self, value, **kwargs): + """Write through to parent when value is set.""" + parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") + parent[...] = parent[...].at[idx].set(value) + super().set_raw_value(value, **kwargs) + + @property + def shape(self): + return self.get_metadata("_parent")[self.get_metadata("_idx")].shape + + +class StackedDecoderLayers(nnx.Module): + """Decoder layers with stacked weights for efficient scan-based forward pass. + + Parameters are stored in stacked format (num_layers, ...). The forward pass + uses jax.lax.scan for all modes (training/prefill/decode) with KV cache as + scan carry for efficient buffer donation. + + This class encapsulates both layer creation and forward pass logic. + """ + + def __init__( + self, + create_layer_fn: Callable[[nnx.Rngs], nnx.Module], + num_layers: int, + rngs: nnx.Rngs, + ): + """Create stacked decoder layers. + + This creates a single _stacked module where all parameters have shape (num_layers, ...). + Layers are created individually and stacked to avoid nnx.vmap memory overhead. + + Args: + create_layer_fn: Function that takes rngs and returns a single layer module. + num_layers: Number of layers to create. Can be 0 for empty layer stack. + rngs: Random number generators for initialization. + """ + self.num_layers = num_layers + + # Handle empty layer case + if num_layers == 0: + self._stacked = None + return + + layer_keys = jax.random.split(rngs.params(), num_layers) + mesh = jax.sharding.get_mesh() + + # Create first layer to get structure and shapes + first_layer = create_layer_fn(nnx.Rngs(layer_keys[0])) + graphdef, first_state = nnx.split(first_layer) + flat_first, treedef = jax.tree_util.tree_flatten(first_state) + + # Pre-allocate stacked arrays with correct sharding + stacked_flat = [] + for arr in flat_first: + stacked_shape = (num_layers,) + arr.shape + original_sharding = arr.sharding + if hasattr(original_sharding, "spec"): + new_spec = PartitionSpec(None, *original_sharding.spec) + stacked = jax.device_put(jnp.zeros(stacked_shape, arr.dtype), NamedSharding(mesh, new_spec)) + else: + stacked = jnp.zeros(stacked_shape, arr.dtype) + stacked_flat.append(stacked) + + # JIT with donate_argnums enables buffer reuse + @functools.partial(jax.jit, donate_argnums=(0,)) + def copy_to_slice(stacked, arr, idx): + return stacked.at[idx].set(arr) + + # Copy first layer's params to slot 0 + for i, arr in enumerate(flat_first): + stacked_flat[i] = copy_to_slice(stacked_flat[i], flat_first[i], 0) + + # Create remaining layers one at a time and copy params + for layer_idx in range(1, num_layers): + layer = create_layer_fn(nnx.Rngs(layer_keys[layer_idx])) + _, state = nnx.split(layer) + flat, _ = jax.tree_util.tree_flatten(state) + for i, arr in enumerate(flat): + stacked_flat[i] = copy_to_slice(stacked_flat[i], flat[i], layer_idx) + + # Reconstruct state from stacked arrays + stacked_state = jax.tree_util.tree_unflatten(treedef, stacked_flat) + + # Sync NNX sharding metadata with actual array sharding. + # The arrays have correct stacked sharding from device_put, but NNX APIs + # (nnx.get_partition_spec, nnx.Optimizer) read from 'sharding_names' metadata. + def update_sharding_metadata(var): + if isinstance(var, nnx.Variable) and hasattr(var.value, "sharding"): + array_sharding = var.value.sharding + if hasattr(array_sharding, "spec"): + var.set_metadata("sharding_names", tuple(array_sharding.spec)) + return var + + jax.tree.map(update_sharding_metadata, stacked_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + self._stacked = nnx.merge(graphdef, stacked_state) + + def __len__(self) -> int: + """Return the number of layers.""" + return self.num_layers + + def __getitem__(self, index: int) -> nnx.Module: + """Get view into layer at index (stays synced with stacked state).""" + if index < 0 or index >= self.num_layers: + raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") + graphdef, state = nnx.split(self._stacked) + layer_state = jax.tree.map( + lambda x: ArrayRef(x, index), + state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + return nnx.merge(graphdef, layer_state) + + def __iter__(self): + """Iterate over individual layers (for testing/weight loading).""" + for i in range(self.num_layers): + yield self[i] + + def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tuple[tuple, ArrayRef]]: + """Transform _stacked paths to per-layer paths with ArrayRef. + + Args: + state: GraphState containing this module's state. + base_path: Path prefix to this module (e.g., ("model", "layers")). + + Returns: + List of (path, ArrayRef) tuples for unstacked parameters. + """ + result = [] + for path, param in nnx.to_flat_state(state): + # Only process paths belonging to this module + if not path[: len(base_path)] == base_path: + continue + # Only process _stacked paths + if "_stacked" not in path[len(base_path) :]: + continue + + # Find _stacked in the relative path + rel_path = path[len(base_path) :] + stacked_idx = rel_path.index("_stacked") + + # Create per-layer paths: base_path + (layer_idx,) + rest + for layer_idx in range(self.num_layers): + new_path = base_path + (str(layer_idx),) + rel_path[stacked_idx + 1 :] + result.append((new_path, ArrayRef(param, layer_idx))) + + return result + + def __call__( + self, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + kv_cache: KVCache | None, + output_hidden_states: bool, + gradient_checkpointing: bool, + is_training: bool = False, + ) -> tuple[jax.Array, list[jax.Array], KVCache | None]: + """Forward pass through all layers. + + Uses scan for prefill/training (efficient, no KV cache needed). + Uses Python loop for decode (with list-format KV cache) to enable buffer donation. + + Args: + hidden_states: Input hidden states of shape (batch, seq, hidden). + attention_mask: Attention mask of shape (batch, seq). + positions: Position indices of shape (batch, seq). + adapter_indices: Optional LoRA adapter indices of shape (batch,). + kv_cache: Optional KV cache for decode mode (None for prefill). Uses list format. + output_hidden_states: Whether to return intermediate hidden states. + gradient_checkpointing: Whether to use gradient checkpointing. + is_training: Whether in training mode. Skips KV cache to save memory. + + Returns: + Tuple of (final_hidden_states, all_hidden_states, kv_cache). + kv_cache is None when is_training=True. + """ + # Handle empty layer case - pass through inputs unchanged + if self.num_layers == 0: + return hidden_states, [], kv_cache + + graphdef, state = nnx.split(self._stacked) + is_decode = kv_cache is not None + + if is_decode: + # Decode mode: Use Python loop with list KV cache for buffer donation. + # We avoid jax.lax.scan here because carrying a stacked KV cache through scan + # and updating it with cache.at[layer_idx].set() causes XLA to copy the entire + # cache array on each layer (16MB per layer). XLA can't prove the buffer can be + # donated since it doesn't know the slices are non-overlapping. With a Python + # loop and list format, each layer's KV array is independent and can be donated. + flat_state, treedef = jax.tree_util.tree_flatten(state) + all_hidden_states: list[jax.Array] = [] + updated_keys: list[jax.Array] = [] + updated_values: list[jax.Array] = [] + + for layer_idx in range(self.num_layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + # Extract this layer's parameters + layer_params_flat = [p[layer_idx] for p in flat_state] + layer_params = jax.tree_util.tree_unflatten(treedef, layer_params_flat) + layer = nnx.merge(graphdef, layer_params) + + # Get this layer's KV cache + layer_kv = (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]) + + hidden_states, (k, v) = layer( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=layer_kv, + ) + updated_keys.append(k) + updated_values.append(v) + + new_kv_cache = KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask) + return hidden_states, all_hidden_states, new_kv_cache + + # Prefill/training mode: use scan for efficiency + def body_fn(carry, layer_params): + hs = carry + + # Forward through layer (no KV cache input for prefill) + layer = nnx.merge(graphdef, layer_params) + new_hs, (k, v) = layer( + hs, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=None, + ) + + hs_output = new_hs if output_hidden_states else None + + # Skip KV accumulation in training mode to save memory + if is_training: + k = v = None + + return new_hs, (hs_output, k, v) + + if gradient_checkpointing: + body_fn = jax.checkpoint(body_fn) + + final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, state) + + if is_training: + new_kv_cache = None + else: + # Convert stacked scan outputs to list format + keys_list = [all_keys[i] for i in range(self.num_layers)] + values_list = [all_values[i] for i in range(self.num_layers)] + new_kv_cache = KVCache.update(None, keys_list, values_list, positions, attention_mask) + + all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] + return final_hs, all_hidden_states, new_kv_cache + + +class MultiStackedDecoderLayers(nnx.Module): + """Multiple StackedDecoderLayers groups with unified interface. + + This allows models like DeepSeek to have different layer types (dense/MoE) + while presenting a unified interface for forward passes and checkpointing. + """ + + def __init__(self, *layer_groups: StackedDecoderLayers): + """Create multi-stacked decoder layers. + + Args: + *layer_groups: One or more StackedDecoderLayers to combine. + """ + self.layer_groups = nnx.List(layer_groups) + self.num_layers = sum(group.num_layers for group in self.layer_groups) + + def __len__(self) -> int: + """Return the total number of layers across all groups.""" + return self.num_layers + + def __getitem__(self, index: int) -> nnx.Module: + """Get view into layer at global index (across all groups).""" + if index < 0 or index >= self.num_layers: + raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") + + # Find which group contains this index + offset = 0 + for group in self.layer_groups: + if index < offset + group.num_layers: + return group[index - offset] + offset += group.num_layers + + raise IndexError(f"Layer index {index} not found") + + def __iter__(self): + """Iterate over all layers across all groups.""" + for group in self.layer_groups: + yield from group + + def get_stacked_layers_list(self) -> list[StackedDecoderLayers]: + """Return list of StackedDecoderLayers for checkpoint loading.""" + return list(self.layer_groups) + + def unstack_paths(self, state: nnx.GraphState, base_path: tuple = ()) -> list[tuple[tuple, ArrayRef]]: + """Transform _stacked paths from all groups to unified per-layer paths. + + Args: + state: GraphState containing this module's state. + base_path: Path prefix to this module (e.g., ("model", "layers")). + + Returns: + List of (path, ArrayRef) tuples for unstacked parameters. + """ + result = [] + checkpoint_idx = 0 + + for i, group in enumerate(self.layer_groups): + # Path to this group: base_path + ("layer_groups", i) + group_path = base_path + ("layer_groups", i) + + # Get unstacked paths from the group + for path, array_ref in group.unstack_paths(state, group_path): + # Extract layer index from path: group_path + (layer_idx,) + rest + layer_idx = int(path[len(group_path)]) + # New path: base_path + (checkpoint_idx + layer_idx,) + rest + new_path = base_path + (str(checkpoint_idx + layer_idx),) + path[len(group_path) + 1 :] + result.append((new_path, array_ref)) + + checkpoint_idx += group.num_layers + + return result + + def __call__( + self, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + kv_cache: KVCache | None, + output_hidden_states: bool, + gradient_checkpointing: bool, + is_training: bool = False, + ) -> tuple[jax.Array, list[jax.Array], KVCache | None]: + """Forward pass through all layer groups. + + Args: + hidden_states: Input hidden states of shape (batch, seq, hidden). + attention_mask: Attention mask of shape (batch, seq). + positions: Position indices of shape (batch, seq). + adapter_indices: Optional LoRA adapter indices of shape (batch,). + kv_cache: Optional KV cache for decode mode (None for prefill). + output_hidden_states: Whether to return intermediate hidden states. + gradient_checkpointing: Whether to use gradient checkpointing. + is_training: Whether in training mode. Skips KV cache to save memory. + + Returns: + Tuple of (final_hidden_states, all_hidden_states, kv_cache). + """ + all_hidden_states: list[jax.Array] = [] + + # Split KV cache for each group + if kv_cache is not None: + split_points = [] + cumsum = 0 + for group in self.layer_groups[:-1]: + cumsum += group.num_layers + split_points.append(cumsum) + kv_caches = kv_cache.split(*split_points) + else: + kv_caches = (None,) * len(self.layer_groups) + + # Forward through each group + kv_results = [] + for group, group_kv_cache in zip(self.layer_groups, kv_caches): + hidden_states, layer_hidden_states, layer_kv_cache = group( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=group_kv_cache, + output_hidden_states=output_hidden_states, + gradient_checkpointing=gradient_checkpointing, + is_training=is_training, + ) + all_hidden_states.extend(layer_hidden_states) + kv_results.append(layer_kv_cache) + + # Concatenate KV caches + new_kv_cache = KVCache.concatenate(*kv_results) if kv_results else None + + return hidden_states, all_hidden_states, new_kv_cache + + +def unstack_state(module: nnx.Module) -> nnx.GraphState: + """Transform stacked layer state to unstacked ArrayRef views. + + Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc. + Each entry is an ArrayRef that writes through to the original stacked variable. + + This is useful for checkpoint loading where weights are stored per-layer. + + Args: + module: Module containing StackedDecoderLayers. + + Returns: + GraphState with unstacked paths and ArrayRef views. + """ + state = nnx.state(module) + expanded = [] + + # Delegate to layers if they support unstacking + if hasattr(module, "model") and hasattr(module.model, "layers"): + layers = module.model.layers + if isinstance(layers, (StackedDecoderLayers, MultiStackedDecoderLayers)): + expanded.extend(layers.unstack_paths(state, base_path=("model", "layers"))) + + # Keep all non-stacked paths as-is + for path, param in nnx.to_flat_state(state): + if "_stacked" not in path: + expanded.append((path, param)) + + return nnx.from_flat_state(expanded) diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 0030c604d..e0f596d94 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -96,12 +96,21 @@ def shard_map_ep(module: nnx.Module, func, *args): *args: Arguments to pass to func (replicated across shards). """ graphdef, state = nnx.split(module) - # Extract only 'ep' dims from PartitionSpecs, replacing others with None - state_specs = jax.tree.map( - lambda s: PartitionSpec(*(p if p == "ep" else None for p in s)) if isinstance(s, PartitionSpec) else s, - nnx.get_partition_spec(state), - is_leaf=lambda x: isinstance(x, PartitionSpec), - ) + + def make_ep_spec(spec, value): + """Create a PartitionSpec with only 'ep' dims, truncated to match tensor rank.""" + if not isinstance(spec, PartitionSpec): + return spec + # When a layer is extracted from stacked layers via x[layer_idx], the tensor + # loses a dimension but the PartitionSpec metadata still has the extra leading None. + # Truncate the spec to match the actual tensor rank. + arr = value.value if hasattr(value, "value") else value + rank = len(arr.shape) if hasattr(arr, "shape") else 0 + truncated = tuple(spec)[-rank:] if rank > 0 else () + return PartitionSpec(*(p if p == "ep" else None for p in truncated)) + + partition_specs = nnx.get_partition_spec(state) + state_specs = jax.tree.map(make_ep_spec, partition_specs, state, is_leaf=lambda x: isinstance(x, PartitionSpec)) in_specs = (state_specs,) + (PartitionSpec(),) * len(args) @jax.shard_map(mesh=get_abstract_mesh(), in_specs=in_specs, out_specs=PartitionSpec(), axis_names={"ep"}) diff --git a/skyrl-tx/tx/models/configs.py b/skyrl-tx/tx/models/configs.py index 15e011388..398d8c042 100644 --- a/skyrl-tx/tx/models/configs.py +++ b/skyrl-tx/tx/models/configs.py @@ -15,7 +15,7 @@ class ModelConfig(PretrainedConfig): max_lora_rank: Maximum rank for LoRA adapters shard_attention_heads: Whether to shard attention across tensor parallel devices loss_chunk_size: Chunk size for cross-entropy loss computation (0 = no chunking) - gradient_checkpointing: Whether to use gradient checkpointing for chunked loss + gradient_checkpointing: Recompute activations during backward to save memory """ # Type hints for config attributes diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 360cd2ca5..00b28ffe4 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -7,6 +7,7 @@ from tx.layers.rotary_embedding import get_rope from tx.layers.util import Param, prepare_routing, shard_map_ep from tx.layers.layernorm import RMSNorm +from tx.layers.stacked import MultiStackedDecoderLayers, StackedDecoderLayers from tx.models.configs import DeepseekV3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache @@ -405,6 +406,9 @@ def __call__( router_logits = self.gate(hidden_states_flat) top_k_weights, top_k_index = self._compute_routing(router_logits) + # _compute_routing uses float32 for softmax stability; cast back to model dtype + # to maintain consistent dtypes through jax.lax.scan in forward_layers + top_k_weights = top_k_weights.astype(hidden_states.dtype) expert_output = self.experts(hidden_states_flat, top_k_index, top_k_weights, adapter_indices_flat) shared_output = self.shared_experts( @@ -416,17 +420,20 @@ def __call__( class DeepseekV3DecoderLayer(nnx.Module): + """Decoder layer supporting both dense MLP and sparse MoE.""" - def __init__(self, config: DeepseekV3Config, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + def __init__( + self, + config: DeepseekV3Config, + *, + mlp_cls: type[DeepseekV3MLP] | type[DeepseekV3MoE], + dtype: jnp.dtype, + rngs: nnx.Rngs, + ) -> None: self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) self.self_attn = DeepseekV3Attention(config, dtype=dtype, rngs=rngs) - - # Use dense MLP for initial layers, MoE for the rest - if layer_idx >= config.first_k_dense_replace: - self.mlp = DeepseekV3MoE(config, dtype=dtype, rngs=rngs) - else: - self.mlp = DeepseekV3MLP(config, dtype=dtype, rngs=rngs) + self.mlp = mlp_cls(config, dtype=dtype, rngs=rngs) def __call__( self, @@ -460,6 +467,9 @@ class DeepseekV3Model(nnx.Module): def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config + self.num_dense_layers = config.first_k_dense_replace + self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace + self.num_layers = config.num_hidden_layers self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, @@ -472,14 +482,24 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs embedding_init=nnx.initializers.normal(), rngs=rngs, ) - self.layers = nnx.List( - [ - DeepseekV3DecoderLayer(config, layer_idx=i, dtype=dtype, rngs=rngs) - for i in range(config.num_hidden_layers) - ] - ) + + # Create stacked layers: dense layers followed by MoE layers + def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: + return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MLP, dtype=dtype, rngs=rngs) + + def create_moe_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer: + return DeepseekV3DecoderLayer(config, mlp_cls=DeepseekV3MoE, dtype=dtype, rngs=rngs) + + dense_layers = StackedDecoderLayers(create_dense_layer, self.num_dense_layers, rngs) + moe_layers = StackedDecoderLayers(create_moe_layer, self.num_moe_layers, rngs) + self.layers = MultiStackedDecoderLayers(dense_layers, moe_layers) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + def get_stacked_layers_list(self): + """Delegate to MultiStackedDecoderLayers for checkpoint loading.""" + return self.layers.get_stacked_layers_list() + def __call__( self, input_ids: jax.Array, @@ -489,28 +509,25 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - all_hidden_states: list[jax.Array] = [] - updated_keys, updated_values = [], [] - - for layer_idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) - - hidden_states, (k, v) = layer( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]), - ) - updated_keys.append(k) - updated_values.append(v) + + # Forward through all layers + hidden_states, all_hidden_states, kv_cache = self.layers( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + output_hidden_states=output_hidden_states, + gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, + ) hidden_states = self.norm(hidden_states) if output_hidden_states: @@ -518,7 +535,7 @@ def __call__( return ModelOutput( last_hidden_state=hidden_states, - kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask), + kv_cache=kv_cache, hidden_states=all_hidden_states if output_hidden_states else None, ) @@ -561,6 +578,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = jnp.arange(attention_mask.shape[1])[None, :] @@ -572,6 +590,7 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, + is_training=is_training, ) return CausalLMOutput( diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index b1ae1027b..8ff6c85ff 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -4,10 +4,11 @@ from jax.sharding import get_abstract_mesh from transformers import LlamaConfig +from tx.layers.attention import dot_product_attention from tx.layers.lora import LoRAEmbed, LoRALinear from tx.layers.rotary_embedding import apply_rope from tx.layers.layernorm import RMSNorm -from tx.layers.attention import dot_product_attention +from tx.layers.stacked import StackedDecoderLayers from tx.utils.logits_processor import LogitsProcessorMixin, LMHead from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache @@ -199,6 +200,7 @@ class Llama3Model(nnx.Module): def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config + self.num_layers = config.num_hidden_layers self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, @@ -211,9 +213,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> embedding_init=nnx.initializers.normal(), rngs=rngs, ) - self.layers = nnx.List( - [Llama3DecoderLayer(config, dtype=dtype, rngs=rngs) for _ in range(config.num_hidden_layers)] - ) + + def create_layer(rngs: nnx.Rngs) -> Llama3DecoderLayer: + return Llama3DecoderLayer(config, dtype=dtype, rngs=rngs) + + self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) def __call__( @@ -225,28 +229,24 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - all_hidden_states: list[jax.Array] = [] - updated_keys, updated_values = [], [] - - for layer_idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) - - hidden_states, (k, v) = layer( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]), - ) - updated_keys.append(k) - updated_values.append(v) + + hidden_states, all_hidden_states, new_kv_cache = self.layers( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + output_hidden_states=output_hidden_states, + gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, + ) hidden_states = self.norm(hidden_states) if output_hidden_states: @@ -254,7 +254,7 @@ def __call__( return ModelOutput( last_hidden_state=hidden_states, - kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask), + kv_cache=new_kv_cache, hidden_states=all_hidden_states if output_hidden_states else None, ) @@ -299,6 +299,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = jnp.arange(attention_mask.shape[1])[None, :] @@ -310,6 +311,7 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, + is_training=is_training, ) return CausalLMOutput( diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 1348cac09..a067e8245 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -3,15 +3,16 @@ from jax import numpy as jnp from jax.sharding import get_abstract_mesh +from tx.layers.attention import dot_product_attention from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear from tx.layers.util import prepare_routing, shard_map_ep from tx.layers.rotary_embedding import apply_rope -from tx.utils.logits_processor import LogitsProcessorMixin, LMHead from tx.layers.layernorm import RMSNorm -from tx.layers.attention import dot_product_attention +from tx.layers.stacked import StackedDecoderLayers from tx.models.configs import Qwen3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead class Qwen3Attention(nnx.Module): @@ -317,6 +318,7 @@ class Qwen3Model(nnx.Module): def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.config = config + self.num_layers = config.num_hidden_layers self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, @@ -329,9 +331,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> embedding_init=nnx.initializers.normal(), rngs=rngs, ) - self.layers = nnx.List( - [Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs) for _ in range(config.num_hidden_layers)] - ) + + def create_layer(rngs: nnx.Rngs) -> Qwen3DecoderLayer: + return Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs) + + self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) def __call__( @@ -343,28 +347,24 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - all_hidden_states: list[jax.Array] = [] - updated_keys, updated_values = [], [] - - for layer_idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) - - hidden_states, (k, v) = layer( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]), - ) - updated_keys.append(k) - updated_values.append(v) + + hidden_states, all_hidden_states, new_kv_cache = self.layers( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + output_hidden_states=output_hidden_states, + gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, + ) hidden_states = self.norm(hidden_states) if output_hidden_states: @@ -372,7 +372,7 @@ def __call__( return ModelOutput( last_hidden_state=hidden_states, - kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask), + kv_cache=new_kv_cache, hidden_states=all_hidden_states if output_hidden_states else None, ) @@ -417,6 +417,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = jnp.arange(attention_mask.shape[1])[None, :] @@ -428,6 +429,7 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, + is_training=is_training, ) return CausalLMOutput( diff --git a/skyrl-tx/tx/models/types.py b/skyrl-tx/tx/models/types.py index 8067c9f8a..16d0241d5 100644 --- a/skyrl-tx/tx/models/types.py +++ b/skyrl-tx/tx/models/types.py @@ -24,12 +24,12 @@ class ModelOutput: Attributes: last_hidden_state: The last hidden state from the model. - kv_cache: The updated key-value cache. + kv_cache: The updated key-value cache (None during training). hidden_states: All hidden states if output_hidden_states=True. """ last_hidden_state: jax.Array - kv_cache: KVCache + kv_cache: KVCache | None hidden_states: list[jax.Array] | None = None @@ -40,10 +40,10 @@ class CausalLMOutput: Attributes: last_hidden_state: The last hidden state from the model. - kv_cache: The updated key-value cache. + kv_cache: The updated key-value cache (None during training). hidden_states: All hidden states, if output_hidden_states=True. """ last_hidden_state: jax.Array - kv_cache: KVCache + kv_cache: KVCache | None hidden_states: list[jax.Array] | None = None diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 99ae33327..744c70d98 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -52,6 +52,7 @@ insert_adapter_state, round_up_seq_len, resolve_model_path, + get_adapter_idx, ) from tx.utils.log import logger @@ -81,7 +82,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"): ) gradient_checkpointing: bool = Field( default=False, - description="Whether to use gradient checkpointing (full recomputation strategy)", + description="Per-layer activation checkpointing: recompute activations during backward to save memory", ) loss_chunk_size: int = Field( default=1024, @@ -126,15 +127,22 @@ def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "Accumulated def get_mean(self, adapter_index: jax.Array) -> nnx.State: """Compute mean gradients for a specific adapter, with zeros for all other adapters.""" count = self.counts[adapter_index] - return jax.tree.map( - lambda g: jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)), - self.grad_sum, - ) + + def compute_mean(path, g): + idx = get_adapter_idx(path, adapter_index) + return jnp.zeros_like(g).at[idx].set(g[idx] / count.astype(g.dtype)) + + return jax.tree.map_with_path(compute_mean, self.grad_sum) def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients": """Reset gradients and count for a specific adapter.""" + + def reset_grad(path, g): + idx = get_adapter_idx(path, adapter_index) + return g.at[idx].set(0.0) + return AccumulatedGradients( - grad_sum=jax.tree.map(lambda g: g.at[adapter_index].set(0.0), self.grad_sum), + grad_sum=jax.tree.map_with_path(reset_grad, self.grad_sum), counts=self.counts.at[adapter_index].set(0), ) @@ -251,14 +259,10 @@ def _model_forward( input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices, + is_training=True, ) return model.compute_logprobs(output.last_hidden_state, target_ids, adapter_indices) - if self.config.gradient_checkpointing: - # Wrap the model forward call to use jax.checkpoint for gradient checkpointing - # policy=None corresponds to full activation recomputation - _model_forward = jax.checkpoint(_model_forward, policy=None) - def loss_for_lora( lora_params: nnx.State, non_lora_params: nnx.State, diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index f461a5613..37b6f7b54 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -16,9 +16,9 @@ class KVCache: """Key-value cache for all layers, each entry in the list corresponds to one layer.""" - keys: list[jax.Array] - values: list[jax.Array] - cache_position: jax.Array # Per-sequence positions of shape [B] for left-aligned decoding + keys: list[jax.Array] # list of (batch, seq, num_kv_heads, head_dim) per layer + values: list[jax.Array] # list of (batch, seq, num_kv_heads, head_dim) per layer + cache_position: jax.Array # Per-sequence positions of shape (batch,) @staticmethod def update( @@ -85,6 +85,75 @@ def pad_to_length(self, max_length: int) -> KVCache: cache_position=self.cache_position, ) + @property + def num_layers(self) -> int: + """Number of layers in the cache.""" + return len(self.keys) + + @property + def batch_size(self) -> int: + """Batch size.""" + return self.keys[0].shape[0] + + @property + def seq_len(self) -> int: + """Current sequence length.""" + return self.keys[0].shape[1] + + def split(self, *layer_indices: int) -> tuple[KVCache | None, ...]: + """Split the cache at one or more layer indices. + + Args: + *layer_indices: Layer indices to split at. For example, split(3, 7) + creates 3 caches: [0:3), [3:7), [7:end). + + Returns: + Tuple of KVCache objects, one for each segment. Returns None for empty segments. + """ + if len(layer_indices) == 0: + return (self,) + + # Build split points: 0, idx1, idx2, ..., num_layers + split_points = [0] + list(layer_indices) + [self.num_layers] + + caches = [] + for start, end in zip(split_points[:-1], split_points[1:]): + if start == end: + caches.append(None) + else: + caches.append( + KVCache( + keys=self.keys[start:end], + values=self.values[start:end], + cache_position=self.cache_position, + ) + ) + return tuple(caches) + + @staticmethod + def concatenate(*caches: KVCache | None) -> KVCache | None: + """Concatenate multiple caches along the layer dimension. + + Args: + *caches: KVCache objects to concatenate, or None values to skip. + + Returns: + Combined KVCache, or None if all inputs are None. + """ + # Filter out None values + non_none_caches = [c for c in caches if c is not None] + + if len(non_none_caches) == 0: + return None + if len(non_none_caches) == 1: + return non_none_caches[0] + + return KVCache( + keys=sum([c.keys for c in non_none_caches], []), + values=sum([c.values for c in non_none_caches], []), + cache_position=non_none_caches[-1].cache_position, + ) + @jax.tree_util.register_dataclass @dataclass @@ -197,11 +266,9 @@ def _prefill_and_decode( last_logits = model.compute_logits(last_hidden, adapter_indices)[:, 0, :] prompt_logprobs_array = None - # Pad KV cache and attention mask + # Pad KV cache and attention mask to max_length kv_cache = outputs.kv_cache.pad_to_length(max_length) - # Pad KV cache and attention mask to max_length - kv_cache = kv_cache.pad_to_length(max_length) decode_attention_mask = jnp.pad(attention_mask, ((0, 0), (0, max_length - attention_mask.shape[1]))) def decode_fn(s: DecodeState, step: jax.Array) -> tuple[DecodeState, tuple[jax.Array, jax.Array]]: diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 6e840febf..6ad3f8ea6 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -1,3 +1,5 @@ +"""Weight loading and saving utilities for stacked layer models.""" + from __future__ import annotations from enum import Enum @@ -76,6 +78,33 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: raise ValueError(f"None of the architectures {config.architectures} is currently supported.") +def is_stacked_lora_path(path: tuple) -> bool: + """Check if a parameter path is for stacked layer weights (for LoRA indexing). + + Stacked layer params have the adapter dimension at axis 1: (num_layers, num_adapters, ...). + Non-stacked params (e.g., embed_tokens) have adapter dimension at axis 0: (num_adapters, ...). + + Args: + path: Parameter path tuple (can be nnx path objects or strings). + + Returns: + True if the path contains '_stacked' (from StackedDecoderLayers). + """ + path_strs = [p.key if hasattr(p, "key") else str(p) for p in path] + return "_stacked" in path_strs + + +def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: + """Return index tuple for accessing an adapter at the given path. + + Stacked layer params have shape (num_layers, num_adapters, ...) -> index as [:, adapter_index]. + Non-stacked params (embed_tokens) have shape (num_adapters, ...) -> index as [adapter_index]. + """ + if is_stacked_lora_path(path): + return (slice(None), adapter_index) + return (adapter_index,) + + def get_param_key(path: tuple, prefix: str = "") -> str: "Get the safetensors key for a given model path." if path[-1] in {"embedding", "kernel"}: @@ -99,14 +128,17 @@ def load_safetensors( prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: + """Load safetensors weights into a model with stacked layers.""" + from tx.layers.stacked import unstack_state + tensors = {} for file in Path(checkpoint_dir).glob("*.safetensors"): tensors.update(safetensors.numpy.load_file(file)) tensors = {k.removeprefix(prefix): v for k, v in tensors.items()} - model_params = nnx.to_flat_state(nnx.state(model)) - updates = [] - for path, param in model_params: + # unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths + # (layers.0.xxx) with ArrayRef write-through, matching checkpoint key format + for path, param in nnx.to_flat_state(unstack_state(model)): if filter_fn is not None and not filter_fn(path): continue key = get_param_key(path) @@ -114,17 +146,14 @@ def load_safetensors( if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): continue if "experts" in path: - tensors[key] = np.stack( - [tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0 - ) + tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) else: - tensors[key] = tensors[key] if "embed_tokens" in path else tensors[key].T + tensor = tensors[key] if "embed_tokens" in key else tensors[key].T if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: - tensors[key] = tensors[key].reshape(param.shape) - assert param.shape == tensors[key].shape, f"shape mismatch for {key}" - sharded_tensor = jax.device_put(tensors[key].astype(param.dtype), param.sharding) - updates.append((path, sharded_tensor)) - nnx.update(model, nnx.from_flat_state(updates)) + tensor = tensor.reshape(param.shape) + assert param.shape == tensor.shape, f"shape mismatch for {key}" + # ArrayRef.set_raw_value writes through to the stacked parent variable + param.set_raw_value(jax.device_put(tensor.astype(param.dtype), param.sharding)) def save_safetensors( @@ -134,9 +163,13 @@ def save_safetensors( prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: - model_params = nnx.to_flat_state(nnx.state(model)) + """Save model weights to safetensors, unstacking layer weights for HF compatibility.""" + from tx.layers.stacked import unstack_state + + # unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths + # (layers.0.xxx) matching the checkpoint key format used by HuggingFace tensors = {} - for path, param in model_params: + for path, param in nnx.to_flat_state(unstack_state(model)): if "rngs" in path: continue if filter_fn is not None and not filter_fn(path): @@ -252,13 +285,14 @@ def extract_adapter_state(adapter_index: int, lora_params: nnx.GraphState, rank: "Helper function to extract the adapter parameters for a specific adapter index." def extract_state(path: tuple, p: jnp.ndarray): - if path[-2].key not in {"lora_A", "lora_B"}: + key = path[-2].key + if key not in {"lora_A", "lora_B"}: return p - assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" - if path[-2].key == "lora_A": - return p[adapter_index, ..., :, :rank] - if path[-2].key == "lora_B": - return p[adapter_index, ..., :rank, :] + assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" + idx = get_adapter_idx(path, adapter_index) + if key == "lora_A": + return p[idx + (..., slice(None, rank))] + return p[idx + (..., slice(None, rank), slice(None))] return jax.tree.map_with_path(extract_state, lora_params) @@ -271,13 +305,14 @@ def insert_adapter_state( "Helper function to insert the adapter parameters for a specific adapter index (inverse of extract_adapter_state)." def insert_state(path: tuple, p: jax.Array, new: jax.Array): - if path[-2].key not in {"lora_A", "lora_B"}: + key = path[-2].key + if key not in {"lora_A", "lora_B"}: return new - assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" - if path[-2].key == "lora_A": - return p.at[adapter_index, ..., :, :rank].set(new) - elif path[-2].key == "lora_B": - return p.at[adapter_index, ..., :rank, :].set(new) + assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" + idx = get_adapter_idx(path, adapter_index) + if key == "lora_A": + return p.at[idx + (..., slice(None, rank))].set(new) + return p.at[idx + (..., slice(None, rank), slice(None))].set(new) updated = jax.tree.map_with_path(insert_state, lora_params, new_params) nnx.update(lora_params, updated)