Skip to content

[Bug]: Outlines broken on vLLM 0.8+ #15636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
cpfiffer opened this issue Mar 27, 2025 · 9 comments
Open

[Bug]: Outlines broken on vLLM 0.8+ #15636

cpfiffer opened this issue Mar 27, 2025 · 9 comments
Labels
bug Something isn't working structured-output

Comments

@cpfiffer
Copy link

Your current environment

The output of `python collect_env.py`
(oss-debug) λ ~/dottxt/oss-debug/ python vllm_test.py
INFO 03-27 10:53:14 [__init__.py:239] Automatically detected platform cuda.
INFO 03-27 10:53:24 [config.py:585] This model supports multiple tasks: {'reward', 'classify', 'embed', 'score', 'generate'}. Defaulting to 'generate'.
INFO 03-27 10:53:24 [config.py:1697] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 03-27 10:53:26 [core.py:54] Initializing a V1 LLM engine (v0.8.2) with config: model='microsoft/Phi-3-mini-4k-instruct', speculative_config=None, tokenizer='microsoft/Phi-3-mini-4k-instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=microsoft/Phi-3-mini-4k-instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":3,"custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}
WARNING 03-27 10:53:26 [utils.py:2321] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x763983533110>
INFO 03-27 10:53:27 [parallel_state.py:954] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 03-27 10:53:27 [cuda.py:220] Using Flash Attention backend on V1 engine.
INFO 03-27 10:53:27 [gpu_model_runner.py:1174] Starting to load model microsoft/Phi-3-mini-4k-instruct...
WARNING 03-27 10:53:28 [topk_topp_sampler.py:63] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
INFO 03-27 10:53:28 [weight_utils.py:265] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:00<00:00,  1.37it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.74it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.67it/s]

INFO 03-27 10:53:29 [loader.py:447] Loading weights took 1.25 seconds
INFO 03-27 10:53:29 [gpu_model_runner.py:1186] Model loading took 7.1184 GB and 2.009474 seconds
INFO 03-27 10:53:38 [backends.py:415] Using cache directory: /home/cameron/.cache/vllm/torch_compile_cache/dc009c0fc6/rank_0_0 for vLLM's torch.compile
INFO 03-27 10:53:38 [backends.py:425] Dynamo bytecode transform time: 8.92 s
INFO 03-27 10:53:41 [backends.py:132] Cache the graph of shape None for later use
INFO 03-27 10:54:06 [backends.py:144] Compiling a graph for general shape takes 27.13 s
INFO 03-27 10:54:15 [monitor.py:33] torch.compile takes 36.05 s in total
INFO 03-27 10:54:16 [kv_cache_utils.py:566] GPU KV cache size: 92,752 tokens
INFO 03-27 10:54:16 [kv_cache_utils.py:569] Maximum concurrency for 4,096 tokens per request: 22.64x
INFO 03-27 10:54:35 [gpu_model_runner.py:1534] Graph capturing finished in 19 secs, took 0.47 GiB
INFO 03-27 10:54:35 [core.py:151] init engine (profile, create kv cache, warmup model) took 65.58 seconds
Traceback (most recent call last):
  File "/home/cameron/dottxt/oss-debug/vllm_test.py", line 13, in <module>
    response = generator(prompt)
               ^^^^^^^^^^^^^^^^^
  File "/home/cameron/dottxt/oss-debug/.venv/lib/python3.12/site-packages/outlines/generate/api.py", line 504, in __call__
    completions = self.model.generate(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/cameron/dottxt/oss-debug/.venv/lib/python3.12/site-packages/outlines/models/vllm.py", line 130, in generate
    results = self.model.generate(
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/cameron/dottxt/oss-debug/.venv/lib/python3.12/site-packages/vllm/utils.py", line 1072, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/cameron/dottxt/oss-debug/.venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 457, in generate
    self._validate_and_add_requests(
  File "/home/cameron/dottxt/oss-debug/.venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1308, in _validate_and_add_requests
    self._add_request(
  File "/home/cameron/dottxt/oss-debug/.venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1326, in _add_request
    self.llm_engine.add_request(
  File "/home/cameron/dottxt/oss-debug/.venv/lib/python3.12/site-packages/vllm/v1/engine/llm_engine.py", line 184, in add_request
    request = self.processor.process_inputs(request_id, prompt, params,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cameron/dottxt/oss-debug/.venv/lib/python3.12/site-packages/vllm/v1/engine/processor.py", line 183, in process_inputs
    self._validate_params(params)
  File "/home/cameron/dottxt/oss-debug/.venv/lib/python3.12/site-packages/vllm/v1/engine/processor.py", line 114, in _validate_params
    self._validate_supported_sampling_params(params)
  File "/home/cameron/dottxt/oss-debug/.venv/lib/python3.12/site-packages/vllm/v1/engine/processor.py", line 97, in _validate_supported_sampling_params
    raise ValueError("vLLM V1 does not support per request "
ValueError: vLLM V1 does not support per request user provided logits processors.

🐛 Describe the bug

Please see the downstream issue in Outlines for additional context: dottxt-ai/outlines#1517.

Essentially, I am getting the error

ValueError: vLLM V1 does not support per request user provided logits processors.

even though I am on vLLM 0.8. I'm curious why I'm getting a vLLM v1-related error message on a non-v1 tag?

MWE:

from outlines import models, generate
from pydantic import BaseModel

model = models.vllm("microsoft/Phi-3-mini-4k-instruct")

class Example(BaseModel):
    name: str
    description: str


prompt = "France: "
generator = generate.json(model, Example)
response = generator(prompt)

print(response)

Requirements:

`requirements.txt` ``` aiohappyeyeballs==2.6.1 aiohttp==3.11.14 aiosignal==1.3.2 airportsdata==20250224 annotated-types==0.7.0 anyio==4.9.0 astor==0.8.1 attrs==25.3.0 blake3==1.0.4 cachetools==5.5.2 certifi==2025.1.31 charset-normalizer==3.4.1 click==8.1.8 cloudpickle==3.1.1 compressed-tensors==0.9.2 cupy-cuda12x==13.4.1 depyf==0.18.0 dill==0.3.9 diskcache==5.6.3 distro==1.9.0 dnspython==2.7.0 einops==0.8.1 email-validator==2.2.0 fastapi==0.115.12 fastapi-cli==0.0.7 fastrlock==0.8.3 filelock==3.18.0 frozenlist==1.5.0 fsspec==2025.3.0 gguf==0.10.0 h11==0.14.0 httpcore==1.0.7 httptools==0.6.4 httpx==0.28.1 huggingface-hub==0.29.3 idna==3.10 importlib-metadata==8.6.1 interegular==0.3.3 jinja2==3.1.6 jiter==0.9.0 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 lark==1.2.2 llguidance==0.7.10 llvmlite==0.43.0 lm-format-enforcer==0.10.11 markdown-it-py==3.0.0 markupsafe==3.0.2 mdurl==0.1.2 mistral-common==1.5.4 mpmath==1.3.0 msgpack==1.1.0 msgspec==0.19.0 multidict==6.2.0 nest-asyncio==1.6.0 networkx==3.4.2 ninja==1.11.1.4 numba==0.60.0 numpy==1.26.4 nvidia-cublas-cu12==12.4.5.8 nvidia-cuda-cupti-cu12==12.4.127 nvidia-cuda-nvrtc-cu12==12.4.127 nvidia-cuda-runtime-cu12==12.4.127 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.2.1.3 nvidia-curand-cu12==10.3.5.147 nvidia-cusolver-cu12==11.6.1.9 nvidia-cusparse-cu12==12.3.1.170 nvidia-cusparselt-cu12==0.6.2 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 openai==1.68.2 opencv-python-headless==4.11.0.86 outlines==0.1.11 outlines-core==0.1.26 packaging==24.2 partial-json-parser==0.2.1.1.post5 pillow==11.1.0 prometheus-client==0.21.1 prometheus-fastapi-instrumentator==7.1.0 propcache==0.3.1 protobuf==6.30.2 psutil==7.0.0 py-cpuinfo==9.0.0 pycountry==24.6.1 pydantic==2.10.6 pydantic-core==2.27.2 pygments==2.19.1 python-dotenv==1.1.0 python-json-logger==3.3.0 python-multipart==0.0.20 pyyaml==6.0.2 pyzmq==26.3.0 ray==2.44.1 referencing==0.36.2 regex==2024.11.6 requests==2.32.3 rich==13.9.4 rich-toolkit==0.14.0 rpds-py==0.24.0 safetensors==0.5.3 scipy==1.15.2 sentencepiece==0.2.0 setuptools==78.1.0 shellingham==1.5.4 six==1.17.0 sniffio==1.3.1 starlette==0.46.1 sympy==1.13.1 tiktoken==0.9.0 tokenizers==0.21.1 torch==2.6.0 torchaudio==2.6.0 torchvision==0.21.0 tqdm==4.67.1 transformers==4.50.2 triton==3.2.0 typer==0.15.2 typing-extensions==4.13.0 urllib3==2.3.0 uvicorn==0.34.0 uvloop==0.21.0 vllm==0.8.2 watchfiles==1.0.4 websockets==15.0.1 xformers==0.0.29.post2 xgrammar==0.1.16 yarl==1.18.3 zipp==3.21.0
</details>

### Before submitting a new issue...

- [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
@russellb
Copy link
Member

The V1 engine became the default as of 0.8.0.

https://github.com/vllm-project/vllm/releases/tag/v0.8.0

@russellb
Copy link
Member

You can set VLLM_USE_V1=0 if you want to use V0. It seems required for this integration.

There will be a different way to use custom logits processors in V1, but it's not implemented yet.

@br3no
Copy link
Contributor

br3no commented Mar 28, 2025

I think the decision to make V1 the default before it reaches feature parity was ill taken.

@simon-mo
Copy link
Collaborator

Unfortunately we don't have a way to gracefully fallback due to these can only be detected per request. There is no engine parameter like --enable-per-request-logits-processors. Definitely lesson learned. Sorry about the breakage. 🙇

@mgoin
Copy link
Member

mgoin commented Mar 28, 2025

@cpfiffer why do you need to create custom logits processors? I thought outlines would use the structured output feature in vLLM directly by passing in a grammar/json schema to be respected

from pydantic import BaseModel
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams

llm = LLM(model="microsoft/Phi-3-mini-4k-instruct")

class Example(BaseModel):
    name: str
    description: str

sampling_params = SamplingParams(
    max_tokens=100,
    guided_decoding=GuidedDecodingParams(json=Example.model_json_schema()),
)

prompt = "France: "
outputs = llm.generate(
    prompts=prompt,
    sampling_params=sampling_params,
)

print(outputs[0].outputs[0].text)

Output:

{"name": "France", "description": "Paris and its romantic ambiance, plethora of museums, historical monuments, gastronomic delight, and affordable fees for tourists."}

@hjlee1995
Copy link

Hi, I'm trying to use custom logits processors to block foreign language output.
To do this, I need to register the logits_processors in the SamplingParams, but I'm running into the same error as the original poster.

vLLM V1 does not support per request user provided logits processors.

I'm wondering if there's a way to fix this.
If it's not possible, I would really appreciate it if you could suggest an alternative method instead of using logits_processors.

Hope you have a great day!

@russellb
Copy link
Member

VLLM_USE_V1=0 is the only way to get custom logits processors right now (switching back to the V0 engine).

@hjlee1995
Copy link

hjlee1995 commented Apr 15, 2025

Got it thx!
it works for me.

@BramVanroy
Copy link

VLLM_USE_V1=0 works. But I'd simply like to add a voice that from a user-perspective it is very odd to get a warning about "V1" when the actual version number of the installed library is "0.8". Very confusing that the engine and library have different version numbers that are not aligned. Also, I would not have expected such breaking changes in minor versions. This easily breaks implementations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working structured-output
Projects
Status: No status
Development

No branches or pull requests

7 participants