Skip to content

Commit f17eca7

Browse files
vanbasten23sierraisland
authored andcommitted
Reduce the host overhead for LoRA (#930)
Signed-off-by: Xiongfei Wei <[email protected]>
1 parent b5080b4 commit f17eca7

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

tpu_inference/layers/vllm/sharding.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ def shard_model_to_tpu(model: torch.nn.Module,
2424
mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
2525
"""
2626
Shard the model weights and move them to TPU.
27-
2827
At the same time, also turn the weight tensors into torchax tensors so that
2928
jax code can interop with it and the overall program can be traced and
3029
compiled in XLA.
31-
3230
Args:
3331
model: A PyTorch model whose weights are on CPU main memory.
3432
mesh: JAX mesh object for sharding.
@@ -51,6 +49,18 @@ def shard_model_to_tpu(model: torch.nn.Module,
5149
return {**params, **buffers}
5250

5351

52+
def update_lora(model: torch.nn.Module,
53+
initial_params_buffers) -> dict[str, torchax.torch.Tensor]:
54+
params, buffers = _extract_all_params_buffers(model)
55+
params_buffers = {**params, **buffers}
56+
for k, v in params_buffers.items():
57+
if 'lora_a_stacked' in k or 'lora_b_stacked' in k:
58+
assert k in initial_params_buffers, f"{k} not in initial_params_buffers"
59+
initial_params_buffers[k] = v
60+
61+
return initial_params_buffers
62+
63+
5464
def _extract_all_params_buffers(model: torch.nn.Module):
5565
return dict(model.named_parameters()), dict(model.named_buffers())
5666

@@ -116,11 +126,11 @@ def _shard_base_linear_lora_replicated(layer: BaseLinearLayerWithLoRA,
116126
# TODO: Add custom sharding logic for following lora layers
117127
def _shard_merged_column_parallel_linear_lora(
118128
layer: MergedColumnParallelLinearWithLoRA, mesh: Mesh) -> None:
129+
assert layer.n_slices > 0, "layer.n_slices should be greater than 0"
119130
# lora_a_stacked[i] has shape [max_loras, 1, max_lora_rank, in_features]
120131
sharded_lora_a_tpu = torch.nn.ParameterList()
121132
sharded_lora_b_tpu = torch.nn.ParameterList()
122133

123-
assert layer.n_slices > 0, "layer.n_slices should be greater than 0"
124134
# lora_b_stacked[i] has shape [max_loras, 1, out_features, max_lora_rank]
125135
lora_b_partition_spec = P(None, None, 'model', None)
126136
lora_b_sharding = NamedSharding(mesh, lora_b_partition_spec)

tpu_inference/runner/lora_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
88
from vllm.lora.request import LoRARequest
99

10-
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
10+
from tpu_inference.layers.vllm.sharding import update_lora
1111

1212
if TYPE_CHECKING:
1313
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
@@ -41,8 +41,8 @@ def set_active_loras(self, num_scheduled_tokens_per_req,
4141
self.runner._set_active_loras(prompt_lora_mapping, token_lora_mapping,
4242
lora_requests)
4343

44-
params_and_buffers = shard_model_to_tpu(self.runner.model.model,
45-
self.runner.mesh)
44+
params_and_buffers = update_lora(
45+
self.runner.model.model, initial_params_buffers=self.runner.state)
4646
self.runner.state = jax_view(params_and_buffers)
4747

4848
def extract_lora_metadata(self):

tpu_inference/worker/tpu_worker_jax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ def add_lora(
241241
def profile(self, is_start: bool = True):
242242
if is_start:
243243
options = jax.profiler.ProfileOptions()
244+
# default: https://docs.jax.dev/en/latest/profiling.html#general-options
244245
options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
246+
options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
245247
jax.profiler.start_trace(self.profile_dir,
246248
profiler_options=options)
247249
else:

0 commit comments

Comments
 (0)