Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from routes import query, azure, system, user_queries, data_documents
from routes import query, azure, system, user_queries, data_documents, audit

app = FastAPI()

Expand Down Expand Up @@ -55,6 +55,7 @@ async def health_check():
app.include_router(system.router, prefix="/system", tags=["System"])
app.include_router(user_queries.router, prefix="/user", tags=["User Queries"])
app.include_router(data_documents.router, prefix="/data", tags=["Data Documents"])
app.include_router(audit.router, prefix="/audit", tags=["Audit"])

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
33 changes: 33 additions & 0 deletions backend/routes/audit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from fastapi import APIRouter, Header, Body, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
from services.audit_service import process_audit_question
from services.gemini_service import VisualizationConfig

router = APIRouter()


class AuditQueryRequest(BaseModel):
question: str


class AuditQueryResponse(BaseModel):
sql_query: str
results: List[Dict[str, Any]]
summary: str
visualization: Optional[VisualizationConfig] = None


@router.post("/query", response_model=AuditQueryResponse)
def query_audit_log(
body: AuditQueryRequest = Body(...), authorization: str = Header(...)
):
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid token format")

# We might want to validate the token here even if we don't use it for the pg connection directly yet
# user_token = authorization.replace("Bearer ", "")
Comment on lines +28 to +29
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The token validation only checks if the authorization header starts with "Bearer " but doesn't actually validate the token itself. The commented code suggests token exchange should happen, but it's not implemented. This leaves the endpoint vulnerable to unauthorized access with any malformed bearer token. Either implement proper token validation or add a TODO comment explaining why it's deferred.

Suggested change
# We might want to validate the token here even if we don't use it for the pg connection directly yet
# user_token = authorization.replace("Bearer ", "")
user_token = authorization[len("Bearer ") :].strip()
if not user_token:
raise HTTPException(status_code=401, detail="Missing bearer token")
# TODO: Implement full token validation or token exchange (e.g. via exchange_token_obo)
# before trusting this token for authorization. For example:

Copilot uses AI. Check for mistakes.
# access_token = exchange_token_obo(user_token)

response = process_audit_question(body.question)
return AuditQueryResponse(**response)
73 changes: 73 additions & 0 deletions backend/services/audit_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Dict, Any
from services.pg_connection import get_connection
from services.gemini_service import generate_audit_sql, summarize_audit_results


def execute_audit_query(sql_query: str) -> list:
"""
Executes a read-only SQL query against the audit database.
"""
if not sql_query.lower().strip().startswith("select"):
return [{"error": "Only SELECT queries are allowed."}]

try:
conn = get_connection()
cur = conn.cursor()
cur.execute(sql_query)

# Get column names
columns = [desc[0] for desc in cur.description]
results = [dict(zip(columns, row)) for row in cur.fetchall()]

# Serialize datetime and json objects
for row in results:
for key, value in row.items():
if hasattr(value, "isoformat"):
row[key] = value.isoformat()
elif isinstance(value, dict):
# Ensure dicts (like diff_data) are kept as dicts for the frontend
pass

cur.close()
conn.close()
return results
except Exception as e:
print(f"Error executing audit query: {e}")
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error handling uses print statements instead of proper logging. This is not suitable for production environments where logs need to be structured, searchable, and routable to appropriate monitoring systems. Replace print statements with proper logging using Python's logging module.

Copilot uses AI. Check for mistakes.
return [{"error": str(e)}]
Comment on lines +13 to +36
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The database connection is not properly closed in the error case. If an exception occurs after conn.cursor() but before cur.close() or conn.close(), the connection and cursor will remain open, potentially causing resource leaks. Consider using a try-finally block or context managers to ensure proper cleanup.

Copilot uses AI. Check for mistakes.


def process_audit_question(question: str) -> Dict[str, Any]:
"""
Orchestrates the process of answering a user's audit question:
1. Generate SQL from NL question (via Gemini)
2. Execute SQL
3. Summarize results (via Gemini)
"""
sql_query = generate_audit_sql(question)

# If the generator returned an error query or invalid SQL, return it
if "Error:" in sql_query:
return {
"sql_query": sql_query,
"results": [],
"summary": "Could not generate a valid query for your request.",
}

results = execute_audit_query(sql_query)

# If execution failed
if results and "error" in results[0]:
return {
"sql_query": sql_query,
"results": [],
"summary": f"Error executing query: {results[0]['error']}",
}

summary_response = summarize_audit_results(question, sql_query, results)

return {
"sql_query": sql_query,
"results": results,
"summary": summary_response.summary,
"visualization": summary_response.visualization,
}
112 changes: 112 additions & 0 deletions backend/services/gemini_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
from google import genai
from google.genai import types
from models.schemas import GeneratedCode, CollectionContext, DebugSuggestionResponse
from pydantic import BaseModel, Field
from typing import Optional, List


class VisualizationConfig(BaseModel):
available: bool = Field(description="Whether a chart is recommended for this data")
type: Optional[str] = Field(
description="Type of chart: 'bar', 'line', 'pie', 'scatter'"
)
x_key: Optional[str] = Field(description="Key for X-axis data")
y_key: Optional[str] = Field(description="Key for Y-axis data")
title: Optional[str] = Field(description="Title for the chart")
data_keys: Optional[List[str]] = Field(
description="Keys to include in the chart data points (e.g. ['count', 'date'])"
)


class AuditSummaryResponse(BaseModel):
summary: str = Field(description="Markdown summary of the results")
visualization: VisualizationConfig = Field(
description="Configuration for data visualization"
)


PROMPT_TEMPLATE_QUERY = """
You are an assistant that converts user requests into MongoDB query code.
Expand Down Expand Up @@ -132,3 +155,92 @@ def generate_suggestion_from_query_error(query: str, error_message: str) -> str:
response.text.strip() if hasattr(response, "text") else str(response).strip
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression uses response.text.strip() if hasattr(response, "text") else str(response).strip but is missing parentheses around str(response).strip(). This will call strip on the result of str(response) only if hasattr returns False, but the expression will always result in calling the strip attribute access (not the method) on the string, causing an error. This should be str(response).strip().

Suggested change
response.text.strip() if hasattr(response, "text") else str(response).strip
response.text.strip() if hasattr(response, "text") else str(response).strip()

Copilot uses AI. Check for mistakes.
)
return DebugSuggestionResponse(suggestion=suggestion)


PROMPT_TEMPLATE_AUDIT_SQL = """
You are a PostgreSQL expert. Convert the user's natural language question into a read-only SQL query for the `write_audit_log` table.
Table Schema:
- user_email (text): Email of the user who performed the operation.
- operation (text): 'insert', 'update', or 'delete'.
- database_name (text): Name of the database (format: account.database).
- collection_name (text): Name of the collection.
- document_id (text): ID of the affected document.
- diff_data (jsonb): JSON containing the changes (for updates, it has 'before' and 'after' fields).
- timestamp_utc (timestamptz): When the operation occurred.

User Question: "{user_input}"

Rules:
1. Return ONLY the SQL query. No markdown, no explanations.
2. The query MUST be a SELECT statement.
3. Use LIMIT 100 if no limit is specified.
4. If the user asks for "recent", order by timestamp_utc DESC.
"""

PROMPT_TEMPLATE_AUDIT_SUMMARY = """
You are a data analyst. Analyze the following SQL query and its results.

User Question: "{user_input}"
SQL Query: "{sql_query}"
Results:
{results}

Tasks:
1. Provide a concise markdown summary identifying patterns or answering the specific question.
2. Determine if the data is suitable for visualization (e.g., time series, counts, comparisons).
3. If suitable, structure a visualization configuration (type, keys, title).
- For time series, prefer 'line' or 'bar'.
- For categorical counts, use 'bar' or 'pie'.
"""


def generate_audit_sql(user_input: str) -> str:
full_prompt = PROMPT_TEMPLATE_AUDIT_SQL.format(user_input=user_input)
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=full_prompt,
config=types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(thinking_budget=0)
),
)
sql = extract_python_code(response.text)
# Basic safety check
if not sql.lower().startswith("select"):
return "SELECT 'Error: Generated query was not a SELECT statement' as error;"
return sql
Comment on lines +208 to +211
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SQL injection protection relies solely on checking if the query starts with "select". This is insufficient as it doesn't prevent malicious SELECT statements (e.g., SELECT queries that use subqueries to modify data or access unauthorized information). Consider using parameterized queries or a SQL parser to validate the query structure more thoroughly.

Copilot uses AI. Check for mistakes.


def summarize_audit_results(
user_input: str, sql_query: str, results: list
) -> AuditSummaryResponse:
# Truncate results if too large to avoid token limits
results_str = str(results)[:10000]
Comment on lines +217 to +218
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The results are truncated to 10000 characters which may cut off in the middle of a data structure, potentially causing JSON parsing issues or incomplete data representation. Consider implementing a smarter truncation strategy that respects data boundaries (e.g., limiting the number of result rows rather than string length) or using a sampling approach.

Copilot uses AI. Check for mistakes.
full_prompt = PROMPT_TEMPLATE_AUDIT_SUMMARY.format(
user_input=user_input, sql_query=sql_query, results=results_str
)
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=full_prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=AuditSummaryResponse,
thinking_config=types.ThinkingConfig(thinking_budget=0),
),
)

if hasattr(response, "parsed") and response.parsed:
return response.parsed

import json

try:
data = json.loads(response.text)
return AuditSummaryResponse(**data)
except Exception as e:
print(f"Error parsing Gemini response: {e}")
return AuditSummaryResponse(
summary="Could not generate summary due to parsing error.",
visualization=VisualizationConfig(available=False),
)
Comment on lines +197 to +246
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test coverage for the newly added audit functionality. While there is a test file for audit_service.py, there are no tests for the new functions in gemini_service.py (generate_audit_sql, summarize_audit_results) or the audit route. Consider adding comprehensive tests for these components to ensure reliability.

Copilot uses AI. Check for mistakes.
Comment on lines +242 to +246
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error handling prints to stdout but returns a generic fallback response. This makes debugging difficult in production. Consider using proper logging (with appropriate log levels) instead of print statements, and include more context about the error in the returned response or raise an appropriate exception.

Copilot uses AI. Check for mistakes.
56 changes: 56 additions & 0 deletions backend/tests/test_audit_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
from unittest.mock import patch, MagicMock
from services.audit_service import process_audit_question


class TestAuditService(unittest.TestCase):

@patch("services.audit_service.generate_audit_sql")
@patch("services.audit_service.get_connection")
@patch("services.audit_service.summarize_audit_results")
def test_process_audit_question_success(
self, mock_summarize, mock_get_conn, mock_generate_sql
):
# Setup mocks
mock_generate_sql.return_value = "SELECT * FROM write_audit_log LIMIT 5"

mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_get_conn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor

# Mock DB results
mock_cursor.description = [("user_email",), ("operation",)]
mock_cursor.fetchall.return_value = [("[email protected]", "insert")]

# Mock Summary Response
mock_response = MagicMock()
mock_response.summary = "Summary of results."
mock_response.visualization = None
mock_summarize.return_value = mock_response

# Execute
result = process_audit_question("Show me inserts")

# Assertions
self.assertEqual(result["sql_query"], "SELECT * FROM write_audit_log LIMIT 5")
self.assertEqual(len(result["results"]), 1)
self.assertEqual(result["results"][0]["user_email"], "[email protected]")
self.assertEqual(result["summary"], "Summary of results.")

mock_generate_sql.assert_called_once()
mock_cursor.execute.assert_called_with("SELECT * FROM write_audit_log LIMIT 5")
mock_summarize.assert_called_once()

@patch("services.audit_service.generate_audit_sql")
def test_process_audit_question_invalid_sql(self, mock_generate_sql):
mock_generate_sql.return_value = "DELETE FROM write_audit_log"

result = process_audit_question("Delete everything")

self.assertIn("Error executing query", result["summary"])
self.assertIn("Only SELECT", result["summary"])


if __name__ == "__main__":
unittest.main()
Loading
Loading