|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import logging |
6 | | -import os |
7 | 6 | import random |
8 | 7 | import time |
9 | 8 | from collections.abc import Callable |
10 | | -from typing import Final, TypeVar, cast |
11 | | - |
12 | | -from openai import OpenAI |
13 | | -from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam |
| 9 | +from typing import TypeVar |
14 | 10 |
|
15 | 11 | from edvise.genai.mapping.identity_agent.grain_inference.schemas import GrainContract |
16 | | -from edvise.genai.mapping.shared.mlflow_gateway_bootstrap import ( |
17 | | - disable_mlflow_side_effects_for_openai_gateway, |
18 | | -) |
| 12 | +from edvise.genai.mapping.shared import databricks_ai_gateway as _databricks_ai_gateway |
19 | 13 |
|
20 | | -# Same default endpoint as ``schema_mapping_agent.manifest.eval`` (MLflow serving / gateway). |
21 | | -DEFAULT_DATABRICKS_MLFLOW_AI_GATEWAY_URL: str = ( |
22 | | - "https://4437281602191762.ai-gateway.gcp.databricks.com/mlflow/v1" |
| 14 | +DEFAULT_DATABRICKS_MLFLOW_AI_GATEWAY_URL = ( |
| 15 | + _databricks_ai_gateway.DEFAULT_DATABRICKS_MLFLOW_AI_GATEWAY_URL |
| 16 | +) |
| 17 | +DEFAULT_GATEWAY_MODEL_ID = _databricks_ai_gateway.DEFAULT_GATEWAY_MODEL_ID |
| 18 | +LLM_COMPLETE_SYSTEM_USER_SEP = _databricks_ai_gateway.LLM_COMPLETE_SYSTEM_USER_SEP |
| 19 | +DEFAULT_GATEWAY_COMPLETION_MAX_TOKENS = ( |
| 20 | + _databricks_ai_gateway.DEFAULT_GATEWAY_COMPLETION_MAX_TOKENS |
| 21 | +) |
| 22 | +llm_complete_combined_message_content = ( |
| 23 | + _databricks_ai_gateway.llm_complete_combined_message_content |
| 24 | +) |
| 25 | +disable_mlflow_tracing_for_openai_gateway_client = ( |
| 26 | + _databricks_ai_gateway.disable_mlflow_tracing_for_openai_gateway_client |
| 27 | +) |
| 28 | +resolve_ai_gateway_base_url = _databricks_ai_gateway.resolve_ai_gateway_base_url |
| 29 | +resolve_gateway_model_id = _databricks_ai_gateway.resolve_gateway_model_id |
| 30 | +require_databricks_token = _databricks_ai_gateway.require_databricks_token |
| 31 | +create_openai_client_for_databricks_gateway = ( |
| 32 | + _databricks_ai_gateway.create_openai_client_for_databricks_gateway |
| 33 | +) |
| 34 | +make_databricks_gateway_llm_complete = ( |
| 35 | + _databricks_ai_gateway.make_databricks_gateway_llm_complete |
23 | 36 | ) |
24 | | - |
25 | | -DEFAULT_GATEWAY_MODEL_ID: str = "claude-sonnet-edvise-genai" |
26 | | - |
27 | | -# System + user are concatenated into one role=user message (IA / SMA). |
28 | | -LLM_COMPLETE_SYSTEM_USER_SEP: Final[str] = "\n\n---\n\n" |
29 | | -DEFAULT_GATEWAY_COMPLETION_MAX_TOKENS: Final[int] = 16_000 |
30 | 37 |
|
31 | 38 | _LOG = logging.getLogger(__name__) |
32 | 39 |
|
33 | | - |
34 | | -def llm_complete_combined_message_content(system: str, user: str) -> str: |
35 | | - """Exact ``content`` string sent to the gateway for ``llm_complete(system, user)``.""" |
36 | | - return system + LLM_COMPLETE_SYSTEM_USER_SEP + user |
37 | | - |
38 | | - |
39 | | -def disable_mlflow_tracing_for_openai_gateway_client() -> None: |
40 | | - """ |
41 | | - Turn off MLflow tracing / OpenAI autolog for gateway calls (see module docstring). |
42 | | -
|
43 | | - Job scripts should also call :func:`~edvise.genai.mapping.shared.mlflow_gateway_bootstrap.disable_mlflow_side_effects_for_openai_gateway` |
44 | | - at import time **before** loading packages that import ``openai``. |
45 | | - """ |
46 | | - disable_mlflow_side_effects_for_openai_gateway() |
47 | | - |
48 | | - |
49 | | -def resolve_ai_gateway_base_url() -> str: |
50 | | - """``AI_GATEWAY_BASE_URL`` env, else :data:`DEFAULT_DATABRICKS_MLFLOW_AI_GATEWAY_URL`.""" |
51 | | - return os.environ.get( |
52 | | - "AI_GATEWAY_BASE_URL", DEFAULT_DATABRICKS_MLFLOW_AI_GATEWAY_URL |
53 | | - ) |
54 | | - |
55 | | - |
56 | | -def resolve_gateway_model_id() -> str: |
57 | | - """``GATEWAY_MODEL_ID`` env, else :data:`DEFAULT_GATEWAY_MODEL_ID`.""" |
58 | | - return os.environ.get("GATEWAY_MODEL_ID", DEFAULT_GATEWAY_MODEL_ID) |
59 | | - |
60 | | - |
61 | | -def _token_from_authorization_header(headers: dict[str, str]) -> str | None: |
62 | | - auth = headers.get("Authorization") or headers.get("authorization") |
63 | | - if not auth or not isinstance(auth, str): |
64 | | - return None |
65 | | - parts = auth.split(None, 1) |
66 | | - if len(parts) == 2 and parts[0].lower() == "bearer": |
67 | | - return parts[1].strip() |
68 | | - return None |
69 | | - |
70 | | - |
71 | | -def _token_from_databricks_sdk_default_auth() -> str | None: |
72 | | - """ |
73 | | - Resolve a short-lived workspace bearer via ``Config().authenticate()`` (Databricks SDK). |
74 | | -
|
75 | | - Typical sources: job/cluster identity metadata service, OAuth M2M / service principal, |
76 | | - or a local ``databricks auth login`` profile when ``DATABRICKS_HOST`` is set. |
77 | | - """ |
78 | | - try: |
79 | | - from databricks.sdk.core import Config |
80 | | - except ImportError: |
81 | | - _LOG.debug( |
82 | | - "databricks-sdk not installed; cannot resolve runtime workspace token" |
83 | | - ) |
84 | | - return None |
85 | | - try: |
86 | | - headers = Config().authenticate() |
87 | | - except Exception as e: |
88 | | - _LOG.debug("Databricks SDK default auth unavailable (%s)", e) |
89 | | - return None |
90 | | - return _token_from_authorization_header(headers) |
91 | | - |
92 | | - |
93 | | -def require_databricks_token() -> str: |
94 | | - """ |
95 | | - Return a workspace bearer for the gateway ``api_key`` via :func:`_token_from_databricks_sdk_default_auth`. |
96 | | -
|
97 | | - Personal access tokens (``DATABRICKS_TOKEN``) are not used for this path. |
98 | | -
|
99 | | - ``OPENAI_API_KEY`` is not used for this gateway. |
100 | | - """ |
101 | | - from_sdk = _token_from_databricks_sdk_default_auth() |
102 | | - if from_sdk: |
103 | | - return from_sdk |
104 | | - msg = ( |
105 | | - "No Databricks workspace token for the MLflow AI gateway: databricks-sdk " |
106 | | - "Config().authenticate() did not return a Bearer token. Run on Databricks compute " |
107 | | - "with job/cluster identity, configure OAuth / service principal credentials, or " |
108 | | - "use ``databricks auth login`` locally with DATABRICKS_HOST set. " |
109 | | - "OPENAI_API_KEY is not used here." |
110 | | - ) |
111 | | - raise ValueError(msg) |
112 | | - |
113 | | - |
114 | | -def create_openai_client_for_databricks_gateway( |
115 | | - *, |
116 | | - api_key: str | None = None, |
117 | | - base_url: str | None = None, |
118 | | -) -> OpenAI: |
119 | | - """ |
120 | | - Build an :class:`openai.OpenAI` client pointed at the Databricks gateway. |
121 | | -
|
122 | | - If ``api_key`` is omitted, :func:`require_databricks_token` is used. |
123 | | - If ``base_url`` is omitted, :func:`resolve_ai_gateway_base_url` is used. |
124 | | - """ |
125 | | - disable_mlflow_tracing_for_openai_gateway_client() |
126 | | - key = api_key if api_key is not None else require_databricks_token() |
127 | | - url = base_url if base_url is not None else resolve_ai_gateway_base_url() |
128 | | - return OpenAI(api_key=key, base_url=url) |
129 | | - |
130 | | - |
131 | | -def make_databricks_gateway_llm_complete( |
132 | | - client: OpenAI, |
133 | | - *, |
134 | | - model: str | None = None, |
135 | | - max_tokens: int = DEFAULT_GATEWAY_COMPLETION_MAX_TOKENS, |
136 | | -) -> Callable[[str, str], str]: |
137 | | - """ |
138 | | - Return ``llm_complete(system, user)`` for :mod:`~edvise.genai.mapping.identity_agent.grain_inference.runner`. |
139 | | -
|
140 | | - The gateway is called with a single user message: ``system``, a separator, then ``user`` |
141 | | - (matches ``ia_dev`` / SMA notebook patterns). |
142 | | - """ |
143 | | - resolved_model = model if model is not None else resolve_gateway_model_id() |
144 | | - |
145 | | - def complete(system: str, user: str) -> str: |
146 | | - messages = cast( |
147 | | - list[ChatCompletionMessageParam], |
148 | | - [ |
149 | | - { |
150 | | - "role": "user", |
151 | | - "content": llm_complete_combined_message_content(system, user), |
152 | | - } |
153 | | - ], |
154 | | - ) |
155 | | - resp = client.chat.completions.create( |
156 | | - model=resolved_model, |
157 | | - messages=messages, |
158 | | - max_tokens=max_tokens, |
159 | | - ) |
160 | | - return _assistant_text_from_chat_completion_or_raise( |
161 | | - resp, log=_LOG, default_model=resolved_model |
162 | | - ) |
163 | | - |
164 | | - return complete |
165 | | - |
166 | | - |
167 | | -def _text_from_message_content( |
168 | | - content: object, |
169 | | -) -> str: |
170 | | - """ |
171 | | - Best-effort string from ``message.content`` (OpenAI is usually ``str | None``; |
172 | | - some routes may return list-shaped multimodal content). |
173 | | - """ |
174 | | - if content is None: |
175 | | - return "" |
176 | | - if isinstance(content, str): |
177 | | - return content |
178 | | - if isinstance(content, list): |
179 | | - parts: list[str] = [] |
180 | | - for block in content: |
181 | | - if isinstance(block, dict): |
182 | | - if block.get("type") == "text" and "text" in block: |
183 | | - parts.append(str(block.get("text", ""))) |
184 | | - else: |
185 | | - t = block.get("text") |
186 | | - if t is not None: |
187 | | - parts.append(str(t)) |
188 | | - else: |
189 | | - tx = getattr(block, "text", None) |
190 | | - if tx is not None: |
191 | | - parts.append(str(tx)) |
192 | | - return "".join(parts) |
193 | | - return str(content) |
194 | | - |
195 | | - |
196 | | -def _assistant_text_from_chat_completion_or_raise( |
197 | | - resp: object, *, log: logging.Logger, default_model: str | None = None |
198 | | -) -> str: |
199 | | - """ |
200 | | - Return the assistant's output text, or raise if there is nothing usable to parse as JSON. |
201 | | -
|
202 | | - A ``200`` response with ``content=None`` and no text was previously turned into ``""``, |
203 | | - which only surfaces as JSONDecodeError on empty input. We fail fast with diagnostics |
204 | | - and surface refusals (e.g. Claude) explicitly. |
205 | | - """ |
206 | | - choices = getattr(resp, "choices", None) or [] |
207 | | - if not choices: |
208 | | - msg = "AI Gateway returned no choices on chat.completions" |
209 | | - log.error("%s: model=%r", msg, getattr(resp, "model", default_model)) |
210 | | - raise RuntimeError(msg) from None |
211 | | - |
212 | | - ch0 = choices[0] |
213 | | - msg = ch0.message |
214 | | - raw = _text_from_message_content(getattr(msg, "content", None)) |
215 | | - if raw.strip(): |
216 | | - return raw |
217 | | - |
218 | | - ref = getattr(msg, "refusal", None) |
219 | | - if isinstance(ref, str) and ref.strip(): |
220 | | - short = ref.strip()[:2000] |
221 | | - log.error( |
222 | | - "AI Gateway: model refusal (not valid JSON for downstream parse): %s", short |
223 | | - ) |
224 | | - raise RuntimeError( |
225 | | - "The model refused to return structured output. Refusal: " |
226 | | - + ref.strip()[:4000] |
227 | | - ) from None |
228 | | - |
229 | | - u = getattr(resp, "usage", None) |
230 | | - udump: object |
231 | | - if u is not None and hasattr(u, "model_dump"): |
232 | | - udump = u.model_dump() # type: ignore[assignment] |
233 | | - else: |
234 | | - udump = u |
235 | | - fr = getattr(ch0, "finish_reason", None) |
236 | | - mod = getattr(resp, "model", None) or default_model |
237 | | - c_raw = getattr(msg, "content", None) |
238 | | - log.error( |
239 | | - "AI Gateway: empty assistant message: finish_reason=%r model=%r usage=%r content=%r", |
240 | | - fr, |
241 | | - mod, |
242 | | - udump, |
243 | | - c_raw, |
244 | | - ) |
245 | | - raise RuntimeError( |
246 | | - "AI Gateway returned an empty assistant message. " |
247 | | - f"finish_reason={fr!r}, model={mod!r}, usage={udump!r}. " |
248 | | - "The prompt may exceed the model context, max_tokens may be exhausted, " |
249 | | - "or the model emitted no text — try a smaller input batch or higher limits." |
250 | | - ) from None |
251 | | - |
252 | | - |
253 | 40 | _T = TypeVar("_T") |
254 | 41 |
|
255 | 42 |
|
|
0 commit comments