Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from routes import query, azure, system, user_queries, data_documents
from routes import query, azure, system, user_queries, data_documents, audit

app = FastAPI()

Expand Down Expand Up @@ -55,6 +55,7 @@ async def health_check():
app.include_router(system.router, prefix="/system", tags=["System"])
app.include_router(user_queries.router, prefix="/user", tags=["User Queries"])
app.include_router(data_documents.router, prefix="/data", tags=["Data Documents"])
app.include_router(audit.router, prefix="/audit", tags=["Audit"])

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
33 changes: 33 additions & 0 deletions backend/routes/audit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from fastapi import APIRouter, Header, Body, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
from services.audit_service import process_audit_question
from services.gemini_service import VisualizationConfig

router = APIRouter()


class AuditQueryRequest(BaseModel):
question: str


class AuditQueryResponse(BaseModel):
sql_query: str
results: List[Dict[str, Any]]
summary: str
visualization: Optional[VisualizationConfig] = None


@router.post("/query", response_model=AuditQueryResponse)
def query_audit_log(
body: AuditQueryRequest = Body(...), authorization: str = Header(...)
):
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid token format")

# We might want to validate the token here even if we don't use it for the pg connection directly yet
# user_token = authorization.replace("Bearer ", "")
# access_token = exchange_token_obo(user_token)

response = process_audit_question(body.question)
return AuditQueryResponse(**response)
73 changes: 73 additions & 0 deletions backend/services/audit_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Dict, Any
from services.pg_connection import get_connection
from services.gemini_service import generate_audit_sql, summarize_audit_results


def execute_audit_query(sql_query: str) -> list:
"""
Executes a read-only SQL query against the audit database.
"""
if not sql_query.lower().strip().startswith("select"):
return [{"error": "Only SELECT queries are allowed."}]
Comment on lines +4 to +11
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SQL query validation only checks if the query starts with "select" (case-insensitive). This is insufficient protection as malicious queries could still be crafted, such as "SELECT pg_sleep(100)" for denial of service, or queries using functions that could modify data. Consider implementing more robust SQL validation, such as using a SQL parser to ensure only safe SELECT operations are allowed, or using parameterized queries with a whitelist of allowed operations.

Suggested change
def execute_audit_query(sql_query: str) -> list:
"""
Executes a read-only SQL query against the audit database.
"""
if not sql_query.lower().strip().startswith("select"):
return [{"error": "Only SELECT queries are allowed."}]
import sqlparse
def _is_safe_readonly_sql(sql_query: str) -> (bool, str | None):
"""
Validates that the provided SQL string is a single, read-only SELECT statement
and does not contain obviously dangerous operations.
"""
if not sql_query or not sql_query.strip():
return False, "Query is empty."
# Normalize whitespace and remove a single trailing semicolon, if present.
normalized_query = sql_query.strip()
if normalized_query.endswith(";"):
normalized_query = normalized_query[:-1].strip()
try:
statements = sqlparse.parse(normalized_query)
except Exception:
return False, "Unable to parse query."
if len(statements) != 1:
return False, "Only a single SELECT statement is allowed."
statement = statements[0]
if statement.get_type() != "SELECT":
return False, "Only SELECT queries are allowed."
# Additional conservative safety checks on the normalized text.
lowered = normalized_query.lower()
disallowed_substrings = [
"pg_sleep",
"insert ",
"update ",
"delete ",
"drop ",
"alter ",
"truncate ",
"create ",
"grant ",
"revoke ",
"execute ",
"call ",
]
if any(substr in lowered for substr in disallowed_substrings):
return False, "The query uses a disallowed or potentially unsafe operation."
return True, None
def execute_audit_query(sql_query: str) -> list:
"""
Executes a read-only SQL query against the audit database.
"""
is_safe, error_message = _is_safe_readonly_sql(sql_query)
if not is_safe:
return [{"error": error_message}]

Copilot uses AI. Check for mistakes.

try:
conn = get_connection()
cur = conn.cursor()
cur.execute(sql_query)

# Get column names
columns = [desc[0] for desc in cur.description]
results = [dict(zip(columns, row)) for row in cur.fetchall()]

# Serialize datetime and json objects
for row in results:
for key, value in row.items():
if hasattr(value, "isoformat"):
row[key] = value.isoformat()
elif isinstance(value, dict):
# Ensure dicts (like diff_data) are kept as dicts for the frontend
pass

cur.close()
conn.close()
return results
except Exception as e:
print(f"Error executing audit query: {e}")
return [{"error": str(e)}]


def process_audit_question(question: str) -> Dict[str, Any]:
"""
Orchestrates the process of answering a user's audit question:
1. Generate SQL from NL question (via Gemini)
2. Execute SQL
3. Summarize results (via Gemini)
"""
sql_query = generate_audit_sql(question)

# If the generator returned an error query or invalid SQL, return it
if "Error:" in sql_query:
return {
"sql_query": sql_query,
"results": [],
"summary": "Could not generate a valid query for your request.",
}

results = execute_audit_query(sql_query)

# If execution failed
if results and "error" in results[0]:
return {
"sql_query": sql_query,
"results": [],
"summary": f"Error executing query: {results[0]['error']}",
}

summary_response = summarize_audit_results(question, sql_query, results)

return {
"sql_query": sql_query,
"results": results,
"summary": summary_response.summary,
"visualization": summary_response.visualization,
}
112 changes: 112 additions & 0 deletions backend/services/gemini_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
from google import genai
from google.genai import types
from models.schemas import GeneratedCode, CollectionContext, DebugSuggestionResponse
from pydantic import BaseModel, Field
from typing import Optional, List


class VisualizationConfig(BaseModel):
available: bool = Field(description="Whether a chart is recommended for this data")
type: Optional[str] = Field(
description="Type of chart: 'bar', 'line', 'pie', 'scatter'"
)
x_key: Optional[str] = Field(description="Key for X-axis data")
y_key: Optional[str] = Field(description="Key for Y-axis data")
title: Optional[str] = Field(description="Title for the chart")
data_keys: Optional[List[str]] = Field(
description="Keys to include in the chart data points (e.g. ['count', 'date'])"
)


class AuditSummaryResponse(BaseModel):
summary: str = Field(description="Markdown summary of the results")
visualization: VisualizationConfig = Field(
description="Configuration for data visualization"
)


PROMPT_TEMPLATE_QUERY = """
You are an assistant that converts user requests into MongoDB query code.
Expand Down Expand Up @@ -132,3 +155,92 @@ def generate_suggestion_from_query_error(query: str, error_message: str) -> str:
response.text.strip() if hasattr(response, "text") else str(response).strip
)
return DebugSuggestionResponse(suggestion=suggestion)


PROMPT_TEMPLATE_AUDIT_SQL = """
You are a PostgreSQL expert. Convert the user's natural language question into a read-only SQL query for the `write_audit_log` table.
Table Schema:
- user_email (text): Email of the user who performed the operation.
- operation (text): 'insert', 'update', or 'delete'.
- database_name (text): Name of the database (format: account.database).
- collection_name (text): Name of the collection.
- document_id (text): ID of the affected document.
- diff_data (jsonb): JSON containing the changes (for updates, it has 'before' and 'after' fields).
- timestamp_utc (timestamptz): When the operation occurred.

User Question: "{user_input}"

Rules:
1. Return ONLY the SQL query. No markdown, no explanations.
2. The query MUST be a SELECT statement.
3. Use LIMIT 100 if no limit is specified.
4. If the user asks for "recent", order by timestamp_utc DESC.
"""

PROMPT_TEMPLATE_AUDIT_SUMMARY = """
You are a data analyst. Analyze the following SQL query and its results.

User Question: "{user_input}"
SQL Query: "{sql_query}"
Results:
{results}

Tasks:
1. Provide a concise markdown summary identifying patterns or answering the specific question.
2. Determine if the data is suitable for visualization (e.g., time series, counts, comparisons).
3. If suitable, structure a visualization configuration (type, keys, title).
- For time series, prefer 'line' or 'bar'.
- For categorical counts, use 'bar' or 'pie'.
"""


def generate_audit_sql(user_input: str) -> str:
full_prompt = PROMPT_TEMPLATE_AUDIT_SQL.format(user_input=user_input)
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=full_prompt,
config=types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(thinking_budget=0)
),
)
sql = extract_python_code(response.text)
# Basic safety check
if not sql.lower().startswith("select"):
return "SELECT 'Error: Generated query was not a SELECT statement' as error;"
Comment on lines +208 to +210
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error handling in generate_audit_sql returns a SELECT statement containing an error message rather than raising an exception or returning a structured error. This could lead to confusion as the function signature suggests it always returns a valid SQL query. Consider raising an exception or returning a Result type that clearly indicates success or failure.

Suggested change
# Basic safety check
if not sql.lower().startswith("select"):
return "SELECT 'Error: Generated query was not a SELECT statement' as error;"
# Basic safety check: ensure we only return a valid SELECT statement
if not isinstance(sql, str) or not sql.strip().lower().startswith("select"):
raise ValueError("Generated query was not a valid SELECT statement.")

Copilot uses AI. Check for mistakes.
return sql


def summarize_audit_results(
user_input: str, sql_query: str, results: list
) -> AuditSummaryResponse:
# Truncate results if too large to avoid token limits
results_str = str(results)[:10000]
full_prompt = PROMPT_TEMPLATE_AUDIT_SUMMARY.format(
user_input=user_input, sql_query=sql_query, results=results_str
)
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=full_prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=AuditSummaryResponse,
thinking_config=types.ThinkingConfig(thinking_budget=0),
),
)

if hasattr(response, "parsed") and response.parsed:
return response.parsed

import json

try:
data = json.loads(response.text)
return AuditSummaryResponse(**data)
except Exception as e:
print(f"Error parsing Gemini response: {e}")
return AuditSummaryResponse(
summary="Could not generate summary due to parsing error.",
visualization=VisualizationConfig(available=False),
)
56 changes: 56 additions & 0 deletions backend/tests/test_audit_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
from unittest.mock import patch, MagicMock
from services.audit_service import process_audit_question


class TestAuditService(unittest.TestCase):

@patch("services.audit_service.generate_audit_sql")
@patch("services.audit_service.get_connection")
@patch("services.audit_service.summarize_audit_results")
def test_process_audit_question_success(
self, mock_summarize, mock_get_conn, mock_generate_sql
):
# Setup mocks
mock_generate_sql.return_value = "SELECT * FROM write_audit_log LIMIT 5"

mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_get_conn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor

# Mock DB results
mock_cursor.description = [("user_email",), ("operation",)]
mock_cursor.fetchall.return_value = [("[email protected]", "insert")]

# Mock Summary Response
mock_response = MagicMock()
mock_response.summary = "Summary of results."
mock_response.visualization = None
mock_summarize.return_value = mock_response

# Execute
result = process_audit_question("Show me inserts")

# Assertions
self.assertEqual(result["sql_query"], "SELECT * FROM write_audit_log LIMIT 5")
self.assertEqual(len(result["results"]), 1)
self.assertEqual(result["results"][0]["user_email"], "[email protected]")
self.assertEqual(result["summary"], "Summary of results.")

mock_generate_sql.assert_called_once()
mock_cursor.execute.assert_called_with("SELECT * FROM write_audit_log LIMIT 5")
mock_summarize.assert_called_once()

@patch("services.audit_service.generate_audit_sql")
def test_process_audit_question_invalid_sql(self, mock_generate_sql):
mock_generate_sql.return_value = "DELETE FROM write_audit_log"

result = process_audit_question("Delete everything")

self.assertIn("Error executing query", result["summary"])
self.assertIn("Only SELECT", result["summary"])


if __name__ == "__main__":
unittest.main()
Loading