Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion litellm/integrations/custom_logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#### What this does ####
# On success, logs events to Promptlayer
import asyncio
import re
import threading
import traceback
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -81,7 +83,13 @@ def __init__(
pass

def log_pre_api_call(self, model, messages, kwargs):
pass
async_impl = type(self).async_log_pre_api_call
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why run the async method for pre api call? isn't it always a sync call?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is what the codebase showed. But since I saw those methods and client was using the async method, I thought it was not implemented

if async_impl is not CustomLogger.async_log_pre_api_call:
self._run_async_method_sync(
self.async_log_pre_api_call, model, messages, kwargs
)
else:
pass

def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
pass
Expand Down Expand Up @@ -737,3 +745,27 @@ def _process_messages(self, messages: List[Any], max_depth: int = DEFAULT_MAX_RE
msg[key] = self._redact_base64(value=val, max_depth=max_depth)
filtered_messages.append(msg)
return filtered_messages

def _run_async_method_sync(self, async_fn, *args, **kwargs) -> None:
try:
asyncio.get_running_loop()
except RuntimeError:
asyncio.run(async_fn(*args, **kwargs))
return

# Running inside an existing event loop – execute in a new thread with its own loop
thread_exc: Optional[BaseException] = None

def _thread_target():
nonlocal thread_exc
try:
asyncio.run(async_fn(*args, **kwargs))
except BaseException as exc: # noqa: BLE001
thread_exc = exc

thread = threading.Thread(target=_thread_target, daemon=True)
thread.start()
thread.join()

if thread_exc is not None:
raise thread_exc
136 changes: 122 additions & 14 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# What is this?
## Common Utility file for Logging handler
# Logging function -> log the exact model details + what's being sent | Non-Blocking
import asyncio
import copy
import datetime
import json
Expand All @@ -10,6 +11,7 @@
import sys
import time
import traceback
import threading
from datetime import datetime as dt_object
from functools import lru_cache
from typing import (
Expand Down Expand Up @@ -59,6 +61,7 @@
from litellm.integrations.deepeval.deepeval import DeepEvalLogger
from litellm.integrations.mlflow import MlflowLogger
from litellm.integrations.sqs import SQSLogger
from litellm.litellm_core_utils.cached_imports import get_coroutine_checker
from litellm.litellm_core_utils.get_litellm_params import get_litellm_params
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
StandardBuiltInToolCostTracking,
Expand Down Expand Up @@ -317,6 +320,9 @@ def __init__(
self.dynamic_input_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = dynamic_input_callbacks
self.dynamic_async_input_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = None
self.dynamic_success_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = dynamic_success_callbacks
Expand Down Expand Up @@ -417,28 +423,41 @@ def _process_dynamic_callback_list(
return None

processed_list: List[Union[str, Callable, CustomLogger]] = []
coroutine_checker = get_coroutine_checker()
for callback in callback_list:
resolved_callback: Union[str, Callable, CustomLogger] = callback
if (
isinstance(callback, str)
and callback in litellm._known_custom_logger_compatible_callbacks
):
callback_class = _init_custom_logger_compatible_class(
callback, internal_usage_cache=None, llm_router=None # type: ignore
)
if callback_class is not None:
processed_list.append(callback_class)

# If processing dynamic_success_callbacks, add to dynamic_async_success_callbacks
if dynamic_callbacks_type == "success":
if self.dynamic_async_success_callbacks is None:
self.dynamic_async_success_callbacks = []
self.dynamic_async_success_callbacks.append(callback_class)
elif dynamic_callbacks_type == "failure":
if self.dynamic_async_failure_callbacks is None:
self.dynamic_async_failure_callbacks = []
self.dynamic_async_failure_callbacks.append(callback_class)
else:
processed_list.append(callback)
if callback_class is None:
continue
resolved_callback = callback_class

# If processing dynamic_success_callbacks, add to dynamic_async_success_callbacks
if dynamic_callbacks_type == "success":
if self.dynamic_async_success_callbacks is None:
self.dynamic_async_success_callbacks = []
self.dynamic_async_success_callbacks.append(callback_class)
elif dynamic_callbacks_type == "failure":
if self.dynamic_async_failure_callbacks is None:
self.dynamic_async_failure_callbacks = []
self.dynamic_async_failure_callbacks.append(callback_class)

if (
dynamic_callbacks_type == "input"
and coroutine_checker.is_async_callable(resolved_callback)
):
if self.dynamic_async_input_callbacks is None:
self.dynamic_async_input_callbacks = []
if resolved_callback not in self.dynamic_async_input_callbacks:
self.dynamic_async_input_callbacks.append(resolved_callback)
continue

processed_list.append(resolved_callback)
return processed_list

def initialize_standard_callback_dynamic_params(
Expand Down Expand Up @@ -907,6 +926,7 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR
)
if capture_exception: # log this error to sentry for debugging
capture_exception(e)
self._run_async_pre_call_callbacks()
except Exception as e:
verbose_logger.exception(
"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format(
Expand All @@ -919,6 +939,94 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR
if capture_exception: # log this error to sentry for debugging
capture_exception(e)

def _get_async_input_callbacks(self) -> List[Union[str, Callable, CustomLogger]]:
callbacks: List[Union[str, Callable, CustomLogger]] = []
try:
if isinstance(litellm._async_input_callback, list):
callbacks.extend(litellm._async_input_callback)
except Exception:
pass

if self.dynamic_async_input_callbacks:
for callback in self.dynamic_async_input_callbacks:
if callback not in callbacks:
callbacks.append(callback)

return callbacks

def _run_async_pre_call_callbacks(self) -> None:
callbacks = self._get_async_input_callbacks()
if len(callbacks) == 0:
return

async def runner():
for callback in callbacks:
try:
await self._execute_async_pre_call_callback(callback)
except Exception as e:
verbose_logger.exception(
"litellm.Logging.pre_call(): Exception occurred while running async callback - {}".format(
str(e)
)
)
if capture_exception:
capture_exception(e)

try:
self._run_coroutine_blocking(runner)
except Exception as e:
verbose_logger.exception(
"litellm.Logging.pre_call(): Exception occurred while executing async callbacks - {}".format(
str(e)
)
)
if capture_exception:
capture_exception(e)

async def _execute_async_pre_call_callback(
self, callback: Union[str, Callable, CustomLogger]
) -> None:
if isinstance(callback, CustomLogger):
await callback.async_log_pre_api_call(
model=self.model,
messages=self.messages,
kwargs=self.model_call_details,
)
elif callable(callback):
global customLogger
if customLogger is None:
customLogger = CustomLogger()
await customLogger.async_log_input_event(
model=self.model,
messages=self.messages,
kwargs=self.model_call_details,
print_verbose=print_verbose,
callback_func=callback,
)

def _run_coroutine_blocking(self, async_callable: Callable[[], Any]) -> None:
try:
asyncio.get_running_loop()
except RuntimeError:
asyncio.run(async_callable())
return

thread_exc: Optional[BaseException] = None

def _thread_target():
nonlocal thread_exc
try:
asyncio.run(async_callable())
except BaseException as exc: # noqa: BLE001
thread_exc = exc

thread = threading.Thread(target=_thread_target, daemon=True)
thread.start()
thread.join()

if thread_exc is not None:
raise thread_exc

def _print_llm_call_debugging_log(
self,
api_base: str,
Expand Down
Loading