Skip to content

Fix mypy errors in openhands/llm directory #6812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
66bd8fd
Enable strict type checking with mypy
openhands-agent Jan 21, 2025
7a25991
Update .github/workflows/lint.yml
neubig Jan 21, 2025
64ebef3
Update .github/workflows/lint.yml
neubig Jan 21, 2025
66a7920
Merge branch 'main' into feature/strict-mypy-checks
neubig Feb 10, 2025
d309455
Merge branch 'main' into feature/strict-mypy-checks
neubig Feb 11, 2025
592aca0
Merge branch 'main' into feature/strict-mypy-checks
neubig Feb 19, 2025
2589b13
Fix mypy errors in openhands/llm directory
openhands-agent Feb 19, 2025
63d9c3d
Revert changes outside openhands/llm directory
openhands-agent Feb 19, 2025
18543a2
Merge branch 'main' into fix/llm-mypy-errors
neubig Feb 19, 2025
5b8db98
Merge branch 'main' into fix/llm-mypy-errors
xingyaoww Feb 19, 2025
5b1c8bc
Merge branch 'main' into fix/llm-mypy-errors
enyst Feb 19, 2025
41068f6
Merge branch 'main' into fix/llm-mypy-errors
neubig Feb 21, 2025
9ccf680
Remove parallel test execution to fix failing tests
openhands-agent Feb 21, 2025
f74ce56
Fix litellm import path for Choices and StreamingChoices
openhands-agent Feb 21, 2025
4b49ffb
Fix litellm type imports to use OpenAI types directly
openhands-agent Feb 21, 2025
7886c1f
Fix litellm type imports to use OpenAI types directly
openhands-agent Feb 21, 2025
1b9a2b4
Merge branch 'main' into fix/llm-mypy-errors
neubig Feb 22, 2025
e25e676
Fix ruff and ruff-format issues
openhands-agent Feb 22, 2025
5266187
Update .github/workflows/py-unit-tests.yml
neubig Feb 23, 2025
1e5c4da
Merge branch 'main' into fix/llm-mypy-errors
neubig Mar 3, 2025
cb705b7
Remove duplicate default values in retry_mixin.py
openhands-agent Mar 3, 2025
e63dfca
Replace partial() with cast() for type hints
openhands-agent Mar 3, 2025
7f3202d
Replace partial() with direct function wrappers
openhands-agent Mar 3, 2025
096b259
Replace partial() with proper type hints
openhands-agent Mar 3, 2025
b9f330d
Merge main into fix/llm-mypy-errors (keeping our poetry.lock)
openhands-agent Mar 24, 2025
ac07afb
fix: make LLM response handling more robust for deepseek and other pr…
openhands-agent Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions openhands/llm/async_llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
from functools import partial
from typing import Any
from typing import Any, Callable, Coroutine

from litellm import acompletion as litellm_acompletion
from litellm.types.utils import ModelResponse

from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
Expand All @@ -17,7 +18,9 @@
class AsyncLLM(LLM):
"""Asynchronous LLM class."""

def __init__(self, *args, **kwargs):
_async_completion: Callable[..., Coroutine[Any, Any, ModelResponse]]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self._async_completion = partial(
Expand All @@ -37,7 +40,9 @@ def __init__(self, *args, **kwargs):
seed=self.config.seed,
)

async_completion_unwrapped = self._async_completion
async_completion_unwrapped: Callable[
..., Coroutine[Any, Any, ModelResponse]
] = self._async_completion

@self.retry_decorator(
num_retries=self.config.num_retries,
Expand All @@ -46,7 +51,7 @@ def __init__(self, *args, **kwargs):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_completion_wrapper(*args, **kwargs):
async def async_completion_wrapper(*args: Any, **kwargs: Any) -> dict[str, Any]:
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
messages: list[dict[str, Any]] | dict[str, Any] = []

Expand Down Expand Up @@ -77,7 +82,7 @@ async def async_completion_wrapper(*args, **kwargs):

self.log_prompt(messages)

async def check_stopped():
async def check_stopped() -> None:
while should_continue():
if (
hasattr(self.config, 'on_cancel_requested_fn')
Expand All @@ -97,10 +102,8 @@ async def check_stopped():
self.log_response(message_back)

# log costs and tokens used
self._post_completion(resp)

# We do not support streaming in this method, thus return resp
return resp
return dict(resp)

except UserCancelledError:
logger.debug('LLM request cancelled by user.')
Expand All @@ -117,14 +120,15 @@ async def check_stopped():
except asyncio.CancelledError:
pass

self._async_completion = async_completion_wrapper # type: ignore
self._async_completion = async_completion_wrapper

async def _call_acompletion(self, *args, **kwargs):
async def _call_acompletion(self, *args: Any, **kwargs: Any) -> ModelResponse:
"""Wrapper for the litellm acompletion function."""
# Used in testing?
return await litellm_acompletion(*args, **kwargs)
resp = await litellm_acompletion(*args, **kwargs)
return ModelResponse(**resp)

@property
def async_completion(self):
def async_completion(self) -> Callable[..., Coroutine[Any, Any, ModelResponse]]:
"""Decorator for the async litellm acompletion function."""
return self._async_completion
2 changes: 1 addition & 1 deletion openhands/llm/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def list_foundation_models(
return []


def remove_error_modelId(model_list):
def remove_error_modelId(model_list: list[str]) -> list[str]:
return list(filter(lambda m: not m.startswith('bedrock'), model_list))
14 changes: 7 additions & 7 deletions openhands/llm/debug_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class DebugMixin:
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None:
if not messages:
logger.debug('No completion messages!')
return
Expand All @@ -24,30 +24,30 @@ def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
else:
logger.debug('No completion messages!')

def log_response(self, message_back: str):
def log_response(self, message_back: str) -> None:
if message_back:
llm_response_logger.debug(message_back)

def _format_message_content(self, message: dict[str, Any]):
def _format_message_content(self, message: dict[str, Any]) -> str:
content = message['content']
if isinstance(content, list):
return '\n'.join(
self._format_content_element(element) for element in content
)
return str(content)

def _format_content_element(self, element: dict[str, Any]):
def _format_content_element(self, element: dict[str, Any]) -> str:
if isinstance(element, dict):
if 'text' in element:
return element['text']
return str(element['text'])
if (
self.vision_is_active()
and 'image_url' in element
and 'url' in element['image_url']
):
return element['image_url']['url']
return str(element['image_url']['url'])
return str(element)

# This method should be implemented in the class that uses DebugMixin
def vision_is_active(self):
def vision_is_active(self) -> bool:
raise NotImplementedError
10 changes: 5 additions & 5 deletions openhands/llm/fn_call_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import copy
import json
import re
from typing import Iterable
from typing import Any, Iterable

from litellm import ChatCompletionToolParam

Expand Down Expand Up @@ -265,7 +265,7 @@ def convert_tool_call_to_string(tool_call: dict) -> str:
return ret


def convert_tools_to_description(tools: list[dict]) -> str:
def convert_tools_to_description(tools: list[ChatCompletionToolParam]) -> str:
ret = ''
for i, tool in enumerate(tools):
assert tool['type'] == 'function'
Expand Down Expand Up @@ -474,8 +474,8 @@ def convert_fncall_messages_to_non_fncall_messages(


def _extract_and_validate_params(
matching_tool: dict, param_matches: Iterable[re.Match], fn_name: str
) -> dict:
matching_tool: dict[str, Any], param_matches: Iterable[re.Match], fn_name: str
) -> dict[str, Any]:
params = {}
# Parse and validate parameters
required_params = set()
Expand Down Expand Up @@ -712,7 +712,7 @@ def convert_non_fncall_messages_to_fncall_messages(
# Parse parameters
param_matches = re.finditer(FN_PARAM_REGEX_PATTERN, fn_body, re.DOTALL)
params = _extract_and_validate_params(
matching_tool, param_matches, fn_name
dict(matching_tool), param_matches, fn_name
)

# Create tool call with unique ID
Expand Down
111 changes: 62 additions & 49 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import time
import warnings
from functools import partial
from typing import Any, Callable

import requests
Expand All @@ -13,14 +12,15 @@
warnings.simplefilter('ignore')
import litellm

from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
from litellm import Message as LiteLLMMessage
from litellm import ChatCompletionMessageToolCall, PromptTokensDetails
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
RateLimitError,
)
from litellm.types.router import ModelInfo as RouterModelInfo
from litellm.types.utils import CostPerToken, ModelResponse, Usage
from litellm.types.utils import ModelInfo as UtilsModelInfo
from litellm.utils import create_pretrained_tokenizer

from openhands.core.logger import openhands_logger as logger
Expand Down Expand Up @@ -87,6 +87,8 @@ class LLM(RetryMixin, DebugMixin):
config: an LLMConfig object specifying the configuration of the LLM.
"""

_completion: Callable[..., ModelResponse]

def __init__(
self,
config: LLMConfig,
Expand All @@ -108,7 +110,7 @@ def __init__(
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)

self.model_info: ModelInfo | None = None
self.model_info: RouterModelInfo | UtilsModelInfo | None = None
self.retry_listener = retry_listener
if self.config.log_completions:
if self.config.log_completions_folder is None:
Expand Down Expand Up @@ -153,23 +155,26 @@ def __init__(
kwargs['max_tokens'] = self.config.max_output_tokens
kwargs.pop('max_completion_tokens')

self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
timeout=self.config.timeout,
top_p=self.config.top_p,
drop_params=self.config.drop_params,
seed=self.config.seed,
**kwargs,
)

self._completion_unwrapped = self._completion
# Create a wrapper function that captures the config values
def completion_with_config(*args: Any, **user_kwargs: Any) -> ModelResponse:
"""Wrapper for litellm_completion that includes the config values."""
merged_kwargs = {
'model': self.config.model,
'api_key': self.config.api_key.get_secret_value()
if self.config.api_key
else None,
'base_url': self.config.base_url,
'api_version': self.config.api_version,
'custom_llm_provider': self.config.custom_llm_provider,
'timeout': self.config.timeout,
'top_p': self.config.top_p,
'drop_params': self.config.drop_params,
**kwargs,
**user_kwargs,
}
return litellm_completion(*args, **merged_kwargs)

self._completion_unwrapped = completion_with_config

@self.retry_decorator(
num_retries=self.config.num_retries,
Expand All @@ -179,7 +184,7 @@ def __init__(
retry_multiplier=self.config.retry_multiplier,
retry_listener=self.retry_listener,
)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> ModelResponse:
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.io import json

Expand Down Expand Up @@ -257,20 +262,18 @@ def wrapper(*args, **kwargs):

# if we mocked function calling, and we have tools, convert the response back to function calling format
if mock_function_calling and mock_fncall_tools is not None:
logger.debug(f'Response choices: {len(resp.choices)}')
assert len(resp.choices) >= 1
non_fncall_response_message = resp.choices[0].message
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [non_fncall_response_message], mock_fncall_tools
assert len(resp.choices) == 1
if isinstance(resp.choices[0], dict) and 'message' in resp.choices[0]:
non_fncall_response_message = resp.choices[0]['message']
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [dict(non_fncall_response_message)],
mock_fncall_tools,
)
)
)
fn_call_response_message = fn_call_messages_with_response[-1]
if not isinstance(fn_call_response_message, LiteLLMMessage):
fn_call_response_message = LiteLLMMessage(
**fn_call_response_message
)
resp.choices[0].message = fn_call_response_message
fn_call_response_message = fn_call_messages_with_response[-1]
fn_call_response_message = dict(fn_call_response_message)
resp.choices[0]['message'] = fn_call_response_message

message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
Expand Down Expand Up @@ -327,14 +330,14 @@ def wrapper(*args, **kwargs):
self._completion = wrapper

@property
def completion(self):
def completion(self) -> Callable[..., ModelResponse]:
"""Decorator for the litellm completion function.

Check the complete documentation at https://litellm.vercel.app/docs/completion
"""
return self._completion

def init_model_info(self):
def init_model_info(self) -> None:
if self._tried_model_info:
return
self._tried_model_info = True
Expand Down Expand Up @@ -464,11 +467,11 @@ def _supports_vision(self) -> bool:
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
# Check both the full model name and the name after proxy prefix for vision support
return (
litellm.supports_vision(self.config.model)
or litellm.supports_vision(self.config.model.split('/')[-1])
bool(litellm.supports_vision(self.config.model))
or bool(litellm.supports_vision(self.config.model.split('/')[-1]))
or (
self.model_info is not None
and self.model_info.get('supports_vision', False)
and bool(self.model_info.get('supports_vision', False))
)
)

Expand Down Expand Up @@ -624,7 +627,7 @@ def _is_local(self) -> bool:
return True
return False

def _completion_cost(self, response) -> float:
def _completion_cost(self, response: ModelResponse) -> float:
"""Calculate completion cost and update metrics with running total.

Calculate the cost of a completion response based on the model. Local models are treated as free.
Expand Down Expand Up @@ -663,35 +666,45 @@ def _completion_cost(self, response) -> float:
try:
if cost is None:
try:
cost = litellm_completion_cost(
completion_response=response, **extra_kwargs
cost = float(
litellm_completion_cost(
completion_response=response,
custom_cost_per_token=extra_kwargs.get(
'custom_cost_per_token'
),
)
)
except Exception as e:
logger.error(f'Error getting cost from litellm: {e}')

if cost is None:
_model_name = '/'.join(self.config.model.split('/')[1:])
cost = litellm_completion_cost(
completion_response=response, model=_model_name, **extra_kwargs
cost = float(
litellm_completion_cost(
completion_response=response,
model=_model_name,
custom_cost_per_token=extra_kwargs.get('custom_cost_per_token'),
)
)
logger.debug(
f'Using fallback model name {_model_name} to get cost: {cost}'
)
self.metrics.add_cost(cost)
return cost
cost_float = float(cost)
self.metrics.add_cost(cost_float)
return cost_float
except Exception:
self.cost_metric_supported = False
logger.debug('Cost calculation not supported for this model.')
return 0.0

def __str__(self):
def __str__(self) -> str:
if self.config.api_version:
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'

def __repr__(self):
def __repr__(self) -> str:
return str(self)

def reset(self) -> None:
Expand Down
Loading
Loading