Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skyrl-tx/tests/models/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_llama3(tp: int):

base_config = AutoConfig.from_pretrained(model_name)
config = Llama3Config(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True)
mesh = jax.make_mesh((1, tp), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
mesh = jax.make_mesh((1, tp), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = Llama3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_safetensors(tmp, config, model)
Expand Down
2 changes: 1 addition & 1 deletion skyrl-tx/tests/models/test_llama3_lora_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_lora_training():
config = Llama3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True)

checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"])
mesh = jax.make_mesh((1, 1), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = Llama3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0))
load_safetensors(checkpoint_path, config, model)
Expand Down
2 changes: 1 addition & 1 deletion skyrl-tx/tests/models/test_models_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tx.utils.models import load_safetensors

MODEL_PARAMS = [
("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("dp", "tp")),
("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("fsdp", "tp")),
("Qwen/Qwen3-0.6B", Qwen3Config, Qwen3ForCausalLM, ("fsdp", "tp")),
]
MODEL_IDS = ["llama3", "qwen3"]
Expand Down
29 changes: 15 additions & 14 deletions skyrl-tx/tx/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
num_embeddings: int,
features: int,
*,
sharding: tuple[str | None, ...],
max_lora_adapters: int = 0,
max_lora_rank: int = 8,
dtype: jnp.dtype = jnp.float32,
Expand All @@ -131,13 +132,9 @@ def __init__(
features=features,
dtype=dtype,
param_dtype=param_dtype,
embedding_init=embedding_init,
embedding_init=nnx.with_partitioning(embedding_init, sharding),
rngs=rngs,
)
assert (
self.embedding[...].sharding is not None
), "LoRAEmbed layer needs sharding, you can specify it by using nnx.with_partitioning on the embedding_init"
sharding = self.embedding[...].sharding.spec

self.init_lora(
max_lora_adapters=max_lora_adapters,
Expand Down Expand Up @@ -181,6 +178,7 @@ def __init__(
in_features: int,
out_features: int,
*,
sharding: tuple[str | None, ...],
max_lora_adapters: int = 0,
max_lora_rank: int = 8,
dtype: jnp.dtype = jnp.float32,
Expand All @@ -200,14 +198,11 @@ def __init__(
use_bias=use_bias,
dtype=dtype,
param_dtype=param_dtype,
kernel_init=kernel_init,
bias_init=bias_init,
kernel_init=nnx.with_partitioning(kernel_init, sharding),
bias_init=nnx.with_partitioning(bias_init, (sharding[-1],)),
rngs=rngs,
)
assert (
self.kernel[...].sharding is not None
), "LoRALinear layer needs sharding, you can specify it by using nnx.with_partitioning on the kernel_init"
sharding = self.kernel[...].sharding.spec

self.init_lora(
max_lora_adapters=max_lora_adapters,
max_lora_rank=max_lora_rank,
Expand All @@ -233,6 +228,7 @@ def __init__(
in_features: int,
out_features: int,
*,
sharding: tuple[str | None, ...],
max_lora_adapters: int = 0,
max_lora_rank: int = 8,
dtype: jnp.dtype = jnp.float32,
Expand All @@ -243,10 +239,15 @@ def __init__(
self.in_features = in_features
self.out_features = out_features

self.weight = Param(num_experts, in_features, out_features, dtype=dtype, kernel_init=kernel_init, rngs=rngs)
self.weight = Param(
num_experts,
in_features,
out_features,
dtype=dtype,
kernel_init=nnx.with_partitioning(kernel_init, sharding),
rngs=rngs,
)

assert self.weight[...].sharding is not None, "LoRAExpert layer needs sharding"
sharding = self.weight[...].sharding.spec
self.init_lora(
max_lora_adapters=max_lora_adapters,
max_lora_rank=max_lora_rank,
Expand Down
42 changes: 28 additions & 14 deletions skyrl-tx/tx/models/deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs
self.q_proj = LoRALinear(
in_features=config.hidden_size,
out_features=self.num_heads * self.qk_head_dim,
sharding=("fsdp", tp_shard),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
param_dtype=dtype,
use_bias=False,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", tp_shard)),
kernel_init=nnx.initializers.lecun_normal(),
rngs=rngs,
)
self.q_a_proj = None
Expand All @@ -53,61 +54,66 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs
self.q_a_proj = LoRALinear(
in_features=config.hidden_size,
out_features=self.q_lora_rank,
sharding=("fsdp", None),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
param_dtype=dtype,
use_bias=config.attention_bias,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", None)),
kernel_init=nnx.initializers.lecun_normal(),
rngs=rngs,
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs)
self.q_b_proj = LoRALinear(
in_features=self.q_lora_rank,
out_features=self.num_heads * self.qk_head_dim,
sharding=(None, tp_shard),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
param_dtype=dtype,
use_bias=False,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)),
kernel_init=nnx.initializers.lecun_normal(),
rngs=rngs,
)

self.kv_a_proj_with_mqa = LoRALinear(
in_features=config.hidden_size,
out_features=self.kv_lora_rank + self.qk_rope_head_dim,
sharding=("fsdp", None),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
param_dtype=dtype,
use_bias=config.attention_bias,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", None)),
kernel_init=nnx.initializers.lecun_normal(),
rngs=rngs,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs)

self.kv_b_proj = LoRALinear(
in_features=self.kv_lora_rank,
out_features=self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
sharding=(None, tp_shard),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
param_dtype=dtype,
use_bias=False,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, tp_shard)),
kernel_init=nnx.initializers.lecun_normal(),
rngs=rngs,
)

self.o_proj = LoRALinear(
in_features=self.num_heads * self.v_head_dim,
out_features=config.hidden_size,
sharding=(tp_shard, "fsdp"),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
param_dtype=dtype,
use_bias=config.attention_bias,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (tp_shard, "fsdp")),
kernel_init=nnx.initializers.lecun_normal(),
rngs=rngs,
)

Expand Down Expand Up @@ -189,32 +195,35 @@ def __init__(
self.gate_proj = LoRALinear(
config.hidden_size,
intermediate_size,
sharding=("fsdp", "tp"),
use_bias=False,
dtype=dtype,
param_dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")),
kernel_init=nnx.initializers.lecun_normal(),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
rngs=rngs,
)
self.up_proj = LoRALinear(
config.hidden_size,
intermediate_size,
sharding=("fsdp", "tp"),
use_bias=False,
dtype=dtype,
param_dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("fsdp", "tp")),
kernel_init=nnx.initializers.lecun_normal(),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
rngs=rngs,
)
self.down_proj = LoRALinear(
intermediate_size,
config.hidden_size,
sharding=("tp", "fsdp"),
use_bias=False,
dtype=dtype,
param_dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("tp", "fsdp")),
kernel_init=nnx.initializers.lecun_normal(),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
rngs=rngs,
Expand Down Expand Up @@ -260,30 +269,33 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs
config.n_routed_experts,
config.hidden_size,
config.moe_intermediate_size,
sharding=("ep", "fsdp", "tp"),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")),
kernel_init=nnx.initializers.lecun_normal(),
rngs=rngs,
)
self.up_proj = LoRAExpert(
config.n_routed_experts,
config.hidden_size,
config.moe_intermediate_size,
sharding=("ep", "fsdp", "tp"),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")),
kernel_init=nnx.initializers.lecun_normal(),
rngs=rngs,
)
self.down_proj = LoRAExpert(
config.n_routed_experts,
config.moe_intermediate_size,
config.hidden_size,
sharding=("ep", "tp", "fsdp"),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "tp", "fsdp")),
kernel_init=nnx.initializers.lecun_normal(),
rngs=rngs,
)

Expand Down Expand Up @@ -452,11 +464,12 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs
self.embed_tokens = LoRAEmbed(
num_embeddings=config.vocab_size,
features=config.hidden_size,
sharding=("tp", None),
dtype=dtype,
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
param_dtype=dtype,
embedding_init=nnx.with_partitioning(nnx.initializers.normal(), ("tp", None)),
embedding_init=nnx.initializers.normal(),
rngs=rngs,
)
self.layers = nnx.List(
Expand Down Expand Up @@ -520,10 +533,11 @@ def __init__(self, config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs
self.lm_head = LoRALinear(
config.hidden_size,
config.vocab_size,
sharding=(None, "tp"),
use_bias=False,
dtype=dtype,
param_dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp")),
kernel_init=nnx.initializers.lecun_normal(),
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
rngs=rngs,
Expand Down
Loading
Loading