Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions backend/models/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ class CollectionContext(BaseModel):

class QueryPrompt(BaseModel):
user_input: str
account_id: str # Added for cross-collection schema fetching
db_context: DBContext
collection_context: CollectionContext | None = (
None # Optional context for specific collection
)
collection_context: list[CollectionContext] = (
[]
) # List of contexts for selected collections
intermediate_context: object | None = (
None # Optional intermediate context for complex queries
)
Expand All @@ -54,3 +55,22 @@ class DebugQueryRequest(BaseModel):

class DebugSuggestionResponse(BaseModel):
suggestion: str


class SchemaRelationshipsRequest(BaseModel):
account_id: str
database_name: str
collection_names: list[str]


class Relationship(BaseModel):
source_collection: str
source_field: str
target_collection: str
target_field: str
description: str
confidence: float # 0.0 to 1.0


class SchemaRelationshipsResponse(BaseModel):
relationships: list[Relationship]
68 changes: 64 additions & 4 deletions backend/routes/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,90 @@
ExecuteInput,
DebugQueryRequest,
DebugSuggestionResponse,
SchemaRelationshipsRequest,
SchemaRelationshipsResponse,
)
from services.gemini_service import (
generate_query_from_prompt,
generate_suggestion_from_query_error,
generate_schema_relationships,
)
from services.mongo_service import (
execute_mongo_query,
transform_mongo_result,
get_database_schema_summary,
)
from services.mongo_service import execute_mongo_query, transform_mongo_result
from models.analyze import AnalyzeRequest, AnalyzeResponse
from services.analyze_service import analyze_query_result

router = APIRouter()


@router.post("/nl2query")
def nl2query(prompt: QueryPrompt = Body(...)):
def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)):
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid token format")

try:
user_token = authorization.replace("Bearer ", "")
access_token = exchange_token_obo(user_token)
# Use provided contexts if available to avoid re-fetch and ensure consistency
if prompt.collection_context:
summary = []
for ctx in prompt.collection_context:
doc_str = (
str(ctx.sampleDocument)
if ctx.sampleDocument
else "No documents found"
)
summary.append(f"Collection: {ctx.name}\nSample Document: {doc_str}")
schema_summary = "\n\n".join(summary)
# Fallback: fetch schema summary from DB
else:
schema_summary = get_database_schema_summary(
prompt.account_id, prompt.db_context.name, access_token
)
except Exception as e:
print(f"Error fetching schema context: {e}")
schema_summary = "Could not fetch schema summary."

collections = [col.name for col in prompt.db_context.collections]
return generate_query_from_prompt(
prompt.user_input,
collections,
prompt.db_context.name,
prompt.collection_context,
prompt.intermediate_context,
collection_context=None,
intermediate_context=prompt.intermediate_context,
all_collections_schema=schema_summary,
)


@router.post("/infer-relationships", response_model=SchemaRelationshipsResponse)
def infer_relationships(
request: SchemaRelationshipsRequest = Body(...), authorization: str = Header(...)
):
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid token format")

try:
user_token = authorization.replace("Bearer ", "")
access_token = exchange_token_obo(user_token)

# Fetch schema summary ONLY for correct collections
schema_summary = get_database_schema_summary(
request.account_id,
request.database_name,
access_token,
collection_filter=request.collection_names,
)

return generate_schema_relationships(schema_summary)

except Exception as e:
print(f"Error inferring relationships: {e}")
raise HTTPException(status_code=500, detail=str(e))


@router.post("/execute")
def execute(query: ExecuteInput = Body(...), authorization: str = Header(...)):
if not authorization.startswith("Bearer "):
Expand Down
65 changes: 64 additions & 1 deletion backend/services/gemini_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from google import genai
from google.genai import types
from models.schemas import GeneratedCode, CollectionContext, DebugSuggestionResponse
from models.schemas import (
GeneratedCode,
CollectionContext,
DebugSuggestionResponse,
SchemaRelationshipsResponse,
)
from pydantic import BaseModel, Field
from typing import Optional, List

Expand Down Expand Up @@ -31,6 +36,7 @@ class AuditSummaryResponse(BaseModel):
Database: {database}
Available collections: {collections}
Sample collection document (optional): {collection_context}
Schema summary for ALL collections (for JOINs/lookups): {all_collections_schema}
Intermediate context (optional): {intermediate_context}
Return:
only one line of pure pymongo query code (e.g., db["collection"].find(...))
Expand Down Expand Up @@ -112,6 +118,7 @@ def generate_query_from_prompt(
database: str,
collection_context: CollectionContext = None,
intermediate_context: dict = None,
all_collections_schema: str = "",
) -> GeneratedCode:
# Prune intermediate_context to remove image/large data
safe_intermediate_context = (
Expand All @@ -124,6 +131,7 @@ def generate_query_from_prompt(
collection_context=(
collection_context.sampleDocument if collection_context else ""
),
all_collections_schema=all_collections_schema,
intermediate_context=safe_intermediate_context,
)
client = genai.Client()
Expand Down Expand Up @@ -244,3 +252,58 @@ def summarize_audit_results(
summary="Could not generate summary due to parsing error.",
visualization=VisualizationConfig(available=False),
)


PROMPT_TEMPLATE_RELATIONSHIPS = """
You are a database architect. Analyze the provided MongoDB document samples to identify likely foreign key relationships and JOIN conditions between collections.

Schema/Samples:
{schema_summary}

Tasks:
1. Identify likely relationships (e.g., `userId` in `orders` -> `_id` in `users`).
2. Provide a confidence score (0.0 - 1.0) and a brief description for each.
3. Return a JSON object with a "relationships" key containing a list of these findings.

Output Format (Json):
{{
"relationships": [
{{
"source_collection": "orders",
"source_field": "userId",
"target_collection": "users",
"target_field": "_id",
"description": "Orders belong to Users",
"confidence": 0.95
}}
]
}}
"""


def generate_schema_relationships(schema_summary: str) -> SchemaRelationshipsResponse:
from models.schemas import SchemaRelationshipsResponse

full_prompt = PROMPT_TEMPLATE_RELATIONSHIPS.format(schema_summary=schema_summary)
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=full_prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=SchemaRelationshipsResponse,
thinking_config=types.ThinkingConfig(thinking_budget=0),
),
)

if hasattr(response, "parsed") and response.parsed:
return response.parsed

import json

try:
data = json.loads(response.text)
return SchemaRelationshipsResponse(**data)
except Exception as e:
print(f"Error parsing Gemini relationship response: {e}")
return SchemaRelationshipsResponse(relationships=[])
35 changes: 35 additions & 0 deletions backend/services/mongo_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,38 @@ def transform_mongo_result(result):
elif isinstance(result, DeleteResult):
return {"deleted_count": result.deleted_count}
return result


def get_database_schema_summary(
account_id: str,
database: str,
access_token: str,
collection_filter: list[str] = None,
) -> str:
from services.azure_cosmos_resources import get_connection_string

try:
connection_string = get_connection_string(account_id, access_token)
client = pymongo.MongoClient(connection_string)
db = client[database]
summary = []

# Determine which collections to scan
if collection_filter:
target_collections = collection_filter
else:
target_collections = db.list_collection_names()

for collection_name in target_collections:
# Skip system collections if scanning all (if explicit filter, try to fetch)
if not collection_filter and collection_name.startswith("system."):
continue

doc = db[collection_name].find_one()
doc_str = str(doc) if doc else "No documents found"
summary.append(f"Collection: {collection_name}\nSample Document: {doc_str}")

return "\n\n".join(summary)
except Exception as e:
print(f"Error fetching schema summary: {e}")
return "Could not fetch schema summary."
13 changes: 10 additions & 3 deletions backend/tests/test_query_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
def test_nl2query(client):
"""Test natural language to query conversion."""
# Mock dependencies
with patch("routes.query.generate_query_from_prompt") as mock_generate:
with (
patch("routes.query.generate_query_from_prompt") as mock_generate,
patch("routes.query.exchange_token_obo") as mock_exchange,
):
mock_generate.return_value = {"generated_code": "db.users.find({})"}

# Create test data
Expand All @@ -25,11 +28,15 @@ def test_nl2query(client):

prompt = QueryPrompt(
user_input="Find all users",
account_id="test-account",
db_context=db_context,
collection_context=collection_context,
collection_context=[collection_context],
)

response = client.post("/query/nl2query", json=prompt.model_dump())
headers = {"authorization": "Bearer valid-token"}
response = client.post(
"/query/nl2query", json=prompt.model_dump(), headers=headers
)

assert response.status_code == 200
data = response.json()
Expand Down
11 changes: 7 additions & 4 deletions backend/tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,22 @@ def test_query_prompt():

prompt = QueryPrompt(
user_input="Find all users",
account_id="test-account",
db_context=db_context,
collection_context=collection_context,
collection_context=[collection_context],
intermediate_context={"key": "value"},
)

assert prompt.user_input == "Find all users"
assert prompt.db_context.name == "test-db"
assert prompt.collection_context.name == "users"
assert prompt.collection_context[0].name == "users"
assert prompt.intermediate_context == {"key": "value"}

# Test with minimal required fields
minimal_prompt = QueryPrompt(user_input="Find all users", db_context=db_context)
assert minimal_prompt.collection_context is None
minimal_prompt = QueryPrompt(
user_input="Find all users", account_id="test-account", db_context=db_context
)
assert minimal_prompt.collection_context == []
assert minimal_prompt.intermediate_context is None


Expand Down
Loading