diff --git a/skyrl-tx/tests/models/test_glm4.py b/skyrl-tx/tests/models/test_glm4.py new file mode 100644 index 000000000..5aa85c701 --- /dev/null +++ b/skyrl-tx/tests/models/test_glm4.py @@ -0,0 +1,210 @@ +import os +import tempfile + +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig + +from tx.layers.lora import LoRAMixin +from tx.models.configs import Glm4Config +from tx.models.glm4 import Glm4ForCausalLM, Glm4MoE +from tx.utils.models import load_safetensors + + +@pytest.mark.parametrize("tp", [1, 2]) +def test_glm4_moe(tp: int): + """Test GLM4-MoE model against HuggingFace implementation.""" + if not jax._src.xla_bridge.backends_are_initialized(): + jax.config.update("jax_num_cpu_devices", 2) + + if tp > 1 and os.getenv("CI"): + pytest.skip("TP > 1 currently runs out of memory in the CI") + + model_name = "yujiepan/glm-4-moe-tiny-random" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, attn_implementation="eager", use_safetensors=True, trust_remote_code=True + ) + + inputs = ["The capital of France is", "The most popular programming language is"] + batch = tokenizer(inputs, return_tensors="pt", padding=True) + with torch.no_grad(): + hf_outputs = hf_model( + batch.input_ids, + attention_mask=batch.attention_mask, + output_hidden_states=True, + return_dict=True, + use_cache=False, + ) + + # Save the HF model checkpoint so we can load our model from it + with tempfile.TemporaryDirectory() as tmp: + hf_model.save_pretrained(tmp, safe_serialization=True) + + base_config = PretrainedConfig.from_pretrained(model_name, trust_remote_code=True) + config = Glm4Config(base_config, max_lora_adapters=32, max_lora_rank=32, shard_attention_heads=True) + mesh = jax.make_mesh( + (1, tp), + ("fsdp", "tp"), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + model = Glm4ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + load_safetensors(tmp, config, model) + + outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy(), output_hidden_states=True) + + assert outputs.hidden_states is not None + assert np.allclose(hf_outputs.hidden_states[0], outputs.hidden_states[0], rtol=1e-6) + # Higher tolerance due to cross-platform BLAS differences + assert np.allclose(hf_outputs.hidden_states[1], outputs.hidden_states[1], rtol=6e-3, atol=6e-3) + assert np.allclose(hf_outputs.hidden_states[-1], outputs.hidden_states[-1], rtol=3e-2, atol=6e-2) + + +def load_moe_base_weights(jax_moe_layer: Glm4MoE, hf_moe_layer) -> None: + """Load base weights from HF MoE layer to JAX MoE layer. + + The tiny random model uses separate experts (ModuleList), matching our implementation. + """ + # Router weights + jax_moe_layer.gate.weight[:] = hf_moe_layer.gate.weight.detach().numpy().T + jax_moe_layer.gate.e_score_correction_bias[:] = hf_moe_layer.gate.e_score_correction_bias.detach().numpy() + + # Expert weights - The tiny model uses ModuleList with separate gate_proj, up_proj, down_proj + hf_experts = hf_moe_layer.experts + + for i, expert in enumerate(hf_experts): + jax_moe_layer.experts.gate_proj.weight[i, :, :] = expert.gate_proj.weight.detach().numpy().T + jax_moe_layer.experts.up_proj.weight[i, :, :] = expert.up_proj.weight.detach().numpy().T + jax_moe_layer.experts.down_proj.weight[i, :, :] = expert.down_proj.weight.detach().numpy().T + + # Shared experts + jax_moe_layer.shared_experts.gate_proj.kernel[:] = hf_moe_layer.shared_experts.gate_proj.weight.detach().numpy().T + jax_moe_layer.shared_experts.up_proj.kernel[:] = hf_moe_layer.shared_experts.up_proj.weight.detach().numpy().T + jax_moe_layer.shared_experts.down_proj.kernel[:] = hf_moe_layer.shared_experts.down_proj.weight.detach().numpy().T + + +def test_glm4_moe_layer(): + """Test GLM4 MoE layer against HuggingFace implementation.""" + model_name = "yujiepan/glm-4-moe-tiny-random" + hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) + base_config = PretrainedConfig.from_pretrained(model_name) + config = Glm4Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=True) + + # First layer uses dense MLP (first_k_dense_replace=1), so we test layer 1 which has MoE + hf_moe_layer = hf_model.model.layers[1].mlp + torch.manual_seed(42) + x = torch.randn(4, 2, config.hidden_size) + with torch.no_grad(): + hf_expert_output = hf_moe_layer.forward(x) + + mesh = jax.make_mesh( + (1, 1), + ("fsdp", "tp"), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + moe_layer = Glm4MoE(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + load_moe_base_weights(moe_layer, hf_moe_layer) + + jax_expert_output = moe_layer(x.numpy()) + + # Higher tolerance due to cross-platform BLAS differences + assert np.allclose(hf_expert_output.detach().numpy(), jax_expert_output, rtol=6e-3, atol=6e-3) + + +def load_lora_weights( + jax_module: LoRAMixin, + adapter_idx: int, + lora_A_weights: np.ndarray, + lora_B_weights: np.ndarray, + scaling: float, + rank: int, +) -> None: + """Load LoRA weights from numpy arrays to JAX module.""" + assert ( + jax_module.lora_A is not None + and jax_module.lora_B is not None + and jax_module.lora_scaling is not None + and jax_module.lora_ranks is not None + ) + jax_module.lora_A[...] = jax_module.lora_A[...].at[adapter_idx].set(jnp.array(lora_A_weights)) + jax_module.lora_B[...] = jax_module.lora_B[...].at[adapter_idx].set(jnp.array(lora_B_weights)) + jax_module.lora_scaling[...] = jax_module.lora_scaling[...].at[adapter_idx].set(scaling) + jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[adapter_idx].set(rank) + + +def test_glm4_moe_layer_lora(): + """Test MoE LoRA by merging adapter into base weights and comparing outputs.""" + model_name = "yujiepan/glm-4-moe-tiny-random" + hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) + base_config = PretrainedConfig.from_pretrained(model_name) + config = Glm4Config(base_config, max_lora_adapters=3, max_lora_rank=4, shard_attention_heads=True) + + hf_moe_layer = hf_model.model.layers[1].mlp + x = torch.randn(3, 4, config.hidden_size) + + mesh = jax.make_mesh( + (1, 1), + ("fsdp", "tp"), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + moe_layer = Glm4MoE(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + load_moe_base_weights(moe_layer, hf_moe_layer) + + # Set LoRA weights for all adapters + rng = np.random.default_rng(42) + scaling = 2.0 + rank = config.max_lora_rank + for adapter_idx in range(config.max_lora_adapters): + for proj in [moe_layer.experts.gate_proj, moe_layer.experts.up_proj, moe_layer.experts.down_proj]: + assert proj.lora_A is not None and proj.lora_B is not None + lora_A = rng.normal(0, 1.0, proj.lora_A[...].shape[1:]) + lora_B = rng.normal(0, 1.0, proj.lora_B[...].shape[1:]) + load_lora_weights(proj, adapter_idx, lora_A, lora_B, scaling, rank) + + # Test with different adapters per sample + adapter_indices = jnp.array([0, 2, 1]) + output_with_lora = moe_layer(x.numpy(), adapter_indices=adapter_indices) + + # Test each sample by comparing with merged weights for its adapter + for sample_idx in range(len(adapter_indices)): + adapter_idx = int(adapter_indices[sample_idx]) + + # Create merged model by adding LoRA weights to base weights + moe_layer_merged = Glm4MoE(config, dtype=jnp.float32, rngs=nnx.Rngs(1 + adapter_idx)) + + # Copy router weights + moe_layer_merged.gate.weight[:] = moe_layer.gate.weight[:] + moe_layer_merged.gate.e_score_correction_bias[:] = moe_layer.gate.e_score_correction_bias[:] + + # Copy shared experts weights + moe_layer_merged.shared_experts.gate_proj.kernel[:] = moe_layer.shared_experts.gate_proj.kernel[:] + moe_layer_merged.shared_experts.up_proj.kernel[:] = moe_layer.shared_experts.up_proj.kernel[:] + moe_layer_merged.shared_experts.down_proj.kernel[:] = moe_layer.shared_experts.down_proj.kernel[:] + + for proj_name in ["gate_proj", "up_proj", "down_proj"]: + proj = getattr(moe_layer.experts, proj_name) + proj_merged = getattr(moe_layer_merged.experts, proj_name) + + # For each expert, merge: base + scaling * (lora_A @ lora_B) + for expert_idx in range(config.n_routed_experts): + lora_A = proj.lora_A[adapter_idx, expert_idx, :, :] + lora_B = proj.lora_B[adapter_idx, expert_idx, :, :] + lora_delta = scaling * (lora_A @ lora_B) + + # Copy base weight AND add LoRA delta + base_weight = proj.weight[expert_idx, :, :] + merged_weight = base_weight + lora_delta + proj_merged.weight[...] = proj_merged.weight[...].at[expert_idx, :, :].set(merged_weight) + + # Run merged model on this sample + x_sample = x[sample_idx : sample_idx + 1].numpy() + output_merged = moe_layer_merged(x_sample) + + assert np.allclose(output_with_lora[sample_idx : sample_idx + 1], output_merged, rtol=1e-3, atol=1e-3) diff --git a/skyrl-tx/tests/models/test_glm4_lora_training.py b/skyrl-tx/tests/models/test_glm4_lora_training.py new file mode 100644 index 000000000..5ceabfb13 --- /dev/null +++ b/skyrl-tx/tests/models/test_glm4_lora_training.py @@ -0,0 +1,233 @@ +from flax import nnx +import jax +import jax.numpy as jnp +import optax +from huggingface_hub import snapshot_download +from transformers import PretrainedConfig + +from tx.models.configs import Glm4Config +from tx.models.glm4 import Glm4ForCausalLM +from tx.utils.models import get_dtype, load_safetensors +from tx.layers.lora import init_lora_adapter +from tx.tinker.types import LoraConfig + + +def _is_routed_expert_path(path) -> bool: + """Disambiguate shared_experts and experts""" + keys = [] + for p in path: + if hasattr(p, "key"): + keys.append(str(p.key)) + elif hasattr(p, "name"): + keys.append(str(p.name)) + + for i, key in enumerate(keys): + if key == "experts" and i > 0 and keys[i - 1] == "mlp": + return True + return False + + +def _get_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: int): + """Extract out-of-rank params, using effective rank for routed expert layers.""" + + def slice_param(path, p): + path_str = str(path) + + if _is_routed_expert_path(path): + effective_rank = max(1, rank // num_experts) + else: + effective_rank = rank + + if "lora_A" in path_str: + # lora_A shape: [adapters, ..., max_rank] - slice last dim + return p[adapter_idx, ..., effective_rank:].copy() + elif "lora_B" in path_str: + # lora_B shape: [adapters, ..., max_rank, out] - slice second-to-last dim + return p[adapter_idx, ..., effective_rank:, :].copy() + return p + + return jax.tree.map_with_path(slice_param, params) + + +def test_lora_training_moe_rank_normalized(): + base_model = "yujiepan/glm-4-moe-tiny-random" + base_config = PretrainedConfig.from_pretrained(base_model, trust_remote_code=True) + config = Glm4Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True) + + checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"]) + mesh = jax.make_mesh( + (1, 1), + ("fsdp", "tp"), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + model = Glm4ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) + load_safetensors(checkpoint_path, config, model) + + # Set different ranks for each adapter (0: rank 16, 1: rank 8) + # For routed experts: effective rank = max(1, rank // num_experts) + # For other layers: effective rank = configured rank + init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) + init_lora_adapter(model, adapter_index=1, lora_config=LoraConfig(rank=8, alpha=8, seed=1)) + + optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=model.is_lora_param) + + # Use 11 tokens so input_ids has 10 (even) - cuDNN flash attention requires even seq length with bias + batch = jnp.array( + [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]], dtype=jnp.int32 + ) + target_ids = batch[:, 1:] + input_ids = batch[:, :-1] + adapter_indices = jnp.array([0, 1], dtype=jnp.int32) + attention_mask = jnp.ones_like(input_ids) + + def loss_fn(model, input_ids, target_ids, attention_mask): + outputs = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) + logits = model.compute_logits(outputs.last_hidden_state, adapter_indices) + return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target_ids).mean() + + graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) + + def get_adapter_params(params, adapter_idx): + return jax.tree.map(lambda p: p[adapter_idx].copy(), params) + + num_experts = config.n_routed_experts + + # Save initial states + initial_adapter_2_params = get_adapter_params(lora_params, 2) + initial_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + initial_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + + initial_loss = None + + # Training loop + for step in range(10): + + def loss_for_lora(lora_params): + merged_model = nnx.merge(graphdef, lora_params, non_lora_params) + return loss_fn(merged_model, input_ids, target_ids, attention_mask) + + loss_and_grad_fn = nnx.value_and_grad(loss_for_lora) + loss, lora_grads = loss_and_grad_fn(lora_params) + + if initial_loss is None: + initial_loss = float(loss) + + optimizer.update(lora_params, lora_grads) + + print(f"Step {step}: loss = {float(loss):.4f}") + + final_loss = float(loss) + + def verify_params_unchanged(initial_params, final_params, error_msg_prefix): + for (path, initial), (_, final) in zip( + jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) + ): + assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" + + assert final_loss < initial_loss, f"Loss did not decrease: {initial_loss} -> {final_loss}" + + # Verify unused adapter was not modified + final_adapter_2_params = get_adapter_params(lora_params, 2) + verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") + + # Verify out-of-rank params were not modified + final_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + verify_params_unchanged( + initial_adapter_0_out_of_rank, final_adapter_0_out_of_rank, "Adapter 0 out-of-rank params modified" + ) + final_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + verify_params_unchanged( + initial_adapter_1_out_of_rank, final_adapter_1_out_of_rank, "Adapter 1 out-of-rank params modified" + ) + + +def test_lora_training_high_rank(): + base_model = "yujiepan/glm-4-moe-tiny-random" + base_config = PretrainedConfig.from_pretrained(base_model, trust_remote_code=True) + config = Glm4Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True) + + checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"]) + mesh = jax.make_mesh( + (1, 1), + ("fsdp", "tp"), + axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + model = Glm4ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) + load_safetensors(checkpoint_path, config, model) + + init_lora_adapter(model, adapter_index=0, lora_config=LoraConfig(rank=16, alpha=16, seed=0)) + init_lora_adapter(model, adapter_index=1, lora_config=LoraConfig(rank=8, alpha=8, seed=1)) + + optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=model.is_lora_param) + + # Use 11 tokens so input_ids has 10 (even) - cuDNN flash attention requires even seq length with bias + batch = jnp.array( + [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]], dtype=jnp.int32 + ) + target_ids = batch[:, 1:] + input_ids = batch[:, :-1] + adapter_indices = jnp.array([0, 1], dtype=jnp.int32) + attention_mask = jnp.ones_like(input_ids) + + def loss_fn(model, input_ids, target_ids, attention_mask): + outputs = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) + logits = model.compute_logits(outputs.last_hidden_state, adapter_indices) + return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target_ids).mean() + + graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...) + + def get_adapter_params(params, adapter_idx): + return jax.tree.map(lambda p: p[adapter_idx].copy(), params) + + num_experts = config.n_routed_experts + + # Save initial states for all unused adapters + initial_adapter_2_params = get_adapter_params(lora_params, 2) + initial_adapter_3_params = get_adapter_params(lora_params, 3) + initial_adapter_4_params = get_adapter_params(lora_params, 4) + + # Save out-of-rank params for adapters 0 and 1 + initial_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + initial_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + + # Training loop + for step in range(10): + + def loss_for_lora(lora_params): + merged_model = nnx.merge(graphdef, lora_params, non_lora_params) + return loss_fn(merged_model, input_ids, target_ids, attention_mask) + + loss_and_grad_fn = nnx.value_and_grad(loss_for_lora) + loss, lora_grads = loss_and_grad_fn(lora_params) + + optimizer.update(lora_params, lora_grads) + + print(f"Step {step}: loss = {float(loss):.4f}") + + def verify_params_unchanged(initial_params, final_params, error_msg_prefix): + for (path, initial), (_, final) in zip( + jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params) + ): + assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}" + + # Verify unused adapters (2, 3, 4) were not modified + final_adapter_2_params = get_adapter_params(lora_params, 2) + verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified") + + final_adapter_3_params = get_adapter_params(lora_params, 3) + verify_params_unchanged(initial_adapter_3_params, final_adapter_3_params, "Adapter 3 was modified") + + final_adapter_4_params = get_adapter_params(lora_params, 4) + verify_params_unchanged(initial_adapter_4_params, final_adapter_4_params, "Adapter 4 was modified") + + # Verify out-of-rank params were not modified + final_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts) + verify_params_unchanged( + initial_adapter_0_out_of_rank, final_adapter_0_out_of_rank, "Adapter 0 out-of-rank params modified" + ) + final_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts) + verify_params_unchanged( + initial_adapter_1_out_of_rank, final_adapter_1_out_of_rank, "Adapter 1 out-of-rank params modified" + ) diff --git a/skyrl-tx/tx/models/configs.py b/skyrl-tx/tx/models/configs.py index 15e011388..a5acfdda7 100644 --- a/skyrl-tx/tx/models/configs.py +++ b/skyrl-tx/tx/models/configs.py @@ -53,3 +53,4 @@ def get_num_experts(self): Llama3Config = ModelConfig Qwen3Config = ModelConfig DeepseekV3Config = ModelConfig +Glm4Config = ModelConfig diff --git a/skyrl-tx/tx/models/glm4.py b/skyrl-tx/tx/models/glm4.py new file mode 100644 index 000000000..f486c9207 --- /dev/null +++ b/skyrl-tx/tx/models/glm4.py @@ -0,0 +1,513 @@ +from flax import nnx +import jax +from jax import numpy as jnp +from jax.sharding import get_abstract_mesh + +from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear +from tx.layers.rotary_embedding import get_rope +from tx.layers.util import Param, prepare_routing +from tx.layers.layernorm import RMSNorm +from tx.layers.attention import dot_product_attention +from tx.models.configs import Glm4Config +from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput +from tx.utils.generator import GeneratorMixin, KVCache +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead + + +class Glm4Attention(nnx.Module): + """Multi-head attention with Grouped Query Attention (GQA) support.""" + + def __init__(self, config: Glm4Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + + tp = get_abstract_mesh().shape.get("tp", 1) + shard_attention_heads = config.shard_attention_heads + if shard_attention_heads: + assert self.num_heads % tp == 0, f"num_heads={self.num_heads} must be divisible by tp={tp}" + assert self.num_kv_heads % tp == 0, f"num_kv_heads={self.num_kv_heads} must be divisible by tp={tp}" + tp_shard = "tp" if shard_attention_heads else None + + self.head_dim = config.head_dim or config.hidden_size // self.num_heads + + self.q_proj = LoRALinear( + in_features=config.hidden_size, + out_features=self.num_heads * self.head_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=config.attention_bias, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + rngs=rngs, + ) + self.k_proj = LoRALinear( + in_features=config.hidden_size, + out_features=self.num_kv_heads * self.head_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=config.attention_bias, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + rngs=rngs, + ) + self.v_proj = LoRALinear( + in_features=config.hidden_size, + out_features=self.num_kv_heads * self.head_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=config.attention_bias, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + rngs=rngs, + ) + self.o_proj = LoRALinear( + in_features=self.num_heads * self.head_dim, + out_features=config.hidden_size, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + param_dtype=dtype, + use_bias=False, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (tp_shard, "fsdp")), + rngs=rngs, + ) + + if self.config.use_qk_norm: + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + + self.rotary_dim = int(self.head_dim * self.config.partial_rotary_factor) + self.rotary_emb, _ = get_rope(self.rotary_dim, config.rope_theta, config.rope_scaling) + + def __call__( + self, + x: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None = None, + kv_cache: tuple[jax.Array, jax.Array] | None = None, + ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + B, T, _ = x.shape + + q = self.q_proj(x, adapter_indices=adapter_indices).reshape(B, T, self.num_heads, self.head_dim) + k = self.k_proj(x, adapter_indices=adapter_indices).reshape(B, T, self.num_kv_heads, self.head_dim) + v = self.v_proj(x, adapter_indices=adapter_indices).reshape(B, T, self.num_kv_heads, self.head_dim) + + if self.config.use_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + # Partial RoPE + q_rot, q_pass = q[..., : self.rotary_dim], q[..., self.rotary_dim :] + k_rot, k_pass = k[..., : self.rotary_dim], k[..., self.rotary_dim :] + + q_rot = self.rotary_emb(q_rot, positions) + k_rot = self.rotary_emb(k_rot, positions) + + q = jnp.concatenate([q_rot, q_pass], axis=-1) + k = jnp.concatenate([k_rot, k_pass], axis=-1) + + # Handle KV cache + if kv_cache is not None: + k, v = KVCache.update_layer(kv_cache, k, v, positions) + + updated_cache = (k, v) + + is_causal = kv_cache is None + attn_output = dot_product_attention(q, k, v, attention_mask, is_causal, self.head_dim) + + output = attn_output.reshape(B, T, self.num_heads * self.head_dim) + return self.o_proj(output, adapter_indices=adapter_indices), updated_cache + + +class Glm4MLP(nnx.Module): + + def __init__( + self, + config: Glm4Config, + *, + dtype: jnp.dtype, + rngs: nnx.Rngs, + override_intermediate_size: int | None = None, + ) -> None: + self.config = config + intermediate_size = override_intermediate_size or config.intermediate_size + self.gate_proj = LoRALinear( + config.hidden_size, + intermediate_size, + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + self.up_proj = LoRALinear( + config.hidden_size, + intermediate_size, + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + self.down_proj = LoRALinear( + intermediate_size, + config.hidden_size, + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("tp", "fsdp")), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + + def __call__(self, x: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array: + gate_out = self.gate_proj(x, adapter_indices) + up_out = self.up_proj(x, adapter_indices) + return self.down_proj(nnx.silu(gate_out) * up_out, adapter_indices) + + +class Glm4TopkRouter(nnx.Module): + """GLM4 MoE routing gate. Returns raw router logits.""" + + def __init__(self, config: Glm4Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + + self.weight = Param( + config.hidden_size, + config.n_routed_experts, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, None)), + rngs=rngs, + ) + + self.e_score_correction_bias = nnx.Variable(jnp.zeros(config.n_routed_experts, dtype=jnp.float32)) + + def __call__(self, hidden_states: jax.Array) -> jax.Array: + hidden_states = hidden_states.reshape(-1, self.config.hidden_size) + router_logits = hidden_states.astype(jnp.float32) @ self.weight[...].astype(jnp.float32) + return router_logits + + +class Glm4Experts(nnx.Module): + + def __init__(self, config: Glm4Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.num_experts = config.n_routed_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + + # NOTE: Huggingface implementation uses a fused gate_up_proj, but the weights are keyed + # by gate_proj and up_proj separately. + self.gate_proj = LoRAExpert( + self.num_experts, + self.hidden_dim, + self.intermediate_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")), + rngs=rngs, + ) + self.up_proj = LoRAExpert( + self.num_experts, + self.hidden_dim, + self.intermediate_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")), + rngs=rngs, + ) + self.down_proj = LoRAExpert( + self.num_experts, + self.intermediate_dim, + self.hidden_dim, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp", "fsdp")), + rngs=rngs, + ) + + def __call__( + self, + hidden_states: jax.Array, + top_k_index: jax.Array, + top_k_weights: jax.Array, + adapter_indices: jax.Array | None = None, + ) -> jax.Array: + num_experts_per_tok = top_k_index.shape[1] + + # Prepare for ragged_dot by sorting tokens based on their assigned expert + selected_experts_flat = top_k_index.ravel() + hidden_states_expanded = jnp.repeat(hidden_states, num_experts_per_tok, axis=0) + adapter_indices_expanded = ( + jnp.repeat(adapter_indices, num_experts_per_tok) if adapter_indices is not None else None + ) + + hidden_states_sorted, group_sizes, unsort_indices, adapter_indices_sorted = prepare_routing( + hidden_states_expanded, + selected_experts_flat, + self.num_experts, + adapter_indices=adapter_indices_expanded, + ) + + gate_out = self.gate_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted) + up_out = self.up_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted) + down_out = self.down_proj(nnx.silu(gate_out) * up_out, group_sizes, adapter_indices_sorted) + + # Unsort and combine the expert outputs + unsorted_out = down_out[unsort_indices] + reshaped_out = unsorted_out.reshape(-1, num_experts_per_tok, self.hidden_dim) + return jnp.sum(reshaped_out * top_k_weights[..., None], axis=1) + + +class Glm4MoE(nnx.Module): + """MoE layer with shared experts and group-based routing.""" + + def __init__(self, config: Glm4Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.n_group = config.n_group + + self.gate = Glm4TopkRouter(config, dtype=dtype, rngs=rngs) + self.experts = Glm4Experts(config, dtype=dtype, rngs=rngs) + + inter_dim = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = Glm4MLP(config, dtype=dtype, rngs=rngs, override_intermediate_size=inter_dim) + + def _compute_routing(self, router_logits: jax.Array) -> tuple[jax.Array, jax.Array]: + num_tokens = router_logits.shape[0] + num_experts = router_logits.shape[1] + + scores = nnx.sigmoid(router_logits) + scores_with_bias = scores + self.gate.e_score_correction_bias[...] + + experts_per_group = num_experts // self.n_group + scores_grouped = scores_with_bias.reshape(num_tokens, self.n_group, experts_per_group) + + top2, _ = jax.lax.top_k(scores_grouped, 2) + group_scores = jnp.sum(top2, axis=-1) + + _, top_group_indices = jax.lax.top_k(group_scores, self.config.topk_group) + + mask = jnp.ones((num_tokens, self.n_group), dtype=bool) + batch_indices = jnp.arange(num_tokens)[:, None] + mask = mask.at[batch_indices, top_group_indices].set(False) + mask = jnp.broadcast_to(mask[:, :, None], scores_grouped.shape) + + scores_with_bias = jnp.where(mask, 0.0, scores_grouped) + scores_with_bias = scores_with_bias.reshape(num_tokens, num_experts) + + _, top_k_index = jax.lax.top_k(scores_with_bias, self.config.num_experts_per_tok) + + # Get weights from original scores + top_k_weights = jnp.take_along_axis(scores, top_k_index, axis=-1) + + if self.config.norm_topk_prob: + top_k_weights = top_k_weights / jnp.sum(top_k_weights, axis=-1, keepdims=True) + + top_k_weights = top_k_weights * self.config.routed_scaling_factor + + return top_k_weights.astype(router_logits.dtype), top_k_index + + def __call__( + self, + hidden_states: jax.Array, + *, + adapter_indices: jax.Array | None = None, + ) -> jax.Array: + batch_size, seq_len, hidden_size = hidden_states.shape + hidden_states_flat = hidden_states.reshape(-1, hidden_size) + + if adapter_indices is not None: + adapter_indices_flat = jnp.repeat(adapter_indices, seq_len) + else: + adapter_indices_flat = None + + router_logits = self.gate(hidden_states_flat) + top_k_weights, top_k_index = self._compute_routing(router_logits) + + expert_output = self.experts(hidden_states_flat, top_k_index, top_k_weights, adapter_indices_flat) + shared_output = self.shared_experts( + hidden_states_flat.reshape(batch_size, seq_len, hidden_size), adapter_indices + ).reshape(-1, hidden_size) + expert_output = expert_output + shared_output + + return expert_output.reshape(batch_size, seq_len, hidden_size) + + +class Glm4DecoderLayer(nnx.Module): + + def __init__(self, config: Glm4Config, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + self.self_attn = Glm4Attention(config, dtype=dtype, rngs=rngs) + + # Check if this is an MoE model and if this layer should use MoE + if config.n_routed_experts and layer_idx >= config.first_k_dense_replace: + self.mlp = Glm4MoE(config, dtype=dtype, rngs=rngs) + else: + self.mlp = Glm4MLP(config, dtype=dtype, rngs=rngs) + + def __call__( + self, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None = None, + kv_cache: tuple[jax.Array, jax.Array] | None = None, + ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, updated_cache = self.self_attn( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(hidden_states, adapter_indices=adapter_indices) + hidden_states = residual + mlp_output + + return hidden_states, updated_cache + + +class Glm4Model(nnx.Module): + + def __init__(self, config: Glm4Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + + self.embed_tokens = LoRAEmbed( + num_embeddings=config.vocab_size, + features=config.hidden_size, + dtype=dtype, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + param_dtype=dtype, + embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), + rngs=rngs, + ) + self.layers = nnx.List( + [Glm4DecoderLayer(config, layer_idx=i, dtype=dtype, rngs=rngs) for i in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + + def __call__( + self, + input_ids: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + output_hidden_states: bool | None = None, + adapter_indices: jax.Array | None = None, + kv_cache: KVCache | None = None, + ) -> ModelOutput: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) + all_hidden_states: list[jax.Array] = [] + updated_keys, updated_values = [], [] + + for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + hidden_states, (k, v) = layer( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]), + ) + updated_keys.append(k) + updated_values.append(v) + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + return ModelOutput( + last_hidden_state=hidden_states, + kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask), + hidden_states=all_hidden_states if output_hidden_states else None, + ) + + +class Glm4ForCausalLM(nnx.Module, ModelForCausalLM, GeneratorMixin, LogitsProcessorMixin): + + def __init__(self, config: Glm4Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + self.config = config + self.model = Glm4Model(config, dtype=dtype, rngs=rngs) + + if not self.config.tie_word_embeddings: + self.lm_head = LoRALinear( + config.hidden_size, + config.vocab_size, + use_bias=False, + dtype=dtype, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + rngs=rngs, + ) + else: + self.lm_head = self.model.embed_tokens.T + + def get_lm_head(self) -> LMHead: + """Return the lm_head callable for logits computation.""" + return self.lm_head + + @staticmethod + def is_lora_param(path: tuple, _value) -> bool: + """Return True if a parameter path corresponds to LoRA weights.""" + return any(name in path for name in ("lora_A", "lora_B")) + + def __call__( + self, + input_ids: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array | None = None, + output_hidden_states: bool | None = None, + adapter_indices: jax.Array | None = None, + kv_cache: KVCache | None = None, + ) -> CausalLMOutput: + if positions is None: + positions = jnp.arange(attention_mask.shape[1])[None, :] + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + positions=positions, + output_hidden_states=output_hidden_states, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + ) + + return CausalLMOutput( + last_hidden_state=outputs.last_hidden_state, + kv_cache=outputs.kv_cache, + hidden_states=outputs.hidden_states, + ) + + +Glm4MoeForCausalLM = Glm4ForCausalLM diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 6e840febf..3e457948d 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -64,6 +64,7 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: import tx.models.llama3 import tx.models.qwen3 import tx.models.deepseekv3 + import tx.models.glm4 for architecture in config.architectures or []: if hasattr(tx.models.llama3, architecture): @@ -72,6 +73,8 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: return getattr(tx.models.qwen3, architecture) if hasattr(tx.models.deepseekv3, architecture): return getattr(tx.models.deepseekv3, architecture) + if hasattr(tx.models.glm4, architecture): + return getattr(tx.models.glm4, architecture) raise ValueError(f"None of the architectures {config.architectures} is currently supported.")