Skip to content

Conversation

@vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Oct 23, 2025

Description

This PR optimizes the host perf of LoRA.

On TPU-inference, the current step time is 130 ms and the execution time is 5.59 ms. So why does it take so long between each execution? I discovered that shard_model_to_tpu takes 94% of the step time.

This PR optimize shard_model_to_tpu and only do the what is necessary for lora:

During inference, before each fwd, LoRA needs to call shard_model_to_tpu because the linear.lora matrices may be updated. shard_model_to_tpu does 3 things essentially: shard the module and move to TPU (via _shard_module_to_tpu), extract params and buffers, and move the remaining params/buffers to TPU (via PyTree.tree_map_only). In the profile, both _shard_module_to_tpu and PyTree.tree_map_only take equally excessive time. This is what this PR proposes:

  1. I don't need to re-shard every time. When vllm updates the lora, the sharding still remains the same and the new lora is moved from CPU to TPU so the device also remains the same (both via _copy). That means we only need to shard once at first when we load the model and don't need to re-shard later. (I also wrote a torchax repro that demonstrates copy_ won't change the tensor's sharding)

  2. 'PyTree.tree_map_only' is only used to move the remaining params/buffers to TPU. But for lora, the only possibly updated params are lora related matrices and we know the lora matrices are already torchax tensor on TPU. So there is really no need to do this tree_map_only step.

  3. That said, the only necessary thing for the LoRA case is to re-extract the parameter for lora and update the original params_buffers.

With the change, I profiled again and the step time has reduced from 130ms to 13ms

Tests

CI

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@vanbasten23 vanbasten23 changed the title Xiowei/lora perf 0 Reduce the host overhead for LoRA Oct 23, 2025
@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@vanbasten23 vanbasten23 marked this pull request as ready for review October 23, 2025 20:59
@vanbasten23
Copy link
Collaborator Author

cc: @qihqi @kyuyeunk

@vanbasten23 vanbasten23 requested a review from kyuyeunk October 24, 2025 17:05
@vanbasten23
Copy link
Collaborator Author

Thanks for the review.

The CI is green https://buildkite.com/tpu-commons/tpu-inference-ci/builds/4711#. The failure is irrelevant to this PR as they also fail in head: https://buildkite.com/tpu-commons/tpu-inference-ci/builds/4711#

@vanbasten23 vanbasten23 merged commit 03d76de into main Oct 28, 2025
3 of 4 checks passed
sierraisland pushed a commit that referenced this pull request Oct 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants