diff --git a/tpu_inference/layers/vllm/sharding.py b/tpu_inference/layers/vllm/sharding.py index 5f9fb17cf..e70a0aa11 100644 --- a/tpu_inference/layers/vllm/sharding.py +++ b/tpu_inference/layers/vllm/sharding.py @@ -102,22 +102,15 @@ def _shard_base_linear_lora(layer: BaseLinearLayerWithLoRA, # NOTE: lora_a_stacked[i] has shape [max_loras, 1, num_out, num_in] sharded_lora_a_tpu = torch.nn.ParameterList() sharded_lora_b_tpu = torch.nn.ParameterList() - sharded_lora_bias_tpu = torch.nn.ParameterList() for i in range(layer.n_slices): sharded_lora_a_tpu.append( _shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh)) sharded_lora_b_tpu.append( _shard_tensor_to_tpu_replicated(layer.lora_b_stacked[i], mesh)) - if layer.lora_bias_stacked is not None: - sharded_lora_bias_tpu.append( - _shard_tensor_to_tpu_replicated(layer.lora_bias_stacked[i], - mesh)) layer.lora_a_stacked = sharded_lora_a_tpu layer.lora_b_stacked = sharded_lora_b_tpu - if layer.lora_bias_stacked is not None: - layer.lora_bias_stacked = sharded_lora_bias_tpu # TODO: Add custom sharding logic for following lora layers diff --git a/tpu_inference/models/vllm/vllm_model_wrapper.py b/tpu_inference/models/vllm/vllm_model_wrapper.py index 51a064d39..b22977ba2 100644 --- a/tpu_inference/models/vllm/vllm_model_wrapper.py +++ b/tpu_inference/models/vllm/vllm_model_wrapper.py @@ -255,11 +255,9 @@ def _tpu_set_lora( lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, ): with torchax.default_env(): - self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, - bias) + self._original_set_lora(index, lora_a, lora_b, embeddings_tensor) def _tpu_reset_lora(self, index: int): with torchax.default_env():