diff --git a/skyrl-tx/tests/models/test_llama3.py b/skyrl-tx/tests/models/test_llama3.py index fa195567f..91b8575bc 100644 --- a/skyrl-tx/tests/models/test_llama3.py +++ b/skyrl-tx/tests/models/test_llama3.py @@ -39,7 +39,7 @@ def test_llama3(tp: int): base_config = AutoConfig.from_pretrained(model_name) config = Llama3Config(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True) - mesh = jax.make_mesh((1, tp), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) + mesh = jax.make_mesh((1, tp), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Llama3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) load_safetensors(tmp, config, model) diff --git a/skyrl-tx/tests/models/test_llama3_lora_training.py b/skyrl-tx/tests/models/test_llama3_lora_training.py index af91d373e..61fa029c6 100644 --- a/skyrl-tx/tests/models/test_llama3_lora_training.py +++ b/skyrl-tx/tests/models/test_llama3_lora_training.py @@ -18,7 +18,7 @@ def test_lora_training(): config = Llama3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True) checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"]) - mesh = jax.make_mesh((1, 1), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) + mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): model = Llama3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) load_safetensors(checkpoint_path, config, model) diff --git a/skyrl-tx/tests/models/test_models_common.py b/skyrl-tx/tests/models/test_models_common.py index f0dc261bc..a90973371 100644 --- a/skyrl-tx/tests/models/test_models_common.py +++ b/skyrl-tx/tests/models/test_models_common.py @@ -13,7 +13,7 @@ from tx.utils.models import load_safetensors MODEL_PARAMS = [ - ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp")), + ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("fsdp", "tp")), ("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")), ] MODEL_IDS = ["llama3", "qwen3"] diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 573b83adb..776b7af59 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -117,6 +117,7 @@ def __init__( num_embeddings: int, features: int, *, + sharding: tuple[str | None, ...], max_lora_adapters: int = 0, max_lora_rank: int = 8, dtype: jnp.dtype = jnp.float32, @@ -131,13 +132,9 @@ def __init__( features=features, dtype=dtype, param_dtype=param_dtype, - embedding_init=embedding_init, + embedding_init=nnx.with_partitioning(embedding_init, sharding), rngs=rngs, ) - assert ( - self.embedding[...].sharding is not None - ), "LoRAEmbed layer needs sharding, you can specify it by using nnx.with_partitioning on the embedding_init" - sharding = self.embedding[...].sharding.spec self.init_lora( max_lora_adapters=max_lora_adapters, @@ -181,6 +178,7 @@ def __init__( in_features: int, out_features: int, *, + sharding: tuple[str | None, ...], max_lora_adapters: int = 0, max_lora_rank: int = 8, dtype: jnp.dtype = jnp.float32, @@ -200,14 +198,11 @@ def __init__( use_bias=use_bias, dtype=dtype, param_dtype=param_dtype, - kernel_init=kernel_init, - bias_init=bias_init, + kernel_init=nnx.with_partitioning(kernel_init, sharding), + bias_init=nnx.with_partitioning(bias_init, (sharding[-1],)), rngs=rngs, ) - assert ( - self.kernel[...].sharding is not None - ), "LoRALinear layer needs sharding, you can specify it by using nnx.with_partitioning on the kernel_init" - sharding = self.kernel[...].sharding.spec + self.init_lora( max_lora_adapters=max_lora_adapters, max_lora_rank=max_lora_rank, @@ -233,6 +228,7 @@ def __init__( in_features: int, out_features: int, *, + sharding: tuple[str | None, ...], max_lora_adapters: int = 0, max_lora_rank: int = 8, dtype: jnp.dtype = jnp.float32, @@ -243,10 +239,15 @@ def __init__( self.in_features = in_features self.out_features = out_features - self.weight = Param(num_experts, in_features, out_features, dtype=dtype, kernel_init=kernel_init, rngs=rngs) + self.weight = Param( + num_experts, + in_features, + out_features, + dtype=dtype, + kernel_init=nnx.with_partitioning(kernel_init, sharding), + rngs=rngs, + ) - assert self.weight[...].sharding is not None, "LoRAExpert layer needs sharding" - sharding = self.weight[...].sharding.spec self.init_lora( max_lora_adapters=max_lora_adapters, max_lora_rank=max_lora_rank, diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 07aea272b..9232832d1 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -37,12 +37,13 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.q_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_heads * self.qk_head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.q_a_proj = None @@ -53,36 +54,39 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.q_a_proj = LoRALinear( in_features=config.hidden_size, out_features=self.q_lora_rank, + sharding=("fsdp", None), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=config.attention_bias, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", None)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) self.q_b_proj = LoRALinear( in_features=self.q_lora_rank, out_features=self.num_heads * self.qk_head_dim, + sharding=(None, tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.kv_a_proj_with_mqa = LoRALinear( in_features=config.hidden_size, out_features=self.kv_lora_rank + self.qk_rope_head_dim, + sharding=("fsdp", None), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=config.attention_bias, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", None)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) @@ -90,24 +94,26 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.kv_b_proj = LoRALinear( in_features=self.kv_lora_rank, out_features=self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + sharding=(None, tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.o_proj = LoRALinear( in_features=self.num_heads * self.v_head_dim, out_features=config.hidden_size, + sharding=(tp_shard, "fsdp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=config.attention_bias, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (tp_shard, "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) @@ -189,10 +195,11 @@ def __init__( self.gate_proj = LoRALinear( config.hidden_size, intermediate_size, + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -200,10 +207,11 @@ def __init__( self.up_proj = LoRALinear( config.hidden_size, intermediate_size, + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -211,10 +219,11 @@ def __init__( self.down_proj = LoRALinear( intermediate_size, config.hidden_size, + sharding=("tp", "fsdp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("tp", "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -260,30 +269,33 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs config.n_routed_experts, config.hidden_size, config.moe_intermediate_size, + sharding=("ep", "fsdp", "tp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.up_proj = LoRAExpert( config.n_routed_experts, config.hidden_size, config.moe_intermediate_size, + sharding=("ep", "fsdp", "tp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.down_proj = LoRAExpert( config.n_routed_experts, config.moe_intermediate_size, config.hidden_size, + sharding=("ep", "tp", "fsdp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "tp", "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) @@ -452,11 +464,12 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, features=config.hidden_size, + sharding=("tp", None), dtype=dtype, max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, param_dtype=dtype, - embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), + embedding_init=nnx.initializers.normal(), rngs=rngs, ) self.layers = nnx.List( @@ -520,10 +533,11 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs self.lm_head = LoRALinear( config.hidden_size, config.vocab_size, + sharding=(None, "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 0522f75be..b1ae1027b 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -31,48 +31,52 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.q_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_heads * self.head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.k_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_kv_heads * self.head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.v_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_kv_heads * self.head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.o_proj = LoRALinear( in_features=self.num_heads * self.head_dim, out_features=config.hidden_size, + sharding=(tp_shard, "fsdp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (tp_shard, None)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) @@ -115,10 +119,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.gate_proj = LoRALinear( config.hidden_size, config.intermediate_size, + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -126,10 +131,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.up_proj = LoRALinear( config.hidden_size, config.intermediate_size, + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -137,10 +143,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.down_proj = LoRALinear( config.intermediate_size, config.hidden_size, + sharding=("tp", "fsdp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("tp", None)), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -196,11 +203,12 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, features=config.hidden_size, + sharding=("tp", None), dtype=dtype, max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, param_dtype=dtype, - embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), + embedding_init=nnx.initializers.normal(), rngs=rngs, ) self.layers = nnx.List( @@ -263,10 +271,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.lm_head = LoRALinear( config.hidden_size, config.vocab_size, + sharding=(None, "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index e35ce8069..1348cac09 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -32,45 +32,49 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.q_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_heads * self.head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.k_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_kv_heads * self.head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.v_proj = LoRALinear( in_features=config.hidden_size, out_features=self.num_kv_heads * self.head_dim, + sharding=("fsdp", tp_shard), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.o_proj = LoRALinear( in_features=self.num_heads * self.head_dim, out_features=config.hidden_size, + sharding=(tp_shard, "fsdp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, param_dtype=dtype, use_bias=False, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (tp_shard, "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) @@ -116,10 +120,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.gate_proj = LoRALinear( config.hidden_size, config.intermediate_size, + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -127,10 +132,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.up_proj = LoRALinear( config.hidden_size, config.intermediate_size, + sharding=("fsdp", "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -138,10 +144,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.down_proj = LoRALinear( config.intermediate_size, config.hidden_size, + sharding=("tp", "fsdp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("tp", "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs, @@ -161,30 +168,33 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> config.num_experts, config.hidden_size, config.moe_intermediate_size, + sharding=("ep", "fsdp", "tp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.up_proj = LoRAExpert( config.num_experts, config.hidden_size, config.moe_intermediate_size, + sharding=("ep", "fsdp", "tp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) self.down_proj = LoRAExpert( config.num_experts, config.moe_intermediate_size, config.hidden_size, + sharding=("ep", "tp", "fsdp"), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "tp", "fsdp")), + kernel_init=nnx.initializers.lecun_normal(), rngs=rngs, ) @@ -311,11 +321,12 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.embed_tokens = LoRAEmbed( num_embeddings=config.vocab_size, features=config.hidden_size, + sharding=("tp", None), dtype=dtype, max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, param_dtype=dtype, - embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)), + embedding_init=nnx.initializers.normal(), rngs=rngs, ) self.layers = nnx.List( @@ -378,10 +389,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> self.lm_head = LoRALinear( config.hidden_size, config.vocab_size, + sharding=(None, "tp"), use_bias=False, dtype=dtype, param_dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")), + kernel_init=nnx.initializers.lecun_normal(), max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, rngs=rngs,