Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/pipeline_jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ steps:
commands:
- |
.buildkite/scripts/run_in_docker.sh \
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py -k multi_lora'
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py'


# -----------------------------------------------------------------
Expand Down
21 changes: 10 additions & 11 deletions tests/lora/test_lora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# https://github.com/vllm-project/vllm/blob/ed10f3cea199a7a1f3532fbe367f5c5479a6cae9/tests/tpu/lora/test_lora.py
import os
import time

import pytest
import vllm
Expand All @@ -16,17 +17,6 @@
# 100 training iterations with a training batch size of 100.


@pytest.fixture(scope="function", autouse=True)
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
"""
Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
for all tests in this file
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
yield


def setup_vllm(num_loras: int, tp: int = 1) -> vllm.LLM:
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
max_model_len=256,
Expand Down Expand Up @@ -67,6 +57,9 @@ def test_single_lora(tp):
assert answer.isdigit()
assert int(answer) == 2

del llm
time.sleep(10)


@pytest.mark.parametrize("tp", TP)
def test_lora_hotswapping(tp):
Expand Down Expand Up @@ -99,6 +92,9 @@ def test_lora_hotswapping(tp):
assert answer.isdigit()
assert int(answer) == i + 1, f"Expected {i + 1}, got {answer}"

del llm
time.sleep(10)


@pytest.mark.parametrize("tp", TP)
def test_multi_lora(tp):
Expand Down Expand Up @@ -132,3 +128,6 @@ def test_multi_lora(tp):
assert int(
output.strip()
[0]) == i + 1, f"Expected {i + 1}, got {int(output.strip()[0])}"

del llm
time.sleep(10)