-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
115 lines (93 loc) · 3.54 KB
/
app.py
File metadata and controls
115 lines (93 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import asyncio
import os
import streamlit as st
from dotenv import load_dotenv
from google.oauth2 import service_account
from google.cloud import bigquery
from agent import run_agent
from bq import manifest_to_prompt, dry_run_sql
load_dotenv()
GCP_PROJECT_ID = os.getenv("GCP_PROJECT_ID", "plfpl-production")
BQ_DATASET_PREFIX = os.getenv("BQ_DATASET_PREFIX", "ism_GW")
correction_made = False
# Create API client.
credentials = service_account.Credentials.from_service_account_info(
st.secrets["gcp_service_account"]
)
client = bigquery.Client(credentials=credentials)
@st.cache_data(ttl=600)
def get_schema_manifest(_client: bigquery.Client, gameweek: int) -> dict:
dataset = f"{GCP_PROJECT_ID}.{BQ_DATASET_PREFIX}{gameweek}"
sql = f"""
SELECT table_name, column_name, data_type
FROM `{dataset}.INFORMATION_SCHEMA.COLUMNS`
ORDER BY table_name, ordinal_position
"""
rows = client.query(sql).result()
manifest: dict[str, list[str]] = {}
for r in rows:
manifest.setdefault(r["table_name"], []).append(r["column_name"])
return {
"project": GCP_PROJECT_ID,
"dataset": f"{BQ_DATASET_PREFIX}{gameweek}",
"tables": manifest,
}
# Perform query.
# Uses st.cache_data to only rerun when the query changes or after 10 min.
@st.cache_data(ttl=600)
def run_query(query: str):
query_job = client.query(query)
rows_raw = query_job.result()
# Convert to list of dicts. Required for st.cache_data to hash the return value.
rows = [dict(row) for row in rows_raw]
return rows
st.set_page_config(page_title="BigQuery Assistant", page_icon=" ")
st.title("BigQuery Assistant")
st.write(
"Ask a question about your data. The assistant will generate BigQuery SQL and run it for the selected Gameweek."
)
gw = st.number_input("Gameweek", min_value=1, max_value=38, value=26, step=1)
query = st.text_input(
"What would you like to find?",
"Top 10 most transferred-in players and their counts",
)
sql = None
results_rows = None
if st.button("Run Query"):
with st.spinner("Generating SQL and querying BigQuery. Please wait..."):
try:
schema = get_schema_manifest(client, int(gw))
schema_text = manifest_to_prompt(schema)
sql = asyncio.run(run_agent(query, int(gw), schema_text))
ok, err = dry_run_sql(client, sql)
if not ok:
# Try a single self-correction round by telling the agent the error
correction_prompt = (
query
+ "\n\nNote: Prior SQL failed validation with this BigQuery error. "
"Please correct using the provided schema only. Error: "
+ (err or "")
)
sql = asyncio.run(run_agent(correction_prompt, int(gw), schema_text))
ok2, err2 = dry_run_sql(client, sql)
if not ok2:
raise RuntimeError(err2 or err)
correction_made = True
results_rows = run_query(sql)
except Exception as e:
error_msg = str(e)
st.error(error_msg)
if correction_made:
st.warning(
"⚠️ Initial SQL required correction after validation error. "
"The corrected version was executed successfully."
)
if sql:
with st.expander("Show generated SQL"):
st.code(sql, language="sql")
if results_rows is not None:
if len(results_rows) == 0:
st.info("Query returned no rows.")
else:
st.success(f"Returned {len(results_rows)} rows")
st.dataframe(results_rows)