Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy errors in openhands/llm directory #6812

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions openhands/llm/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

from litellm import acompletion as litellm_acompletion
from litellm.types.utils import ModelResponse

from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
Expand All @@ -17,7 +18,7 @@
class AsyncLLM(LLM):
"""Asynchronous LLM class."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self._async_completion = partial(
Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(self, *args, **kwargs):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_completion_wrapper(*args, **kwargs):
async def async_completion_wrapper(*args: Any, **kwargs: Any) -> dict[str, Any]:
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
messages: list[dict[str, Any]] | dict[str, Any] = []

Expand Down Expand Up @@ -76,7 +77,7 @@ async def async_completion_wrapper(*args, **kwargs):

self.log_prompt(messages)

async def check_stopped():
async def check_stopped() -> None:
while should_continue():
if (
hasattr(self.config, 'on_cancel_requested_fn')
Expand All @@ -96,10 +97,8 @@ async def check_stopped():
self.log_response(message_back)

# log costs and tokens used
self._post_completion(resp)

# We do not support streaming in this method, thus return resp
return resp
return dict(resp)

except UserCancelledError:
logger.debug('LLM request cancelled by user.')
Expand All @@ -116,14 +115,15 @@ async def check_stopped():
except asyncio.CancelledError:
pass

self._async_completion = async_completion_wrapper # type: ignore
self._async_completion = partial(async_completion_wrapper)

async def _call_acompletion(self, *args, **kwargs):
async def _call_acompletion(self, *args: Any, **kwargs: Any) -> ModelResponse:
"""Wrapper for the litellm acompletion function."""
# Used in testing?
return await litellm_acompletion(*args, **kwargs)
resp = await litellm_acompletion(*args, **kwargs)
return ModelResponse(**resp)

@property
def async_completion(self):
def async_completion(self) -> Any:
"""Decorator for the async litellm acompletion function."""
return self._async_completion
2 changes: 1 addition & 1 deletion openhands/llm/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def list_foundation_models(
return []


def remove_error_modelId(model_list):
def remove_error_modelId(model_list: list[str]) -> list[str]:
return list(filter(lambda m: not m.startswith('bedrock'), model_list))
14 changes: 7 additions & 7 deletions openhands/llm/debug_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class DebugMixin:
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None:
if not messages:
logger.debug('No completion messages!')
return
Expand All @@ -24,30 +24,30 @@ def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
else:
logger.debug('No completion messages!')

def log_response(self, message_back: str):
def log_response(self, message_back: str) -> None:
if message_back:
llm_response_logger.debug(message_back)

def _format_message_content(self, message: dict[str, Any]):
def _format_message_content(self, message: dict[str, Any]) -> str:
content = message['content']
if isinstance(content, list):
return '\n'.join(
self._format_content_element(element) for element in content
)
return str(content)

def _format_content_element(self, element: dict[str, Any]):
def _format_content_element(self, element: dict[str, Any]) -> str:
if isinstance(element, dict):
if 'text' in element:
return element['text']
return str(element['text'])
if (
self.vision_is_active()
and 'image_url' in element
and 'url' in element['image_url']
):
return element['image_url']['url']
return str(element['image_url']['url'])
return str(element)

# This method should be implemented in the class that uses DebugMixin
def vision_is_active(self):
def vision_is_active(self) -> bool:
raise NotImplementedError
10 changes: 5 additions & 5 deletions openhands/llm/fn_call_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import copy
import json
import re
from typing import Iterable
from typing import Any, Iterable

from litellm import ChatCompletionToolParam

Expand Down Expand Up @@ -265,7 +265,7 @@ def convert_tool_call_to_string(tool_call: dict) -> str:
return ret


def convert_tools_to_description(tools: list[dict]) -> str:
def convert_tools_to_description(tools: list[ChatCompletionToolParam]) -> str:
ret = ''
for i, tool in enumerate(tools):
assert tool['type'] == 'function'
Expand Down Expand Up @@ -474,8 +474,8 @@ def convert_fncall_messages_to_non_fncall_messages(


def _extract_and_validate_params(
matching_tool: dict, param_matches: Iterable[re.Match], fn_name: str
) -> dict:
matching_tool: dict[str, Any], param_matches: Iterable[re.Match], fn_name: str
) -> dict[str, Any]:
params = {}
# Parse and validate parameters
required_params = set()
Expand Down Expand Up @@ -712,7 +712,7 @@ def convert_non_fncall_messages_to_fncall_messages(
# Parse parameters
param_matches = re.finditer(FN_PARAM_REGEX_PATTERN, fn_body, re.DOTALL)
params = _extract_and_validate_params(
matching_tool, param_matches, fn_name
dict(matching_tool), param_matches, fn_name
)

# Create tool call with unique ID
Expand Down
73 changes: 43 additions & 30 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
warnings.simplefilter('ignore')
import litellm

from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
from litellm import Message as LiteLLMMessage
from litellm import ChatCompletionMessageToolCall, PromptTokensDetails
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
RateLimitError,
)
from litellm.types.router import ModelInfo as RouterModelInfo
from litellm.types.utils import CostPerToken, ModelResponse, Usage
from litellm.types.utils import ModelInfo as UtilsModelInfo
from litellm.utils import create_pretrained_tokenizer
from openai.types.chat import ChatCompletion, ChatCompletionChunk

from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
Expand Down Expand Up @@ -104,7 +106,7 @@ def __init__(
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)

self.model_info: ModelInfo | None = None
self.model_info: RouterModelInfo | UtilsModelInfo | None = None
self.retry_listener = retry_listener
if self.config.log_completions:
if self.config.log_completions_folder is None:
Expand Down Expand Up @@ -170,7 +172,7 @@ def __init__(
retry_multiplier=self.config.retry_multiplier,
retry_listener=self.retry_listener,
)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> ModelResponse:
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.io import json

Expand Down Expand Up @@ -244,18 +246,19 @@ def wrapper(*args, **kwargs):
# if we mocked function calling, and we have tools, convert the response back to function calling format
if mock_function_calling and mock_fncall_tools is not None:
assert len(resp.choices) == 1
non_fncall_response_message = resp.choices[0].message
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [non_fncall_response_message], mock_fncall_tools
)
)
fn_call_response_message = fn_call_messages_with_response[-1]
if not isinstance(fn_call_response_message, LiteLLMMessage):
fn_call_response_message = LiteLLMMessage(
**fn_call_response_message
if isinstance(
resp.choices[0], (ChatCompletion.Choice, ChatCompletionChunk.Choice)
):
non_fncall_response_message = resp.choices[0].message
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [dict(non_fncall_response_message)],
mock_fncall_tools,
)
)
resp.choices[0].message = fn_call_response_message
fn_call_response_message = fn_call_messages_with_response[-1]
fn_call_response_message = dict(fn_call_response_message)
resp.choices[0].message = fn_call_response_message

message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
Expand Down Expand Up @@ -305,17 +308,17 @@ def wrapper(*args, **kwargs):

return resp

self._completion = wrapper
self._completion = partial(wrapper)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question here, does this imply that during execution, it goes through another layer of indirection? I worry a bit that our code in this area is already not easy to follow, and while the majority of the changes in this PR make it more readable, I'm not really sure that wrapping it again does. 🤔


@property
def completion(self):
def completion(self) -> Callable[..., ModelResponse]:
"""Decorator for the litellm completion function.

Check the complete documentation at https://litellm.vercel.app/docs/completion
"""
return self._completion

def init_model_info(self):
def init_model_info(self) -> None:
if self._tried_model_info:
return
self._tried_model_info = True
Expand Down Expand Up @@ -443,11 +446,11 @@ def _supports_vision(self) -> bool:
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
# Check both the full model name and the name after proxy prefix for vision support
return (
litellm.supports_vision(self.config.model)
or litellm.supports_vision(self.config.model.split('/')[-1])
bool(litellm.supports_vision(self.config.model))
or bool(litellm.supports_vision(self.config.model.split('/')[-1]))
or (
self.model_info is not None
and self.model_info.get('supports_vision', False)
and bool(self.model_info.get('supports_vision', False))
)
)

Expand Down Expand Up @@ -592,7 +595,7 @@ def _is_local(self) -> bool:
return True
return False

def _completion_cost(self, response) -> float:
def _completion_cost(self, response: ModelResponse) -> float:
"""Calculate completion cost and update metrics with running total.

Calculate the cost of a completion response based on the model. Local models are treated as free.
Expand Down Expand Up @@ -631,35 +634,45 @@ def _completion_cost(self, response) -> float:
try:
if cost is None:
try:
cost = litellm_completion_cost(
completion_response=response, **extra_kwargs
cost = float(
litellm_completion_cost(
completion_response=response,
custom_cost_per_token=extra_kwargs.get(
'custom_cost_per_token'
),
)
)
except Exception as e:
logger.error(f'Error getting cost from litellm: {e}')

if cost is None:
_model_name = '/'.join(self.config.model.split('/')[1:])
cost = litellm_completion_cost(
completion_response=response, model=_model_name, **extra_kwargs
cost = float(
litellm_completion_cost(
completion_response=response,
model=_model_name,
custom_cost_per_token=extra_kwargs.get('custom_cost_per_token'),
)
)
logger.debug(
f'Using fallback model name {_model_name} to get cost: {cost}'
)
self.metrics.add_cost(cost)
return cost
cost_float = float(cost)
self.metrics.add_cost(cost_float)
return cost_float
except Exception:
self.cost_metric_supported = False
logger.debug('Cost calculation not supported for this model.')
return 0.0

def __str__(self):
def __str__(self) -> str:
if self.config.api_version:
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'

def __repr__(self):
def __repr__(self) -> str:
return str(self)

def reset(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions openhands/llm/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,18 @@ def get(self) -> dict:
],
}

def reset(self):
def reset(self) -> None:
self._accumulated_cost = 0.0
self._costs = []
self._response_latencies = []

def log(self):
def log(self) -> str:
"""Log the metrics."""
metrics = self.get()
logs = ''
for key, value in metrics.items():
logs += f'{key}: {value}\n'
return logs

def __repr__(self):
def __repr__(self) -> str:
return f'Metrics({self.get()}'
Loading
Loading