@@ -24,11 +24,9 @@ def shard_model_to_tpu(model: torch.nn.Module,
2424 mesh : Mesh ) -> dict [str , torchax .torch .Tensor ]:
2525 """
2626 Shard the model weights and move them to TPU.
27-
2827 At the same time, also turn the weight tensors into torchax tensors so that
2928 jax code can interop with it and the overall program can be traced and
3029 compiled in XLA.
31-
3230 Args:
3331 model: A PyTorch model whose weights are on CPU main memory.
3432 mesh: JAX mesh object for sharding.
@@ -51,6 +49,18 @@ def shard_model_to_tpu(model: torch.nn.Module,
5149 return {** params , ** buffers }
5250
5351
52+ def update_lora (model : torch .nn .Module ,
53+ initial_params_buffers ) -> dict [str , torchax .torch .Tensor ]:
54+ params , buffers = _extract_all_params_buffers (model )
55+ params_buffers = {** params , ** buffers }
56+ for k , v in params_buffers .items ():
57+ if 'lora_a_stacked' in k or 'lora_b_stacked' in k :
58+ assert k in initial_params_buffers , f"{ k } not in initial_params_buffers"
59+ initial_params_buffers [k ] = v
60+
61+ return initial_params_buffers
62+
63+
5464def _extract_all_params_buffers (model : torch .nn .Module ):
5565 return dict (model .named_parameters ()), dict (model .named_buffers ())
5666
@@ -116,11 +126,11 @@ def _shard_base_linear_lora_replicated(layer: BaseLinearLayerWithLoRA,
116126# TODO: Add custom sharding logic for following lora layers
117127def _shard_merged_column_parallel_linear_lora (
118128 layer : MergedColumnParallelLinearWithLoRA , mesh : Mesh ) -> None :
129+ assert layer .n_slices > 0 , "layer.n_slices should be greater than 0"
119130 # lora_a_stacked[i] has shape [max_loras, 1, max_lora_rank, in_features]
120131 sharded_lora_a_tpu = torch .nn .ParameterList ()
121132 sharded_lora_b_tpu = torch .nn .ParameterList ()
122133
123- assert layer .n_slices > 0 , "layer.n_slices should be greater than 0"
124134 # lora_b_stacked[i] has shape [max_loras, 1, out_features, max_lora_rank]
125135 lora_b_partition_spec = P (None , None , 'model' , None )
126136 lora_b_sharding = NamedSharding (mesh , lora_b_partition_spec )
0 commit comments