|
| 1 | +""" |
| 2 | +Helicone integration that leverages StandardLoggingPayload and supports batching via CustomBatchLogger. |
| 3 | +""" |
| 4 | + |
| 5 | +import asyncio |
| 6 | +import json |
| 7 | +import os |
| 8 | +from typing import Any, Dict, Optional |
| 9 | + |
| 10 | +import litellm |
| 11 | +from litellm._logging import verbose_logger |
| 12 | +from litellm.integrations.custom_batch_logger import CustomBatchLogger |
| 13 | +from litellm.litellm_core_utils.safe_json_dumps import safe_dumps |
| 14 | +from litellm.llms.custom_httpx.http_handler import ( |
| 15 | + get_async_httpx_client, |
| 16 | + httpxSpecialProvider, |
| 17 | +) |
| 18 | +from litellm.types.utils import StandardLoggingPayload |
| 19 | + |
| 20 | +__all__ = ["HeliconeLogger"] |
| 21 | + |
| 22 | + |
| 23 | +class HeliconeLogger(CustomBatchLogger): |
| 24 | + """Batching Helicone logger that consumes the StandardLoggingPayload.""" |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + api_key: Optional[str] = None, |
| 29 | + api_base: Optional[str] = None, |
| 30 | + **kwargs: Any, |
| 31 | + ) -> None: |
| 32 | + base = api_base or os.getenv("HELICONE_API_BASE") or "https://api.hconeai.com" |
| 33 | + self.api_base = base[:-1] if base.endswith("/") else base |
| 34 | + self.api_key = api_key or os.getenv("HELICONE_API_KEY") |
| 35 | + |
| 36 | + self.async_httpx_client = get_async_httpx_client( |
| 37 | + llm_provider=httpxSpecialProvider.LoggingCallback |
| 38 | + ) |
| 39 | + self.flush_lock: Optional[asyncio.Lock] = None |
| 40 | + try: |
| 41 | + asyncio.create_task(self.periodic_flush()) |
| 42 | + self.flush_lock = asyncio.Lock() |
| 43 | + except ( |
| 44 | + Exception |
| 45 | + ) as exc: # pragma: no cover - dependent on runtime loop availability |
| 46 | + verbose_logger.debug( |
| 47 | + "HeliconeLogger async batching disabled; running synchronously. %s", |
| 48 | + exc, |
| 49 | + ) |
| 50 | + self.flush_lock = None |
| 51 | + |
| 52 | + super().__init__(flush_lock=self.flush_lock, **kwargs) |
| 53 | + |
| 54 | + batch_size_override = os.getenv("HELICONE_BATCH_SIZE") |
| 55 | + if batch_size_override: |
| 56 | + try: |
| 57 | + self.batch_size = int(batch_size_override) |
| 58 | + except ValueError: |
| 59 | + verbose_logger.debug( |
| 60 | + "HeliconeLogger: ignoring invalid HELICONE_BATCH_SIZE=%s", |
| 61 | + batch_size_override, |
| 62 | + ) |
| 63 | + |
| 64 | + def log_success_event( |
| 65 | + self, |
| 66 | + kwargs: Dict[str, Any], |
| 67 | + response_obj: Any, |
| 68 | + start_time: Any, |
| 69 | + end_time: Any, |
| 70 | + ) -> None: |
| 71 | + try: |
| 72 | + data = self._build_data(kwargs, response_obj, start_time, end_time) |
| 73 | + if data is None: |
| 74 | + return |
| 75 | + self._send_sync(data) |
| 76 | + except Exception: |
| 77 | + verbose_logger.exception("HeliconeLogger: sync logging failed") |
| 78 | + |
| 79 | + async def async_log_success_event( |
| 80 | + self, |
| 81 | + kwargs: Dict[str, Any], |
| 82 | + response_obj: Any, |
| 83 | + start_time: Any, |
| 84 | + end_time: Any, |
| 85 | + ) -> None: |
| 86 | + try: |
| 87 | + data = self._build_data(kwargs, response_obj, start_time, end_time) |
| 88 | + if data is None: |
| 89 | + return |
| 90 | + |
| 91 | + if self.flush_lock is None: |
| 92 | + await self._send_async(data) |
| 93 | + return |
| 94 | + |
| 95 | + self.log_queue.append(data) |
| 96 | + if len(self.log_queue) >= self.batch_size: |
| 97 | + await self.flush_queue() |
| 98 | + except Exception: |
| 99 | + verbose_logger.exception("HeliconeLogger: async logging failed") |
| 100 | + |
| 101 | + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): |
| 102 | + try: |
| 103 | + verbose_logger.debug( |
| 104 | + "HeliconeLogger: Async logging - Enters logging function for model %s", |
| 105 | + kwargs, |
| 106 | + ) |
| 107 | + data = self._build_data(kwargs, response_obj, start_time, end_time) |
| 108 | + |
| 109 | + if data is None: |
| 110 | + return |
| 111 | + |
| 112 | + if self.flush_lock is None: |
| 113 | + await self._send_async(data) |
| 114 | + return |
| 115 | + |
| 116 | + self.log_queue.append(data) |
| 117 | + if len(self.log_queue) >= self.batch_size: |
| 118 | + await self.flush_queue() |
| 119 | + except Exception as e: |
| 120 | + verbose_logger.exception(f"HeliconeLogger Layer Error - {str(e)}") |
| 121 | + pass |
| 122 | + |
| 123 | + async def async_send_batch(self, *args: Any, **kwargs: Any) -> None: |
| 124 | + if not self.log_queue: |
| 125 | + return |
| 126 | + |
| 127 | + events = list(self.log_queue) |
| 128 | + for event in events: |
| 129 | + try: |
| 130 | + await self._send_async(event) |
| 131 | + except Exception: |
| 132 | + verbose_logger.exception( |
| 133 | + "HeliconeLogger: failed to send batched Helicone event" |
| 134 | + ) |
| 135 | + |
| 136 | + def _build_data( |
| 137 | + self, kwargs: Dict[str, Any], response_obj: Any, start_time: Any, end_time: Any |
| 138 | + ) -> dict: |
| 139 | + logging_payload: Optional[StandardLoggingPayload] = kwargs.get( |
| 140 | + "standard_logging_object", None |
| 141 | + ) |
| 142 | + if logging_payload is None: |
| 143 | + raise ValueError("standard_logging_object not found in kwargs") |
| 144 | + |
| 145 | + provider_url = logging_payload.get("api_base", "") |
| 146 | + provider_request = self._pick_request_json(kwargs) |
| 147 | + meta: dict = {} |
| 148 | + providerRequest = { |
| 149 | + "url": provider_url, |
| 150 | + "json": provider_request, |
| 151 | + "meta": meta, |
| 152 | + } |
| 153 | + |
| 154 | + # provider_response = logging_payload.get("response", {}) |
| 155 | + provider_response = self._pick_response(logging_payload) |
| 156 | + # provider_response_header = self._pick_response_headers(logging_payload) |
| 157 | + provider_response_status = self._pick_status_code(logging_payload) |
| 158 | + provider_response = { |
| 159 | + "json": provider_response, |
| 160 | + "headers": {}, |
| 161 | + "status": provider_response_status, |
| 162 | + } |
| 163 | + |
| 164 | + start_time_seconds = int(start_time.timestamp()) |
| 165 | + start_time_milliseconds = int( |
| 166 | + (start_time.timestamp() - start_time_seconds) * 1000 |
| 167 | + ) |
| 168 | + end_time_seconds = int(end_time.timestamp()) |
| 169 | + end_time_milliseconds = int((end_time.timestamp() - end_time_seconds) * 1000) |
| 170 | + timing = { |
| 171 | + "startTime": { |
| 172 | + "seconds": start_time_seconds, |
| 173 | + "milliseconds": start_time_milliseconds, |
| 174 | + }, |
| 175 | + "endTime": { |
| 176 | + "seconds": end_time_seconds, |
| 177 | + "milliseconds": end_time_milliseconds, |
| 178 | + }, |
| 179 | + } |
| 180 | + |
| 181 | + payload_json = { |
| 182 | + "providerRequest": providerRequest, |
| 183 | + "providerResponse": provider_response, |
| 184 | + "timing": timing, |
| 185 | + } |
| 186 | + return self._sanitize(payload_json) |
| 187 | + |
| 188 | + def _pick_request_json(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: |
| 189 | + if kwargs: |
| 190 | + additional_args = kwargs.get("additional_args") or {} |
| 191 | + if isinstance(additional_args, dict): |
| 192 | + complete_input_dict = additional_args.get("complete_input_dict") |
| 193 | + if isinstance(complete_input_dict, dict): |
| 194 | + return complete_input_dict |
| 195 | + return {} |
| 196 | + |
| 197 | + def _pick_response(self, logging_payload: StandardLoggingPayload) -> Any: |
| 198 | + if logging_payload.get("status") == "success": |
| 199 | + return logging_payload.get("response", {}) |
| 200 | + return logging_payload.get("error_str", {}) |
| 201 | + |
| 202 | + def _pick_response_headers( |
| 203 | + self, logging_payload: StandardLoggingPayload |
| 204 | + ) -> Dict[str, Any]: |
| 205 | + headers: Dict[str, Any] = {} |
| 206 | + hidden_params = logging_payload.get("hidden_params") |
| 207 | + if isinstance(hidden_params, dict): |
| 208 | + provider_headers = hidden_params.get("response_headers") |
| 209 | + if isinstance(provider_headers, dict): |
| 210 | + headers.update(provider_headers) |
| 211 | + return headers |
| 212 | + |
| 213 | + def _pick_status_code(self, logging_payload: StandardLoggingPayload) -> int: |
| 214 | + error_information = logging_payload.get("error_information") or {} |
| 215 | + if isinstance(error_information, dict): |
| 216 | + error_code = error_information.get("error_code") |
| 217 | + if isinstance(error_code, str) and error_code: |
| 218 | + return int(error_code) |
| 219 | + return 200 |
| 220 | + |
| 221 | + @staticmethod |
| 222 | + def _sanitize(payload: Dict[str, Any]) -> Dict[str, Any]: |
| 223 | + """Return a JSON-serializable representation of the payload.""" |
| 224 | + return json.loads(safe_dumps(payload)) |
| 225 | + |
| 226 | + def _send_sync(self, data: Dict[str, Any]) -> None: |
| 227 | + url = f"{self.api_base}/custom/v1/log" |
| 228 | + headers = { |
| 229 | + "Authorization": f"Bearer {self.api_key}", |
| 230 | + "Content-Type": "application/json", |
| 231 | + } |
| 232 | + |
| 233 | + response = litellm.module_level_client.post( |
| 234 | + url=url, |
| 235 | + headers=headers, |
| 236 | + json=data, |
| 237 | + ) |
| 238 | + verbose_logger.debug( |
| 239 | + "HeliconeLogger: logged Helicone event (status %s)", |
| 240 | + getattr(response, "status_code", "unknown"), |
| 241 | + ) |
| 242 | + |
| 243 | + async def _send_async(self, data: Dict[str, Any]) -> None: |
| 244 | + url = f"{self.api_base}/custom/v1/log" |
| 245 | + headers = { |
| 246 | + "Authorization": f"Bearer {self.api_key}", |
| 247 | + "Content-Type": "application/json", |
| 248 | + } |
| 249 | + response = await self.async_httpx_client.post( |
| 250 | + url=url, |
| 251 | + headers=headers, |
| 252 | + json=data, |
| 253 | + ) |
| 254 | + response.raise_for_status() |
| 255 | + verbose_logger.debug( |
| 256 | + "HeliconeLogger: logged Helicone event (status %s)", |
| 257 | + response.status_code, |
| 258 | + ) |
0 commit comments