-
Notifications
You must be signed in to change notification settings - Fork 0
Release #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Release #19
Changes from all commits
a8f13ea
d57c968
495a85f
254e838
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| from fastapi import APIRouter, Header, Body, HTTPException | ||
| from pydantic import BaseModel | ||
| from typing import List, Dict, Any, Optional | ||
| from services.audit_service import process_audit_question | ||
| from services.gemini_service import VisualizationConfig | ||
|
|
||
| router = APIRouter() | ||
|
|
||
|
|
||
| class AuditQueryRequest(BaseModel): | ||
| question: str | ||
|
|
||
|
|
||
| class AuditQueryResponse(BaseModel): | ||
| sql_query: str | ||
| results: List[Dict[str, Any]] | ||
| summary: str | ||
| visualization: Optional[VisualizationConfig] = None | ||
|
|
||
|
|
||
| @router.post("/query", response_model=AuditQueryResponse) | ||
| def query_audit_log( | ||
| body: AuditQueryRequest = Body(...), authorization: str = Header(...) | ||
| ): | ||
| if not authorization.startswith("Bearer "): | ||
| raise HTTPException(status_code=401, detail="Invalid token format") | ||
|
|
||
| # We might want to validate the token here even if we don't use it for the pg connection directly yet | ||
| # user_token = authorization.replace("Bearer ", "") | ||
| # access_token = exchange_token_obo(user_token) | ||
|
|
||
| response = process_audit_question(body.question) | ||
| return AuditQueryResponse(**response) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| from typing import Dict, Any | ||
| from services.pg_connection import get_connection | ||
| from services.gemini_service import generate_audit_sql, summarize_audit_results | ||
|
|
||
|
|
||
| def execute_audit_query(sql_query: str) -> list: | ||
| """ | ||
| Executes a read-only SQL query against the audit database. | ||
| """ | ||
| if not sql_query.lower().strip().startswith("select"): | ||
| return [{"error": "Only SELECT queries are allowed."}] | ||
|
|
||
| try: | ||
| conn = get_connection() | ||
| cur = conn.cursor() | ||
| cur.execute(sql_query) | ||
|
|
||
| # Get column names | ||
| columns = [desc[0] for desc in cur.description] | ||
| results = [dict(zip(columns, row)) for row in cur.fetchall()] | ||
|
|
||
| # Serialize datetime and json objects | ||
| for row in results: | ||
| for key, value in row.items(): | ||
| if hasattr(value, "isoformat"): | ||
| row[key] = value.isoformat() | ||
| elif isinstance(value, dict): | ||
| # Ensure dicts (like diff_data) are kept as dicts for the frontend | ||
| pass | ||
|
|
||
| cur.close() | ||
| conn.close() | ||
| return results | ||
| except Exception as e: | ||
| print(f"Error executing audit query: {e}") | ||
|
||
| return [{"error": str(e)}] | ||
|
Comment on lines
+13
to
+36
|
||
|
|
||
|
|
||
| def process_audit_question(question: str) -> Dict[str, Any]: | ||
| """ | ||
| Orchestrates the process of answering a user's audit question: | ||
| 1. Generate SQL from NL question (via Gemini) | ||
| 2. Execute SQL | ||
| 3. Summarize results (via Gemini) | ||
| """ | ||
| sql_query = generate_audit_sql(question) | ||
|
|
||
| # If the generator returned an error query or invalid SQL, return it | ||
| if "Error:" in sql_query: | ||
| return { | ||
| "sql_query": sql_query, | ||
| "results": [], | ||
| "summary": "Could not generate a valid query for your request.", | ||
| } | ||
|
|
||
| results = execute_audit_query(sql_query) | ||
|
|
||
| # If execution failed | ||
| if results and "error" in results[0]: | ||
| return { | ||
| "sql_query": sql_query, | ||
| "results": [], | ||
| "summary": f"Error executing query: {results[0]['error']}", | ||
| } | ||
|
|
||
| summary_response = summarize_audit_results(question, sql_query, results) | ||
|
|
||
| return { | ||
| "sql_query": sql_query, | ||
| "results": results, | ||
| "summary": summary_response.summary, | ||
| "visualization": summary_response.visualization, | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,6 +1,29 @@ | ||||||
| from google import genai | ||||||
| from google.genai import types | ||||||
| from models.schemas import GeneratedCode, CollectionContext, DebugSuggestionResponse | ||||||
| from pydantic import BaseModel, Field | ||||||
| from typing import Optional, List | ||||||
|
|
||||||
|
|
||||||
| class VisualizationConfig(BaseModel): | ||||||
| available: bool = Field(description="Whether a chart is recommended for this data") | ||||||
| type: Optional[str] = Field( | ||||||
| description="Type of chart: 'bar', 'line', 'pie', 'scatter'" | ||||||
| ) | ||||||
| x_key: Optional[str] = Field(description="Key for X-axis data") | ||||||
| y_key: Optional[str] = Field(description="Key for Y-axis data") | ||||||
| title: Optional[str] = Field(description="Title for the chart") | ||||||
| data_keys: Optional[List[str]] = Field( | ||||||
| description="Keys to include in the chart data points (e.g. ['count', 'date'])" | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class AuditSummaryResponse(BaseModel): | ||||||
| summary: str = Field(description="Markdown summary of the results") | ||||||
| visualization: VisualizationConfig = Field( | ||||||
| description="Configuration for data visualization" | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| PROMPT_TEMPLATE_QUERY = """ | ||||||
| You are an assistant that converts user requests into MongoDB query code. | ||||||
|
|
@@ -132,3 +155,92 @@ def generate_suggestion_from_query_error(query: str, error_message: str) -> str: | |||||
| response.text.strip() if hasattr(response, "text") else str(response).strip | ||||||
|
||||||
| response.text.strip() if hasattr(response, "text") else str(response).strip | |
| response.text.strip() if hasattr(response, "text") else str(response).strip() |
Copilot
AI
Jan 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SQL injection protection relies solely on checking if the query starts with "select". This is insufficient as it doesn't prevent malicious SELECT statements (e.g., SELECT queries that use subqueries to modify data or access unauthorized information). Consider using parameterized queries or a SQL parser to validate the query structure more thoroughly.
Copilot
AI
Jan 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The results are truncated to 10000 characters which may cut off in the middle of a data structure, potentially causing JSON parsing issues or incomplete data representation. Consider implementing a smarter truncation strategy that respects data boundaries (e.g., limiting the number of result rows rather than string length) or using a sampling approach.
Copilot
AI
Jan 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test coverage for the newly added audit functionality. While there is a test file for audit_service.py, there are no tests for the new functions in gemini_service.py (generate_audit_sql, summarize_audit_results) or the audit route. Consider adding comprehensive tests for these components to ensure reliability.
Copilot
AI
Jan 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error handling prints to stdout but returns a generic fallback response. This makes debugging difficult in production. Consider using proper logging (with appropriate log levels) instead of print statements, and include more context about the error in the returned response or raise an appropriate exception.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| import unittest | ||
| from unittest.mock import patch, MagicMock | ||
| from services.audit_service import process_audit_question | ||
|
|
||
|
|
||
| class TestAuditService(unittest.TestCase): | ||
|
|
||
| @patch("services.audit_service.generate_audit_sql") | ||
| @patch("services.audit_service.get_connection") | ||
| @patch("services.audit_service.summarize_audit_results") | ||
| def test_process_audit_question_success( | ||
| self, mock_summarize, mock_get_conn, mock_generate_sql | ||
| ): | ||
| # Setup mocks | ||
| mock_generate_sql.return_value = "SELECT * FROM write_audit_log LIMIT 5" | ||
|
|
||
| mock_conn = MagicMock() | ||
| mock_cursor = MagicMock() | ||
| mock_get_conn.return_value = mock_conn | ||
| mock_conn.cursor.return_value = mock_cursor | ||
|
|
||
| # Mock DB results | ||
| mock_cursor.description = [("user_email",), ("operation",)] | ||
| mock_cursor.fetchall.return_value = [("[email protected]", "insert")] | ||
|
|
||
| # Mock Summary Response | ||
| mock_response = MagicMock() | ||
| mock_response.summary = "Summary of results." | ||
| mock_response.visualization = None | ||
| mock_summarize.return_value = mock_response | ||
|
|
||
| # Execute | ||
| result = process_audit_question("Show me inserts") | ||
|
|
||
| # Assertions | ||
| self.assertEqual(result["sql_query"], "SELECT * FROM write_audit_log LIMIT 5") | ||
| self.assertEqual(len(result["results"]), 1) | ||
| self.assertEqual(result["results"][0]["user_email"], "[email protected]") | ||
| self.assertEqual(result["summary"], "Summary of results.") | ||
|
|
||
| mock_generate_sql.assert_called_once() | ||
| mock_cursor.execute.assert_called_with("SELECT * FROM write_audit_log LIMIT 5") | ||
| mock_summarize.assert_called_once() | ||
|
|
||
| @patch("services.audit_service.generate_audit_sql") | ||
| def test_process_audit_question_invalid_sql(self, mock_generate_sql): | ||
| mock_generate_sql.return_value = "DELETE FROM write_audit_log" | ||
|
|
||
| result = process_audit_question("Delete everything") | ||
|
|
||
| self.assertIn("Error executing query", result["summary"]) | ||
| self.assertIn("Only SELECT", result["summary"]) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The token validation only checks if the authorization header starts with "Bearer " but doesn't actually validate the token itself. The commented code suggests token exchange should happen, but it's not implemented. This leaves the endpoint vulnerable to unauthorized access with any malformed bearer token. Either implement proper token validation or add a TODO comment explaining why it's deferred.