Skip to content

Commit 53db3b4

Browse files
authored
Add the bgmv tests (#942)
Signed-off-by: Xiongfei Wei <[email protected]>
1 parent 608909d commit 53db3b4

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ steps:
119119
commands:
120120
- |
121121
.buildkite/scripts/run_in_docker.sh \
122-
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py'
122+
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/'
123123
124124
- label: "E2E MLPerf tests for JAX + vLLM models on multiple chips"
125125
key: test_11

tests/lora/test_bgmv.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import jax
2+
import torch
3+
import torchax
4+
5+
from tpu_inference.lora.torch_lora_ops import bgmv_torch
6+
7+
8+
def test_bgmv_torch():
9+
num_tokens = 16
10+
hidden_size = 128
11+
max_loras = 9
12+
max_lora_rank = 8
13+
14+
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
15+
inputs = torch.rand(num_tokens, hidden_size, device='jax')
16+
loras = torch.rand(max_loras,
17+
1,
18+
max_lora_rank,
19+
hidden_size,
20+
device='jax')
21+
idxs = torch.randint(0, max_loras, (num_tokens, ), device='jax')
22+
23+
actual = bgmv_torch(inputs, loras, idxs)
24+
expected = _ref_bgmv_torch(inputs, loras, idxs)
25+
torch.testing.assert_close(actual, expected, atol=3e-2, rtol=1e-3)
26+
27+
28+
def _ref_bgmv_torch(inputs, loras, idxs):
29+
if len(loras.shape) == 4:
30+
loras = loras.squeeze(axis=1)
31+
32+
# Another equivalent ref impl is as the 2 lines below.
33+
# selected_loras = loras[idxs]
34+
# return torch.einsum('td,tld->tl', inputs, selected_loras)
35+
num_tokens, _ = inputs.shape
36+
outputs = []
37+
for i in range(num_tokens):
38+
input = inputs[i] # [hidden_size]
39+
lora = loras[idxs[i]] # [max_lora_rank, hidden_size]
40+
out = torch.matmul(lora, input)
41+
outputs.append(out)
42+
43+
return torch.stack(outputs, axis=0)

0 commit comments

Comments
 (0)