1+ import contextlib
12from collections .abc import AsyncGenerator
2- from typing import Any , TypedDict
3+ from typing import Any
34
45from pydantic import BaseModel , Field
56from ray import serve
7+ from typing_extensions import TypedDict
68from vllm .engine .arg_utils import AsyncEngineArgs
79from vllm .engine .async_llm_engine import AsyncLLMEngine
8- from vllm .model_executor .utils import set_random_seed
10+
11+ with contextlib .suppress (ImportError ):
12+ from vllm .model_executor .utils import (
13+ set_random_seed , # Ignore if we don't have GPU and only run on CPU with test cache
14+ )
915from vllm .sampling_params import SamplingParams as VLLMSamplingParams
10- from vllm .utils import get_gpu_memory , random_uuid
16+ from vllm .utils import random_uuid
1117
1218from aana .deployments .base_deployment import BaseDeployment
1319from aana .exceptions .general import InferenceException , PromptTooLongException
1420from aana .models .pydantic .chat_message import ChatDialog , ChatMessage
1521from aana .models .pydantic .sampling_params import SamplingParams
1622from aana .utils .chat_template import apply_chat_template
17- from aana .utils .general import merged_options
23+ from aana .utils .general import get_gpu_memory , merged_options
1824from aana .utils .test import test_cache
1925
2026
@@ -28,6 +34,9 @@ class VLLMConfig(BaseModel):
2834 gpu_memory_reserved (float): the GPU memory reserved for the model in mb
2935 default_sampling_params (SamplingParams): the default sampling parameters.
3036 max_model_len (int): the maximum generated text length in tokens (optional, default: None)
37+ chat_template (str): the name of the chat template, if not provided, the chat template from the model will be used
38+ but some models may not have a chat template (optional, default: None)
39+ enforce_eager (bool): whether to enforce eager execution (optional, default: False)
3140 """
3241
3342 model : str
@@ -37,6 +46,7 @@ class VLLMConfig(BaseModel):
3746 default_sampling_params : SamplingParams
3847 max_model_len : int | None = Field (default = None )
3948 chat_template : str | None = Field (default = None )
49+ enforce_eager : bool | None = Field (default = False )
4050
4151
4252class LLMOutput (TypedDict ):
@@ -107,6 +117,7 @@ async def apply_config(self, config: dict[str, Any]):
107117 model = config_obj .model ,
108118 dtype = config_obj .dtype ,
109119 quantization = config_obj .quantization ,
120+ enforce_eager = config_obj .enforce_eager ,
110121 gpu_memory_utilization = self .gpu_memory_utilization ,
111122 max_model_len = config_obj .max_model_len ,
112123 )
@@ -116,7 +127,7 @@ async def apply_config(self, config: dict[str, Any]):
116127
117128 # create the engine
118129 self .engine = AsyncLLMEngine .from_engine_args (args )
119- self .tokenizer = self .engine .engine .tokenizer
130+ self .tokenizer = self .engine .engine .tokenizer . tokenizer
120131 self .model_config = await self .engine .get_model_config ()
121132
122133 @test_cache
@@ -148,7 +159,7 @@ async def generate_stream(
148159 try :
149160 # convert SamplingParams to VLLMSamplingParams
150161 sampling_params_vllm = VLLMSamplingParams (
151- ** sampling_params .dict (exclude_unset = True )
162+ ** sampling_params .model_dump (exclude_unset = True )
152163 )
153164 # start the request
154165 request_id = random_uuid ()
0 commit comments