-
Notifications
You must be signed in to change notification settings - Fork 0
feat: add AuditPage for analyzing audit logs with visualization and m… #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| from fastapi import APIRouter, Header, Body, HTTPException | ||
| from pydantic import BaseModel | ||
| from typing import List, Dict, Any, Optional | ||
| from services.audit_service import process_audit_question | ||
| from services.gemini_service import VisualizationConfig | ||
|
|
||
| router = APIRouter() | ||
|
|
||
|
|
||
| class AuditQueryRequest(BaseModel): | ||
| question: str | ||
|
|
||
|
|
||
| class AuditQueryResponse(BaseModel): | ||
| sql_query: str | ||
| results: List[Dict[str, Any]] | ||
| summary: str | ||
| visualization: Optional[VisualizationConfig] = None | ||
|
|
||
|
|
||
| @router.post("/query", response_model=AuditQueryResponse) | ||
| def query_audit_log( | ||
| body: AuditQueryRequest = Body(...), authorization: str = Header(...) | ||
| ): | ||
| if not authorization.startswith("Bearer "): | ||
| raise HTTPException(status_code=401, detail="Invalid token format") | ||
|
|
||
| # We might want to validate the token here even if we don't use it for the pg connection directly yet | ||
| # user_token = authorization.replace("Bearer ", "") | ||
| # access_token = exchange_token_obo(user_token) | ||
|
|
||
| response = process_audit_question(body.question) | ||
| return AuditQueryResponse(**response) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| from typing import Dict, Any | ||
| from services.pg_connection import get_connection | ||
| from services.gemini_service import generate_audit_sql, summarize_audit_results | ||
|
|
||
|
|
||
| def execute_audit_query(sql_query: str) -> list: | ||
| """ | ||
| Executes a read-only SQL query against the audit database. | ||
| """ | ||
| if not sql_query.lower().strip().startswith("select"): | ||
| return [{"error": "Only SELECT queries are allowed."}] | ||
|
|
||
| try: | ||
| conn = get_connection() | ||
| cur = conn.cursor() | ||
| cur.execute(sql_query) | ||
|
|
||
| # Get column names | ||
| columns = [desc[0] for desc in cur.description] | ||
| results = [dict(zip(columns, row)) for row in cur.fetchall()] | ||
|
|
||
| # Serialize datetime and json objects | ||
| for row in results: | ||
| for key, value in row.items(): | ||
| if hasattr(value, "isoformat"): | ||
| row[key] = value.isoformat() | ||
| elif isinstance(value, dict): | ||
| # Ensure dicts (like diff_data) are kept as dicts for the frontend | ||
| pass | ||
|
|
||
| cur.close() | ||
| conn.close() | ||
| return results | ||
| except Exception as e: | ||
| print(f"Error executing audit query: {e}") | ||
| return [{"error": str(e)}] | ||
|
|
||
|
|
||
| def process_audit_question(question: str) -> Dict[str, Any]: | ||
| """ | ||
| Orchestrates the process of answering a user's audit question: | ||
| 1. Generate SQL from NL question (via Gemini) | ||
| 2. Execute SQL | ||
| 3. Summarize results (via Gemini) | ||
| """ | ||
| sql_query = generate_audit_sql(question) | ||
|
|
||
| # If the generator returned an error query or invalid SQL, return it | ||
| if "Error:" in sql_query: | ||
| return { | ||
| "sql_query": sql_query, | ||
| "results": [], | ||
| "summary": "Could not generate a valid query for your request.", | ||
| } | ||
|
|
||
| results = execute_audit_query(sql_query) | ||
|
|
||
| # If execution failed | ||
| if results and "error" in results[0]: | ||
| return { | ||
| "sql_query": sql_query, | ||
| "results": [], | ||
| "summary": f"Error executing query: {results[0]['error']}", | ||
| } | ||
|
|
||
| summary_response = summarize_audit_results(question, sql_query, results) | ||
|
|
||
| return { | ||
| "sql_query": sql_query, | ||
| "results": results, | ||
| "summary": summary_response.summary, | ||
| "visualization": summary_response.visualization, | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,6 +1,29 @@ | ||||||||||||||
| from google import genai | ||||||||||||||
| from google.genai import types | ||||||||||||||
| from models.schemas import GeneratedCode, CollectionContext, DebugSuggestionResponse | ||||||||||||||
| from pydantic import BaseModel, Field | ||||||||||||||
| from typing import Optional, List | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class VisualizationConfig(BaseModel): | ||||||||||||||
| available: bool = Field(description="Whether a chart is recommended for this data") | ||||||||||||||
| type: Optional[str] = Field( | ||||||||||||||
| description="Type of chart: 'bar', 'line', 'pie', 'scatter'" | ||||||||||||||
| ) | ||||||||||||||
| x_key: Optional[str] = Field(description="Key for X-axis data") | ||||||||||||||
| y_key: Optional[str] = Field(description="Key for Y-axis data") | ||||||||||||||
| title: Optional[str] = Field(description="Title for the chart") | ||||||||||||||
| data_keys: Optional[List[str]] = Field( | ||||||||||||||
| description="Keys to include in the chart data points (e.g. ['count', 'date'])" | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class AuditSummaryResponse(BaseModel): | ||||||||||||||
| summary: str = Field(description="Markdown summary of the results") | ||||||||||||||
| visualization: VisualizationConfig = Field( | ||||||||||||||
| description="Configuration for data visualization" | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| PROMPT_TEMPLATE_QUERY = """ | ||||||||||||||
| You are an assistant that converts user requests into MongoDB query code. | ||||||||||||||
|
|
@@ -132,3 +155,92 @@ def generate_suggestion_from_query_error(query: str, error_message: str) -> str: | |||||||||||||
| response.text.strip() if hasattr(response, "text") else str(response).strip | ||||||||||||||
| ) | ||||||||||||||
| return DebugSuggestionResponse(suggestion=suggestion) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| PROMPT_TEMPLATE_AUDIT_SQL = """ | ||||||||||||||
| You are a PostgreSQL expert. Convert the user's natural language question into a read-only SQL query for the `write_audit_log` table. | ||||||||||||||
| Table Schema: | ||||||||||||||
| - user_email (text): Email of the user who performed the operation. | ||||||||||||||
| - operation (text): 'insert', 'update', or 'delete'. | ||||||||||||||
| - database_name (text): Name of the database (format: account.database). | ||||||||||||||
| - collection_name (text): Name of the collection. | ||||||||||||||
| - document_id (text): ID of the affected document. | ||||||||||||||
| - diff_data (jsonb): JSON containing the changes (for updates, it has 'before' and 'after' fields). | ||||||||||||||
| - timestamp_utc (timestamptz): When the operation occurred. | ||||||||||||||
|
|
||||||||||||||
| User Question: "{user_input}" | ||||||||||||||
|
|
||||||||||||||
| Rules: | ||||||||||||||
| 1. Return ONLY the SQL query. No markdown, no explanations. | ||||||||||||||
| 2. The query MUST be a SELECT statement. | ||||||||||||||
| 3. Use LIMIT 100 if no limit is specified. | ||||||||||||||
| 4. If the user asks for "recent", order by timestamp_utc DESC. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| PROMPT_TEMPLATE_AUDIT_SUMMARY = """ | ||||||||||||||
| You are a data analyst. Analyze the following SQL query and its results. | ||||||||||||||
|
|
||||||||||||||
| User Question: "{user_input}" | ||||||||||||||
| SQL Query: "{sql_query}" | ||||||||||||||
| Results: | ||||||||||||||
| {results} | ||||||||||||||
|
|
||||||||||||||
| Tasks: | ||||||||||||||
| 1. Provide a concise markdown summary identifying patterns or answering the specific question. | ||||||||||||||
| 2. Determine if the data is suitable for visualization (e.g., time series, counts, comparisons). | ||||||||||||||
| 3. If suitable, structure a visualization configuration (type, keys, title). | ||||||||||||||
| - For time series, prefer 'line' or 'bar'. | ||||||||||||||
| - For categorical counts, use 'bar' or 'pie'. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def generate_audit_sql(user_input: str) -> str: | ||||||||||||||
| full_prompt = PROMPT_TEMPLATE_AUDIT_SQL.format(user_input=user_input) | ||||||||||||||
| client = genai.Client() | ||||||||||||||
| response = client.models.generate_content( | ||||||||||||||
| model="gemini-2.5-flash", | ||||||||||||||
| contents=full_prompt, | ||||||||||||||
| config=types.GenerateContentConfig( | ||||||||||||||
| thinking_config=types.ThinkingConfig(thinking_budget=0) | ||||||||||||||
| ), | ||||||||||||||
| ) | ||||||||||||||
| sql = extract_python_code(response.text) | ||||||||||||||
| # Basic safety check | ||||||||||||||
| if not sql.lower().startswith("select"): | ||||||||||||||
| return "SELECT 'Error: Generated query was not a SELECT statement' as error;" | ||||||||||||||
|
Comment on lines
+208
to
+210
|
||||||||||||||
| # Basic safety check | |
| if not sql.lower().startswith("select"): | |
| return "SELECT 'Error: Generated query was not a SELECT statement' as error;" | |
| # Basic safety check: ensure we only return a valid SELECT statement | |
| if not isinstance(sql, str) or not sql.strip().lower().startswith("select"): | |
| raise ValueError("Generated query was not a valid SELECT statement.") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| import unittest | ||
| from unittest.mock import patch, MagicMock | ||
| from services.audit_service import process_audit_question | ||
|
|
||
|
|
||
| class TestAuditService(unittest.TestCase): | ||
|
|
||
| @patch("services.audit_service.generate_audit_sql") | ||
| @patch("services.audit_service.get_connection") | ||
| @patch("services.audit_service.summarize_audit_results") | ||
| def test_process_audit_question_success( | ||
| self, mock_summarize, mock_get_conn, mock_generate_sql | ||
| ): | ||
| # Setup mocks | ||
| mock_generate_sql.return_value = "SELECT * FROM write_audit_log LIMIT 5" | ||
|
|
||
| mock_conn = MagicMock() | ||
| mock_cursor = MagicMock() | ||
| mock_get_conn.return_value = mock_conn | ||
| mock_conn.cursor.return_value = mock_cursor | ||
|
|
||
| # Mock DB results | ||
| mock_cursor.description = [("user_email",), ("operation",)] | ||
| mock_cursor.fetchall.return_value = [("[email protected]", "insert")] | ||
|
|
||
| # Mock Summary Response | ||
| mock_response = MagicMock() | ||
| mock_response.summary = "Summary of results." | ||
| mock_response.visualization = None | ||
| mock_summarize.return_value = mock_response | ||
|
|
||
| # Execute | ||
| result = process_audit_question("Show me inserts") | ||
|
|
||
| # Assertions | ||
| self.assertEqual(result["sql_query"], "SELECT * FROM write_audit_log LIMIT 5") | ||
| self.assertEqual(len(result["results"]), 1) | ||
| self.assertEqual(result["results"][0]["user_email"], "[email protected]") | ||
| self.assertEqual(result["summary"], "Summary of results.") | ||
|
|
||
| mock_generate_sql.assert_called_once() | ||
| mock_cursor.execute.assert_called_with("SELECT * FROM write_audit_log LIMIT 5") | ||
| mock_summarize.assert_called_once() | ||
|
|
||
| @patch("services.audit_service.generate_audit_sql") | ||
| def test_process_audit_question_invalid_sql(self, mock_generate_sql): | ||
| mock_generate_sql.return_value = "DELETE FROM write_audit_log" | ||
|
|
||
| result = process_audit_question("Delete everything") | ||
|
|
||
| self.assertIn("Error executing query", result["summary"]) | ||
| self.assertIn("Only SELECT", result["summary"]) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SQL query validation only checks if the query starts with "select" (case-insensitive). This is insufficient protection as malicious queries could still be crafted, such as "SELECT pg_sleep(100)" for denial of service, or queries using functions that could modify data. Consider implementing more robust SQL validation, such as using a SQL parser to ensure only safe SELECT operations are allowed, or using parameterized queries with a whitelist of allowed operations.