Skip to content

Commit 1114d02

Browse files
author
Vishakh Pillai
committed
feat: grain resolution module + prompt tweaks for manifest refinement
1 parent 4f3c664 commit 1114d02

16 files changed

Lines changed: 1486 additions & 324 deletions

File tree

src/edvise/genai/mapping/identity_agent/grain_inference/databricks_gateway.py

Lines changed: 24 additions & 237 deletions
Original file line numberDiff line numberDiff line change
@@ -3,253 +3,40 @@
33
from __future__ import annotations
44

55
import logging
6-
import os
76
import random
87
import time
98
from collections.abc import Callable
10-
from typing import Final, TypeVar, cast
11-
12-
from openai import OpenAI
13-
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
9+
from typing import TypeVar
1410

1511
from edvise.genai.mapping.identity_agent.grain_inference.schemas import GrainContract
16-
from edvise.genai.mapping.shared.mlflow_gateway_bootstrap import (
17-
disable_mlflow_side_effects_for_openai_gateway,
18-
)
12+
from edvise.genai.mapping.shared import databricks_ai_gateway as _databricks_ai_gateway
1913

20-
# Same default endpoint as ``schema_mapping_agent.manifest.eval`` (MLflow serving / gateway).
21-
DEFAULT_DATABRICKS_MLFLOW_AI_GATEWAY_URL: str = (
22-
"https://4437281602191762.ai-gateway.gcp.databricks.com/mlflow/v1"
14+
DEFAULT_DATABRICKS_MLFLOW_AI_GATEWAY_URL = (
15+
_databricks_ai_gateway.DEFAULT_DATABRICKS_MLFLOW_AI_GATEWAY_URL
16+
)
17+
DEFAULT_GATEWAY_MODEL_ID = _databricks_ai_gateway.DEFAULT_GATEWAY_MODEL_ID
18+
LLM_COMPLETE_SYSTEM_USER_SEP = _databricks_ai_gateway.LLM_COMPLETE_SYSTEM_USER_SEP
19+
DEFAULT_GATEWAY_COMPLETION_MAX_TOKENS = (
20+
_databricks_ai_gateway.DEFAULT_GATEWAY_COMPLETION_MAX_TOKENS
21+
)
22+
llm_complete_combined_message_content = (
23+
_databricks_ai_gateway.llm_complete_combined_message_content
24+
)
25+
disable_mlflow_tracing_for_openai_gateway_client = (
26+
_databricks_ai_gateway.disable_mlflow_tracing_for_openai_gateway_client
27+
)
28+
resolve_ai_gateway_base_url = _databricks_ai_gateway.resolve_ai_gateway_base_url
29+
resolve_gateway_model_id = _databricks_ai_gateway.resolve_gateway_model_id
30+
require_databricks_token = _databricks_ai_gateway.require_databricks_token
31+
create_openai_client_for_databricks_gateway = (
32+
_databricks_ai_gateway.create_openai_client_for_databricks_gateway
33+
)
34+
make_databricks_gateway_llm_complete = (
35+
_databricks_ai_gateway.make_databricks_gateway_llm_complete
2336
)
24-
25-
DEFAULT_GATEWAY_MODEL_ID: str = "claude-sonnet-edvise-genai"
26-
27-
# System + user are concatenated into one role=user message (IA / SMA).
28-
LLM_COMPLETE_SYSTEM_USER_SEP: Final[str] = "\n\n---\n\n"
29-
DEFAULT_GATEWAY_COMPLETION_MAX_TOKENS: Final[int] = 16_000
3037

3138
_LOG = logging.getLogger(__name__)
3239

33-
34-
def llm_complete_combined_message_content(system: str, user: str) -> str:
35-
"""Exact ``content`` string sent to the gateway for ``llm_complete(system, user)``."""
36-
return system + LLM_COMPLETE_SYSTEM_USER_SEP + user
37-
38-
39-
def disable_mlflow_tracing_for_openai_gateway_client() -> None:
40-
"""
41-
Turn off MLflow tracing / OpenAI autolog for gateway calls (see module docstring).
42-
43-
Job scripts should also call :func:`~edvise.genai.mapping.shared.mlflow_gateway_bootstrap.disable_mlflow_side_effects_for_openai_gateway`
44-
at import time **before** loading packages that import ``openai``.
45-
"""
46-
disable_mlflow_side_effects_for_openai_gateway()
47-
48-
49-
def resolve_ai_gateway_base_url() -> str:
50-
"""``AI_GATEWAY_BASE_URL`` env, else :data:`DEFAULT_DATABRICKS_MLFLOW_AI_GATEWAY_URL`."""
51-
return os.environ.get(
52-
"AI_GATEWAY_BASE_URL", DEFAULT_DATABRICKS_MLFLOW_AI_GATEWAY_URL
53-
)
54-
55-
56-
def resolve_gateway_model_id() -> str:
57-
"""``GATEWAY_MODEL_ID`` env, else :data:`DEFAULT_GATEWAY_MODEL_ID`."""
58-
return os.environ.get("GATEWAY_MODEL_ID", DEFAULT_GATEWAY_MODEL_ID)
59-
60-
61-
def _token_from_authorization_header(headers: dict[str, str]) -> str | None:
62-
auth = headers.get("Authorization") or headers.get("authorization")
63-
if not auth or not isinstance(auth, str):
64-
return None
65-
parts = auth.split(None, 1)
66-
if len(parts) == 2 and parts[0].lower() == "bearer":
67-
return parts[1].strip()
68-
return None
69-
70-
71-
def _token_from_databricks_sdk_default_auth() -> str | None:
72-
"""
73-
Resolve a short-lived workspace bearer via ``Config().authenticate()`` (Databricks SDK).
74-
75-
Typical sources: job/cluster identity metadata service, OAuth M2M / service principal,
76-
or a local ``databricks auth login`` profile when ``DATABRICKS_HOST`` is set.
77-
"""
78-
try:
79-
from databricks.sdk.core import Config
80-
except ImportError:
81-
_LOG.debug(
82-
"databricks-sdk not installed; cannot resolve runtime workspace token"
83-
)
84-
return None
85-
try:
86-
headers = Config().authenticate()
87-
except Exception as e:
88-
_LOG.debug("Databricks SDK default auth unavailable (%s)", e)
89-
return None
90-
return _token_from_authorization_header(headers)
91-
92-
93-
def require_databricks_token() -> str:
94-
"""
95-
Return a workspace bearer for the gateway ``api_key`` via :func:`_token_from_databricks_sdk_default_auth`.
96-
97-
Personal access tokens (``DATABRICKS_TOKEN``) are not used for this path.
98-
99-
``OPENAI_API_KEY`` is not used for this gateway.
100-
"""
101-
from_sdk = _token_from_databricks_sdk_default_auth()
102-
if from_sdk:
103-
return from_sdk
104-
msg = (
105-
"No Databricks workspace token for the MLflow AI gateway: databricks-sdk "
106-
"Config().authenticate() did not return a Bearer token. Run on Databricks compute "
107-
"with job/cluster identity, configure OAuth / service principal credentials, or "
108-
"use ``databricks auth login`` locally with DATABRICKS_HOST set. "
109-
"OPENAI_API_KEY is not used here."
110-
)
111-
raise ValueError(msg)
112-
113-
114-
def create_openai_client_for_databricks_gateway(
115-
*,
116-
api_key: str | None = None,
117-
base_url: str | None = None,
118-
) -> OpenAI:
119-
"""
120-
Build an :class:`openai.OpenAI` client pointed at the Databricks gateway.
121-
122-
If ``api_key`` is omitted, :func:`require_databricks_token` is used.
123-
If ``base_url`` is omitted, :func:`resolve_ai_gateway_base_url` is used.
124-
"""
125-
disable_mlflow_tracing_for_openai_gateway_client()
126-
key = api_key if api_key is not None else require_databricks_token()
127-
url = base_url if base_url is not None else resolve_ai_gateway_base_url()
128-
return OpenAI(api_key=key, base_url=url)
129-
130-
131-
def make_databricks_gateway_llm_complete(
132-
client: OpenAI,
133-
*,
134-
model: str | None = None,
135-
max_tokens: int = DEFAULT_GATEWAY_COMPLETION_MAX_TOKENS,
136-
) -> Callable[[str, str], str]:
137-
"""
138-
Return ``llm_complete(system, user)`` for :mod:`~edvise.genai.mapping.identity_agent.grain_inference.runner`.
139-
140-
The gateway is called with a single user message: ``system``, a separator, then ``user``
141-
(matches ``ia_dev`` / SMA notebook patterns).
142-
"""
143-
resolved_model = model if model is not None else resolve_gateway_model_id()
144-
145-
def complete(system: str, user: str) -> str:
146-
messages = cast(
147-
list[ChatCompletionMessageParam],
148-
[
149-
{
150-
"role": "user",
151-
"content": llm_complete_combined_message_content(system, user),
152-
}
153-
],
154-
)
155-
resp = client.chat.completions.create(
156-
model=resolved_model,
157-
messages=messages,
158-
max_tokens=max_tokens,
159-
)
160-
return _assistant_text_from_chat_completion_or_raise(
161-
resp, log=_LOG, default_model=resolved_model
162-
)
163-
164-
return complete
165-
166-
167-
def _text_from_message_content(
168-
content: object,
169-
) -> str:
170-
"""
171-
Best-effort string from ``message.content`` (OpenAI is usually ``str | None``;
172-
some routes may return list-shaped multimodal content).
173-
"""
174-
if content is None:
175-
return ""
176-
if isinstance(content, str):
177-
return content
178-
if isinstance(content, list):
179-
parts: list[str] = []
180-
for block in content:
181-
if isinstance(block, dict):
182-
if block.get("type") == "text" and "text" in block:
183-
parts.append(str(block.get("text", "")))
184-
else:
185-
t = block.get("text")
186-
if t is not None:
187-
parts.append(str(t))
188-
else:
189-
tx = getattr(block, "text", None)
190-
if tx is not None:
191-
parts.append(str(tx))
192-
return "".join(parts)
193-
return str(content)
194-
195-
196-
def _assistant_text_from_chat_completion_or_raise(
197-
resp: object, *, log: logging.Logger, default_model: str | None = None
198-
) -> str:
199-
"""
200-
Return the assistant's output text, or raise if there is nothing usable to parse as JSON.
201-
202-
A ``200`` response with ``content=None`` and no text was previously turned into ``""``,
203-
which only surfaces as JSONDecodeError on empty input. We fail fast with diagnostics
204-
and surface refusals (e.g. Claude) explicitly.
205-
"""
206-
choices = getattr(resp, "choices", None) or []
207-
if not choices:
208-
msg = "AI Gateway returned no choices on chat.completions"
209-
log.error("%s: model=%r", msg, getattr(resp, "model", default_model))
210-
raise RuntimeError(msg) from None
211-
212-
ch0 = choices[0]
213-
msg = ch0.message
214-
raw = _text_from_message_content(getattr(msg, "content", None))
215-
if raw.strip():
216-
return raw
217-
218-
ref = getattr(msg, "refusal", None)
219-
if isinstance(ref, str) and ref.strip():
220-
short = ref.strip()[:2000]
221-
log.error(
222-
"AI Gateway: model refusal (not valid JSON for downstream parse): %s", short
223-
)
224-
raise RuntimeError(
225-
"The model refused to return structured output. Refusal: "
226-
+ ref.strip()[:4000]
227-
) from None
228-
229-
u = getattr(resp, "usage", None)
230-
udump: object
231-
if u is not None and hasattr(u, "model_dump"):
232-
udump = u.model_dump() # type: ignore[assignment]
233-
else:
234-
udump = u
235-
fr = getattr(ch0, "finish_reason", None)
236-
mod = getattr(resp, "model", None) or default_model
237-
c_raw = getattr(msg, "content", None)
238-
log.error(
239-
"AI Gateway: empty assistant message: finish_reason=%r model=%r usage=%r content=%r",
240-
fr,
241-
mod,
242-
udump,
243-
c_raw,
244-
)
245-
raise RuntimeError(
246-
"AI Gateway returned an empty assistant message. "
247-
f"finish_reason={fr!r}, model={mod!r}, usage={udump!r}. "
248-
"The prompt may exceed the model context, max_tokens may be exhausted, "
249-
"or the model emitted no text — try a smaller input batch or higher limits."
250-
) from None
251-
252-
25340
_T = TypeVar("_T")
25441

25542

0 commit comments

Comments
 (0)