Reduce the host overhead for LoRA #930
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR optimizes the host perf of LoRA.
On TPU-inference, the current step time is 130 ms and the execution time is 5.59 ms. So why does it take so long between each execution? I discovered that
shard_model_to_tputakes 94% of the step time.This PR optimize
shard_model_to_tpuand only do the what is necessary for lora:During inference, before each fwd, LoRA needs to call
shard_model_to_tpubecause the linear.lora matrices may be updated.shard_model_to_tpudoes 3 things essentially: shard the module and move to TPU (via _shard_module_to_tpu), extract params and buffers, and move the remaining params/buffers to TPU (via PyTree.tree_map_only). In the profile, both_shard_module_to_tpuandPyTree.tree_map_onlytake equally excessive time. This is what this PR proposes:I don't need to re-shard every time. When vllm updates the lora, the sharding still remains the same and the new lora is moved from CPU to TPU so the device also remains the same (both via
_copy). That means we only need to shard once at first when we load the model and don't need to re-shard later. (I also wrote a torchax repro that demonstratescopy_won't change the tensor's sharding)'PyTree.tree_map_only' is only used to move the remaining params/buffers to TPU. But for lora, the only possibly updated params are
lorarelated matrices and we know the lora matrices are already torchax tensor on TPU. So there is really no need to do this tree_map_only step.That said, the only necessary thing for the LoRA case is to re-extract the parameter for lora and update the original
params_buffers.With the change, I profiled again and the step time has reduced from 130ms to 13ms
Tests
CI
Checklist
Before submitting this PR, please make sure: