-
Notifications
You must be signed in to change notification settings - Fork 269
Improved stability of litellm models for reasoning models. #538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
|
||
import logging | ||
import os | ||
import re | ||
import time | ||
from concurrent.futures import ThreadPoolExecutor | ||
from dataclasses import dataclass | ||
|
@@ -93,20 +94,25 @@ def __init__(self, config, env_config) -> None: | |
litellm.drop_params = True | ||
litellm.set_verbose = False | ||
|
||
def is_reasoning_model(self): | ||
return "o1" in self.model or "o3" in self.model or "R1" in self.model | ||
|
||
def _prepare_stop_sequence(self, stop_sequence): | ||
"""Prepare and validate stop sequence.""" | ||
if self.provider == "anthropic": | ||
# Filter out whitespace-only stop sequences | ||
if stop_sequence: | ||
stop_sequence = [s for s in stop_sequence if s and s.strip()] | ||
if not stop_sequence: # If empty after filtering | ||
stop_sequence = ["\n"] | ||
return stop_sequence | ||
|
||
def _prepare_max_new_tokens(self, max_new_tokens): | ||
"""Calculate completion tokens based on max_new_tokens.""" | ||
if not max_new_tokens or max_new_tokens <= 0: | ||
return None | ||
|
||
if "o1" in self.model: | ||
if self.is_reasoning_model(): | ||
# We need to allow more tokens to include reasoning tokens | ||
max_new_tokens = min(max_new_tokens * 10, 32000) | ||
return max_new_tokens | ||
|
@@ -132,8 +138,8 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se | |
"n": num_samples, | ||
"caching": True, | ||
} | ||
if "o1" in self.model: | ||
logger.warning("O1 models do not support temperature, top_p, stop sequence. Disabling.") | ||
if self.is_reasoning_model(): | ||
logger.warning("Reasoning models do not support temperature, top_p, stop sequence. Disabling.") | ||
else: | ||
kwargs["temperature"] = self.TEMPERATURE | ||
kwargs["top_p"] = self.TOP_P | ||
|
@@ -142,10 +148,17 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se | |
response = litellm.completion(**kwargs) | ||
|
||
# If response is empty, retry without caching (maybe the error is recoverable and solved with a retry) | ||
if response.choices[0].message.content is None: | ||
content = response.choices[0].message.content | ||
if not content: | ||
kwargs["caching"] = False | ||
logger.info("Response is empty, retrying without caching") | ||
response = litellm.completion(**kwargs) | ||
|
||
if content and "<think>" in content: | ||
logger.debug(f"Removing <think> tags from response: {content}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we removing think tags from the answer here ? I think it should be done in the metric function no ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are evaluating a reasoning model the grader will look at the thinking tokens unless we remove them. We would need to remove them in every metric function otherwise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah but in that case you lose the thinking traces in the details.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True. Maybe we can open an issue for that and add that improvement in a later PR? |
||
response.choices[0].message.content = re.sub( | ||
r"<think>.*?</think>", "", content, flags=re.DOTALL | ||
).strip() | ||
return response | ||
except litellm.BadRequestError as e: | ||
if "message" in e.__dict__: | ||
|
Uh oh!
There was an error while loading. Please reload this page.