diff --git a/examples/offline_inference.py b/examples/offline_inference.py index d222dc30a..501db04e9 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -14,7 +14,7 @@ def create_parser(): parser = FlexibleArgumentParser() # Add engine args EngineArgs.add_cli_args(parser) - parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + parser.set_defaults(model="Qwen/Qwen2.5-3B-Instruct") parser.set_defaults(max_model_len=1024) # Add sampling params diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py index 86623167e..3e94a3da3 100644 --- a/tests/lora/test_lora.py +++ b/tests/lora/test_lora.py @@ -1,4 +1,5 @@ # https://github.com/vllm-project/vllm/blob/ed10f3cea199a7a1f3532fbe367f5c5479a6cae9/tests/tpu/lora/test_lora.py + import pytest import vllm from vllm.lora.request import LoRARequest @@ -25,10 +26,11 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch): yield -def setup_vllm(num_loras: int) -> vllm.LLM: +def setup_vllm(num_loras: int, num_devices: int = 1) -> vllm.LLM: return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=256, max_num_seqs=8, + tensor_parallel_size=num_devices, enable_lora=True, max_loras=num_loras, max_lora_rank=8) @@ -49,7 +51,56 @@ def test_single_lora(): "lora_adapter_2", 2, "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter") output = llm.generate(prompt, - sampling_params=vllm.SamplingParams(max_tokens=256, + sampling_params=vllm.SamplingParams(max_tokens=16, + temperature=0), + lora_request=lora_request)[0].outputs[0].text + + answer = output.strip()[0] + + assert answer.isdigit() + assert int(answer) == 2 + + +def test_single_lora_spmd(): + """ + This test ensures we can run a single LoRA adapter on the TPU backend. + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter" which + will force Qwen2.5-3B-Instruct to claim 1+1=2. + """ + # max_loras = 1 + # engine_args = EngineArgs( + # model="Qwen/Qwen2.5-3B-Instruct", + # max_model_len=256, + # max_num_seqs=8, + # enable_lora=True, + # max_loras=max_loras, + # max_lora_rank=8, + # ) + # vllm_config = engine_args.create_engine_config() + # with set_current_vllm_config(vllm_config): + # temp_file = tempfile.mkstemp()[1] + # init_distributed_environment( + # 1, + # 0, + # local_rank=0, + # distributed_init_method=f"file://{temp_file}", + # backend="gloo") + # ensure_model_parallel_initialized(1, 1) + + # num_devices = jax.local_device_count() # why does this line cause hanging. + # To test SPMD multi-chip case, only num_device=2 works for this model Qwen2.5-3B-Instruct. + # This is because this model has kv_head=2. https://github.com/vllm-project/tpu_commons/blob/a489e59c5b3a4d5c28e93775d5323970eecd66c9/tpu_commons/layers/jax/attention_interface.py#L275 here we shard the num_kv_heads. Only 2 can divide the num_kv_heads in this case. + num_devices = 2 + print(f'xw32 using TP={num_devices}') + llm = setup_vllm(1, num_devices) + + prompt = "What is 1+1? \n" + + lora_request = LoRARequest( + "lora_adapter_2", 2, + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter") + output = llm.generate(prompt, + sampling_params=vllm.SamplingParams(max_tokens=16, temperature=0), lora_request=lora_request)[0].outputs[0].text @@ -82,7 +133,7 @@ def test_lora_hotswapping(): for i, req in enumerate(lora_requests): output = llm.generate(prompt, sampling_params=vllm.SamplingParams( - max_tokens=256, temperature=0), + max_tokens=16, temperature=0), lora_request=req)[0].outputs[0].text answer = output.strip()[0] @@ -112,7 +163,7 @@ def test_multi_lora(): for i, req in enumerate(lora_requests): output = llm.generate(prompt, sampling_params=vllm.SamplingParams( - max_tokens=256, temperature=0), + max_tokens=16, temperature=0), lora_request=req)[0].outputs[0].text answer = output.strip()[0] diff --git a/tpu_commons/models/jax/attention.py b/tpu_commons/models/jax/attention.py index d48e63a0d..c036ad3a2 100644 --- a/tpu_commons/models/jax/attention.py +++ b/tpu_commons/models/jax/attention.py @@ -21,6 +21,7 @@ def sharded_ragged_paged_attention( v_scale: float | None = None, ): """Shards along KV heads.""" + # nonspmd(tp=1):q.shape=(16,16,128),k.shape=(16,2,128),kv_cache.shape=(40660,16,2,2,128) qkv_spec = P(None, "model", None) kv_cache_spec = P(None, None, "model") in_specs = ( @@ -86,6 +87,7 @@ def attention( md = attention_metadata # (T, N, H) + # nonspmd(tp=1):q.shape=(16,16,128),k.shape=(16,2,128),kv_cache.shape=(40660,16,2,2,128) output, kv_cache = sharded_ragged_paged_attention( head_dim_original**-0.5, mesh, attention_chunk_size, q_scale, k_scale, v_scale)( diff --git a/tpu_commons/models/vllm/jax_linear_common.py b/tpu_commons/models/vllm/jax_linear_common.py index c28de5cdb..2574c553a 100644 --- a/tpu_commons/models/vllm/jax_linear_common.py +++ b/tpu_commons/models/vllm/jax_linear_common.py @@ -127,6 +127,8 @@ def torch_to_jax_param( tensor = tensor.astype(jax_dtype) if fused: + # In non-lora qkv layer, tensor.shape=[3072, 2048], output_sizes=[2048, 512, 512], n_shards=4, dim=0 + # sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('model', None), memory_kind=device) tensor = reorder_concatenated_tensor_for_sharding( tensor, output_sizes, n_shards, dim) tensor = jax.device_put(tensor, sharding) diff --git a/tpu_commons/models/vllm/quantization/common.py b/tpu_commons/models/vllm/quantization/common.py index aad439c94..9762c3f09 100644 --- a/tpu_commons/models/vllm/quantization/common.py +++ b/tpu_commons/models/vllm/quantization/common.py @@ -60,6 +60,7 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase): "Unsupported linear layer type of %s. Can potentially yield " " bad performance.", type(layer)) + # non-lora: for qkv_parallel_linear, weight_sharding is PartitionSpec('model', None) self.bias_sharding = P(self.weight_sharding[0]) self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1) diff --git a/tpu_commons/models/vllm/sharding.py b/tpu_commons/models/vllm/sharding.py index e4ea1e429..8ec60133e 100644 --- a/tpu_commons/models/vllm/sharding.py +++ b/tpu_commons/models/vllm/sharding.py @@ -128,7 +128,47 @@ def _shard_column_parallel_linear_lora( def _shard_qkv_parallel_linear_lora(layer: MergedQKVParallelLinearWithLoRA, mesh: Mesh) -> None: - _shard_base_linear_lora(layer, mesh) + # mesh=Mesh(axis_sizes=(1, 2), axis_names=('data', 'model'), axis_types=(Auto, Auto)) + # NOTE: lora_a_stacked[i] has shape [max_loras, 1, num_out, num_in] + sharded_lora_a_tpu = torch.nn.ParameterList() + sharded_lora_b_tpu = torch.nn.ParameterList() + sharded_lora_bias_tpu = torch.nn.ParameterList() + + assert layer.n_slices > 0, "layer.n_slices should be greater than 0" + mesh_lora_b_shape = (1, 1) + (mesh.shape['data'], mesh.shape['model']) + mesh_lora_b_axis = ('replica_num_lora', 'replica', 'data', 'model') + lora_b_mesh = jax.make_mesh( + mesh_lora_b_shape, mesh_lora_b_axis, + devices=mesh.devices[0]) # mesh.devices=[[device0, ..device_n]] + lora_b_partition_spec = P(None, None, 'model', None) + lora_b_sharding = NamedSharding(lora_b_mesh, lora_b_partition_spec) + + mesh_lora_bias_shape = (1, 1) + (mesh.shape['model'], ) + mesh_lora_bias_axis = ('replica_num_lora', 'replica', 'model') + lora_bias_mesh = jax.make_mesh( + mesh_lora_bias_shape, mesh_lora_bias_axis, + devices=mesh.devices[0]) # mesh.devices=[[device0, ..device_n]] + lora_bias_partition_spec = P(None, None, 'model') + lora_bias_sharding = NamedSharding(lora_bias_mesh, + lora_bias_partition_spec) + + for i in range(layer.n_slices): + sharded_lora_a_tpu.append( + _shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh)) + + sharded_lora_b_tpu.append( + _convert_to_torchax_and_shard(layer.lora_b_stacked[i], + lora_b_sharding)) + + if layer.lora_bias_stacked is not None: + sharded_lora_bias_tpu.append( + _convert_to_torchax_and_shard(layer.lora_bias_stacked[i], + lora_bias_sharding)) + + layer.lora_a_stacked = sharded_lora_a_tpu + layer.lora_b_stacked = sharded_lora_b_tpu + if layer.lora_bias_stacked is not None: + layer.lora_bias_stacked = sharded_lora_bias_tpu def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA, @@ -152,7 +192,7 @@ def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA, def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None: for path, module in model.named_modules(): for module_type, sharding_func in MODULE_TYPE_TO_SHARDING_FUNC: - if isinstance(module, module_type): + if type(module) is module_type: logger.debug("shard %s with %s", path, sharding_func) sharding_func(module, mesh) break