[tx] Make sharding explicit in LoRA constructors#997
[tx] Make sharding explicit in LoRA constructors#997pcmoritz merged 4 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the LoRA layer constructors (LoRAEmbed, LoRALinear, LoRAExpert) to explicitly accept a sharding parameter. This change streamlines the sharding configuration by removing the need for nnx.with_partitioning calls at the instantiation sites within the model definitions (deepseekv3.py, llama3.py, qwen3.py) and eliminates the associated runtime assertions. The refactoring consistently applies across all affected files, improving clarity and control over sharding for LoRA layers. All changes are well-implemented and align with the stated objective of making sharding explicit.
raulchen
left a comment
There was a problem hiding this comment.
we can define constants for the shardings, as they are repeated.
|
Do you mean across models? Within models there is generally not a lot of repetition, only a little bit, but then it is also nice that each tensor has its sharding explicitly directly next to it. I'm going to merge this for now so we can unblock the other PR, but we can think about whether there is a good way to structure this in a better way to get rid of repetition (and forgetting certain sharding) across models. Let me know if you have some ideas :) |
This is in preparation for merging #996, so we don't need to depend on the jax tracer. It is also slightly cleaner this way and the assert is not needed any more, since the error is "defined away".
It also adds the FSDP sharding for llama3.