diff --git a/src/engine.py b/src/engine.py index 00cf401..83254bd 100644 --- a/src/engine.py +++ b/src/engine.py @@ -107,7 +107,7 @@ async def generate(self, job_input: JobInput): yield {"error": create_error_response(str(e)).model_dump()} async def _generate_vllm(self, llm_input, validated_sampling_params, batch_size, stream, apply_chat_template, request_id, batch_size_growth_factor, min_batch_size: str) -> AsyncGenerator[dict, None]: - if apply_chat_template or isinstance(llm_input, list): + if apply_chat_template: tokenizer_wrapper = self._get_tokenizer_for_chat_template() llm_input = tokenizer_wrapper.apply_chat_template(llm_input) results_generator = self.llm.generate(llm_input, validated_sampling_params, request_id) @@ -299,4 +299,4 @@ async def _handle_chat_or_completion_request(self, openai_request: JobInput): if self.raw_openai_output: batch = "".join(batch) yield batch - \ No newline at end of file + diff --git a/src/utils.py b/src/utils.py index ee1b927..5718c5a 100644 --- a/src/utils.py +++ b/src/utils.py @@ -42,12 +42,39 @@ def count_physical_cores(): return len(cores) +# These are to support sending multiple prompts or token arrays in a single request +def prompt_to_vllm_prompt(prompt): + if len(prompt) == 0: + return vllm.TextPrompt(prompt=prompt) + elif prompt is list: + return [vllm.TextPrompt(prompt=p) for p in prompt] + else: + return vllm.TextPrompt(prompt=prompt) + +def tokens_to_vllm_prompt(tokens): + if len(tokens) == 0: + return vllm.TokensPrompt(prompt_token_ids=tokens) + elif tokens[0] is list: # Multiple prompts in one entry + return [vllm.TokensPrompt(prompt_token_ids=toks) for toks in tokens] + else: + return vllm.TokensPrompt(prompt_token_ids=tokens) + +def get_llm_input(job): + for k, fn in [ + ("messages", lambda messages: messages), + ("prompt", prompt_to_vllm_prompt), + ("tokens", tokens_to_vllm_prompt)]: + value = job.get(k) + if value: + return fn(value) + return None + class JobInput: def __init__(self, job): - self.llm_input = job.get("messages", job.get("prompt")) + self.llm_input = get_llm_input(job) self.stream = job.get("stream", False) self.max_batch_size = job.get("max_batch_size") - self.apply_chat_template = job.get("apply_chat_template", False) + self.apply_chat_template = job.get("apply_chat_template", job.get("messages") is not None) self.use_openai_format = job.get("use_openai_format", False) samp_param = job.get("sampling_params", {}) if "max_tokens" not in samp_param: @@ -103,4 +130,4 @@ def wrapper(*args, **kwargs): end = time() logging.info(f"{func.__name__} completed in {end - start:.2f} seconds") return result - return wrapper \ No newline at end of file + return wrapper