-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
337 lines (285 loc) · 13.5 KB
/
app.py
File metadata and controls
337 lines (285 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import streamlit as st
import os
from google.cloud.alloydb.connector import Connector, IPTypes
import pg8000
import sqlalchemy
from sqlalchemy import text
import vertexai
from vertexai.language_models import TextEmbeddingModel
from google.cloud import storage
import pandas as pd
from hybrid_search import hybrid_search
from agent_assist import AgentAssist
import chat_utils
# --- Configuration ---
PROJECT_ID = "alloydbtest-374215"
REGION = "us-central1"
CLUSTER = "hybridsearch"
INSTANCE = "hybridsearch-primary"
DATABASE = "ecom"
DB_USER = "postgres"
DEFAULT_DB_PASS = os.environ.get("DB_PASSWORD", "")
# Vertex AI Model
EMBEDDING_MODEL_NAME = "text-embedding-005"
# --- Setup ---
st.set_page_config(page_title="Virtual Retail Sales Rep", layout="wide")
if "messages" not in st.session_state:
st.session_state.messages = []
# Initialize Vertex AI
if "vertex_initialized" not in st.session_state:
try:
vertexai.init(project=PROJECT_ID, location=REGION)
st.session_state.vertex_initialized = True
except Exception as e:
st.error(f"Failed to initialize Vertex AI: {e}")
# --- Helpers ---
@st.cache_resource
def get_db_connection(password):
connector = Connector()
def getconn():
conn = connector.connect(
f"projects/{PROJECT_ID}/locations/{REGION}/clusters/{CLUSTER}/instances/{INSTANCE}",
"pg8000",
user=DB_USER,
password=password,
db=DATABASE,
ip_type=IPTypes.PUBLIC
)
return conn
pool = sqlalchemy.create_engine(
"postgresql+pg8000://",
creator=getconn,
)
return pool
@st.cache_resource
def get_storage_client():
return storage.Client(project=PROJECT_ID)
import urllib.parse
def get_image_url(gcs_uri, product_name="Product"):
"""Converts gs:// URI to bytes for Streamlit display, or returns placeholder."""
if not gcs_uri:
return f"https://placehold.co/400x400?text={urllib.parse.quote(product_name)}"
# Rewrite inaccessible bucket to public mirror
if "genwealth-gen-vid/product-images" in gcs_uri:
# Return public HTTP URL directly
return gcs_uri.replace(
"gs://genwealth-gen-vid/product-images/",
"https://storage.googleapis.com/pr-public-demo-data/alloydb-retail-demo/product-images-branded/"
)
if "pr-public-demo-data/alloydb-retail-demo" in gcs_uri:
return gcs_uri.replace("gs://", "https://storage.googleapis.com/")
if not gcs_uri.startswith("gs://"):
return gcs_uri
try:
client = get_storage_client()
parts = gcs_uri[5:].split("/", 1)
if len(parts) != 2:
return f"https://placehold.co/400x400?text={urllib.parse.quote(product_name)}"
bucket_name, blob_name = parts
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name)
return blob.download_as_bytes()
except Exception as e:
# Fallback to placeholder if GCS fails
print(f"Failed to load {gcs_uri}: {e}")
return f"https://placehold.co/400x400?text={urllib.parse.quote(product_name)}"
def get_embedding(text_input):
st.sidebar.write("Generating embedding...")
model = TextEmbeddingModel.from_pretrained(EMBEDDING_MODEL_NAME)
embeddings = model.get_embeddings([text_input])
st.sidebar.write("Embedding generated.")
return embeddings[0].values
def search_products(query_text, db_pool):
try:
embedding = get_embedding(query_text)
st.sidebar.write("Executing SQL query...")
sql = text("""
SELECT
p.id,
p.name,
p.product_description,
p.product_image_uri,
p.retail_price,
(p.product_description_embedding <=> CAST(:embedding AS vector(768))) as distance,
COALESCE(s.qty, 0) as stock_qty
FROM public.products p
LEFT JOIN LATERAL (
SELECT COUNT(*) as qty
FROM public.inventory_items i
WHERE i.product_id = p.id AND i.sold_at IS NULL
) s ON true
ORDER BY distance ASC
LIMIT 5
""")
with db_pool.connect() as conn:
st.sidebar.write("Connection acquired.")
result = conn.execute(sql, {"embedding": str(embedding)})
st.sidebar.write("Query executed.")
rows = result.fetchall()
st.sidebar.write(f"Found {len(rows)} rows.")
return rows
except Exception as e:
st.error(f"Search failed: {e}")
return []
# --- UI Layout ---
st.title("🛍️ Virtual Retail Sales Rep")
# Sidebar for Credentials & Debug
with st.sidebar:
st.title("Settings")
# Database Password Logic
if "DB_PASSWORD" in os.environ and os.environ["DB_PASSWORD"]:
db_pass_input = os.environ["DB_PASSWORD"]
st.success("Database Configured (Env)")
else:
db_pass_input = st.text_input("Database Password", value=DEFAULT_DB_PASS, type="password")
if st.button("Clear Chat"):
st.session_state.messages = []
show_sql = st.checkbox("Show Debug Info (SQL/Explain)", value=False)
# Initialize Agent Assist
# Initialize AgentAssist (Fresh every time to pick up code changes)
# if "agent_assist" not in st.session_state:
st.session_state.agent_assist = AgentAssist(PROJECT_ID, REGION)
# st.toast("Agent Assist Reloaded")
if show_sql and "last_sql" in st.session_state:
with st.expander("Debug: Last Executed SQL"):
st.code(st.session_state.last_sql, language="sql")
# Default to display params if available, else empty
display_params = st.session_state.get("last_params_display", {})
st.write("Params:", display_params)
if "last_rerank_debug" in st.session_state:
st.write("Rerank Debug:", st.session_state.last_rerank_debug)
if "last_rerank_error" in st.session_state:
st.error(f"Rerank Error: {st.session_state.last_rerank_error}")
if "last_rerank_sql" in st.session_state:
with st.expander("Show Rerank SQL"):
st.code(st.session_state.last_rerank_sql, language="sql")
if st.button("Run EXPLAIN (ANALYZE, JSON) for Rerank"):
try:
rerank_params = st.session_state.get("last_rerank_params", {})
if not rerank_params:
st.error("No parameters found for Rerank EXPLAIN.")
else:
db_pool = get_db_connection(db_pass_input)
# Prepend EXPLAIN ...
explain_sql = text("EXPLAIN (ANALYZE, VERBOSE, COSTS, SETTINGS, BUFFERS, WAL, TIMING, SUMMARY, FORMAT JSON) " + st.session_state.last_rerank_sql)
with db_pool.connect() as conn:
result = conn.execute(explain_sql, rerank_params)
rows = result.fetchall()
if rows:
st.json(rows[0][0])
else:
st.warning("EXPLAIN returned no rows.")
except Exception as e:
st.error(f"Rerank Explain failed: {e}")
if st.button("Run EXPLAIN (ANALYZE, JSON)"):
try:
# Use the real params for execution
real_params = st.session_state.get("last_params_real", {})
if not real_params:
st.error("No parameters found for EXPLAIN. Please run a search first.")
else:
db_pool = get_db_connection(db_pass_input)
# Prepend EXPLAIN ... to the last SQL
explain_sql = text("EXPLAIN (ANALYZE, VERBOSE, COSTS, SETTINGS, BUFFERS, WAL, TIMING, SUMMARY, FORMAT JSON) " + st.session_state.last_sql)
with db_pool.connect() as conn:
result = conn.execute(explain_sql, real_params)
# The result of EXPLAIN (FORMAT JSON) is usually a single row with a JSON object
# It might be returning a list of rows where the first column is the JSON line?
# In Postgres via SQLAlchemy, it often returns a valid JSON object or string.
rows = result.fetchall()
if rows:
# Postgres returns query plan as a JSON object in the first column
plan = rows[0][0]
st.json(plan)
else:
st.warning("EXPLAIN returned no rows.")
except Exception as e:
st.error(f"Explain failed: {e}")
# Chat Interface
# --- Logic ---
def handle_query(prompt):
"""Processes a user query: search, chat response, suggestions."""
st.session_state.messages.append({"role": "user", "content": prompt})
# Force a rerun to display the user message immediately?
# No, we can just process and append assistant response, then rerun or let Streamlit handle it.
# If called from button, we might want to rerun to show the user message first?
# Actually, appending to session state and then calculating result is fine.
if not db_pass_input:
st.error("Please enter the Database Password in the sidebar.")
return
try:
db_pool = get_db_connection(db_pass_input)
# 0. Contextual Query Rewriting
rewritten_prompt = prompt
# Only rewrite if we have history
if st.session_state.messages:
with st.spinner("Refining search..."):
rewritten_prompt = st.session_state.agent_assist.rewrite_query(prompt, st.session_state.messages)
# Display debug info if changed
if rewritten_prompt.lower() != prompt.lower():
st.toast(f"Rewrote query: '{prompt}' -> '{rewritten_prompt}'")
print(f"Rewrote query: '{prompt}' -> '{rewritten_prompt}'")
# 1. Search
with st.spinner("Analyzing your request..."):
results = hybrid_search(rewritten_prompt, db_pool, top_k=6)
# 2. Chat Response
with st.spinner("Thinking..."):
bot_response = chat_utils.generate_response(prompt, results)
# 3. Generate Suggestions
suggestions = []
# with st.spinner("Generating suggestions..."):
# # Pass history (excluding current user message which is already in prompt, but beneficial to have full context)
# # Actually st.session_state.messages already includes the latest user message at this point
# analysis = st.session_state.agent_assist.analyze_intent(prompt, results, st.session_state.messages)
# suggestions = [line.strip() for line in analysis.split('\n') if line.strip()]
message_data = {
"role": "assistant",
"content": bot_response,
"results": results,
"suggestions": suggestions
}
st.session_state.messages.append(message_data)
except Exception as e:
st.error(f"Connection error: {e}")
# Chat Interface
# Display history
for msg_idx, message in enumerate(st.session_state.messages):
with st.chat_message(message["role"]):
st.markdown(message["content"])
if message["role"] == "assistant":
# Results
if "results" in message and message["results"]:
with st.container():
st.markdown("### Finding the best matches for you:")
# Use 2 columns instead of 3 for larger images
cols = st.columns(2)
for i, row in enumerate(message["results"]):
with cols[i % 2]:
p_name = row.name
p_img = row.product_image_uri
p_price = row.retail_price
with st.container(border=True):
if p_img:
display_img = get_image_url(p_img, p_name)
if display_img.startswith("http"):
st.image(display_img, width=400)
else:
st.image(display_img, width=400)
else:
st.image("https://placehold.co/400x400?text=No+Image", width=400)
st.markdown(f"**{p_name}**")
st.markdown(f"${p_price}")
st.button("Add to cart", key=f"add_{msg_idx}_{i}")
# Suggestions
# if "suggestions" in message and message["suggestions"]:
# st.markdown("### Suggested searches:")
# s_cols = st.columns(len(message["suggestions"]))
# for i, sugg in enumerate(message["suggestions"]):
# # Button callback to trigger search
# if s_cols[i].button(sugg, key=f"sugg_{msg_idx}_{i}"):
# handle_query(sugg)
# st.rerun()
# Input
if prompt := st.chat_input("What are you looking for today?"):
handle_query(prompt)
st.rerun()