-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
435 lines (353 loc) · 17.6 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
import streamlit as st
import json
import os
from PIL import Image
import pandas as pd
import altair as chart
from datetime import datetime
import nltk
from query_db import *
from search_metrics import *
# Load metadata for images
with open("parsed_birds/parsing_metadata.json", 'r') as f:
parsing_metadata = json.load(f)
# Search input
# Quick description of the app, what you can do, and how to use it
# set the app to wide
st.title("Bird Search")
st.logo("Pinecone-Primary-Logo-Black.png")
st.markdown("""
## Welcome to Bird Search!
This app allows you to search for birds using natural language queries across different search methodologies.
""")
with st.expander("Why Different Search Methods?"):
st.markdown("""
Each search method works in a unique way:
- **Keyword Search**: Finds birds based on exact word matches (like Google in the early days)
- **Dense Search**: Understands meaning even with different words (like asking "large birds" and finding "massive avians")
- **Sparse Search**: Balances exact keywords and some understanding of meaning
- **Cascading Retrieval**: Our advanced approach that combines methods for optimal results
""")
with st.expander("How to Use This App:"):
st.markdown("""
1. Enter your query about birds in the search box above
2. Browse results across the different tabs
3. Mark results as relevant by checking the boxes. Results that are un-marked are considered irrelevant by default
4. Click "Log Annotations" to record your evaluations, on each tab.
5. View the Metrics tab to see which search method performed best
""")
with st.expander("Why Annotate Results?"):
st.markdown("""
Your annotations help us understand which search methods work best for different types of questions! The app calculates:
- **Mean Average Precision**: How well each method ranks relevant results higher than irrelevant results, across queries.
- **Relevant Birds**: How many unique relevant bird species each method finds. This is an example of a "buisiness" metric that, while not objective,
is a good proxy for how novel the search results are per query
""")
with st.expander("Example Queries to Try:"):
st.markdown("""
- **"Birds that live in Illinois"**, a query that works great for keyword, sparse searches.
- **"Birds that are bad at flying"**, a query that works great for dense searches, but not so much for keyword.
- **"descriptions of really large birds"**, a query that works great for cascading, exposing the strengths of the reranker in re-prioritizing results.
- **"Big bird red head black wings that pecks wood"**, a good example for finding "woodpeckers" without using that word!
- **"Colorful birds that live in the Midwestern United States"**, a fun query that works alright across all methods, and exposes the weaknesses of a simple chunking method.
""")
st.markdown("Have fun exploring!")
query = st.text_input("Enter your search query")
# Initialize session state variables for annotations
if 'annotations_df' not in st.session_state:
st.session_state.annotations_df = pd.DataFrame(columns=[
'timestamp', 'query', 'method', 'bird', 'rank', 'is_relevant', 'score',
'chunk_id', 'chunk_text'
])
# Function to handle dataframe updates
def update_annotations_df(edited_df):
# Ensure all columns have consistent types
if 'is_relevant' in edited_df.columns:
edited_df['is_relevant'] = edited_df['is_relevant'].astype(bool)
if 'rank' in edited_df.columns:
edited_df['rank'] = edited_df['rank'].astype(int)
if 'score' in edited_df.columns:
edited_df['score'] = edited_df['score'].astype(float)
st.session_state.annotations_df = edited_df
st.success("Annotations updated!")
def calculate_metrics(annotations_df, query=None, method=None):
"""
Calculate search metrics from annotations
Args:
annotations_df: DataFrame with annotations
query: If provided, filter by this query, otherwise use all queries
method: Search method ("Dense" or "Sparse")
"""
# Filter by query and method if specified
filtered_df = annotations_df.copy()
if query:
filtered_df = filtered_df[filtered_df['query'] == query]
if method:
filtered_df = filtered_df[filtered_df['method'] == method]
# Return zeros if no data after filtering
if filtered_df.empty:
return {
"unique_relevant_birds": 0,
"mean_average_precision": 0.0
}
return {
"unique_relevant_birds": get_unique_relevant_birds(filtered_df),
"mean_average_precision": calculate_mean_average_precision(filtered_df)
}
def visualize_metrics(query=None):
if st.session_state.annotations_df.empty:
st.info("No annotations available yet. Mark search results as relevant and log scores.")
return
# Get unique queries and methods for the selector
all_queries = st.session_state.annotations_df['query'].unique()
all_methods = st.session_state.annotations_df['method'].unique()
# Add selector for which query to visualize metrics for
col1, col2 = st.columns([3, 1])
with col1:
metric_query = st.selectbox(
"Select query to view metrics for:",
["All Queries"] + list(all_queries),
key="metric_query_selector"
)
with col2:
st.write("") # Spacer
st.write("") # Spacer
show_details = st.checkbox("Show details", value=True, key="show_metrics_details")
# Determine which query to use
selected_query = None if metric_query == "All Queries" else metric_query
# Get metrics for all available methods
method_metrics = {}
for method in all_methods:
method_metrics[method] = calculate_metrics(st.session_state.annotations_df, selected_query, method)
# Create metrics DataFrame dynamically based on available methods
methods_list = []
metrics_list = []
values_list = []
for method in all_methods:
# Add relevant birds
methods_list.append(method)
metrics_list.append('Relevant Birds')
values_list.append(int(method_metrics[method]['unique_relevant_birds']))
# Add avg precision
methods_list.append(method)
metrics_list.append('Mean Average Precision')
values_list.append(float(method_metrics[method]['mean_average_precision']))
metrics_data = pd.DataFrame({
'Method': methods_list,
'Metric': metrics_list,
'Value': values_list
})
# Create bar chart
title = f"Search Performance Metrics - {metric_query}"
st.subheader(title)
c = chart.Chart(metrics_data).mark_bar().encode(
x=chart.X('Method:N'),
y=chart.Y('Value:Q'),
color='Method:N',
column='Metric:N'
).properties(width=100)
st.altair_chart(c)
# show results in a table
st.subheader("Metrics Comparison")
# Create metrics comparison table dynamically
metrics_comparison_data = {
'Metric': ['Relevant Birds', 'Mean Average Precision']
}
for method in all_methods:
metrics_comparison_data[method] = [
int(method_metrics[method]['unique_relevant_birds']),
f"{float(method_metrics[method]['mean_average_precision']):.3f}"
]
metrics_comparison = pd.DataFrame(metrics_comparison_data)
st.table(metrics_comparison)
# Show query distribution
if metric_query == "All Queries":
st.subheader("Annotations per Query")
query_counts = st.session_state.annotations_df.groupby(['query', 'method']).size().reset_index(name='count')
# Ensure count column is integer
query_counts['count'] = query_counts['count'].astype(int)
# Create a grouped bar chart of queries
query_chart = chart.Chart(query_counts).mark_bar().encode(
x=chart.X('query:N', title='Query'),
y=chart.Y('count:Q', title='Number of Annotations'),
color='method:N',
column='method:N'
).properties(width=min(80 * len(all_queries), 400))
st.altair_chart(query_chart)
def highlight_matching_words(text: str, query: str) -> str:
"""
Returns text with words that match the query highlighted in markdown bold.
Args:
text (str): The original text to process
query (str): The search query
Returns:
str: Markdown formatted text with matching words in bold
"""
# Split query into words and create a set of lowercase words
query_words = set(query.lower().split())
# Split the text into words, keeping track of original words
words = text.split()
# Create markdown by bolding matching words
markdown_words = []
for word in words:
# Compare lowercase versions for matching, stripping punctuation
if word.lower().strip('.,!?()[]{};"\'') in query_words:
markdown_words.append(f"**{word}**")
else:
markdown_words.append(word)
# Join back into text
return ' '.join(markdown_words)
def display_search_results(results, query, title, container, method):
with container:
st.header(title)
unique_birds = set()
# Add a "Log Annotations" button at the top
log_key = f"log_{method.lower()}"
# Dictionary to store checkbox states for this method
if f"{method}_relevance" not in st.session_state:
st.session_state[f"{method}_relevance"] = {}
for i, hit in enumerate(results):
bird = hit['fields']['bird']
text = hit['fields']['chunk_text']
score = hit['_score']
chunk_id = hit.get('id', f"{bird}_chunk_{i}")
# Use index and bird name for the checkbox key instead of hit_id
checkbox_key = f"{method}_{bird}_{i}"
unique_birds.add(bird)
with st.expander(f"{bird} (Score: {score:.2f})"):
st.write(highlight_matching_words(text, query))
if bird in parsing_metadata:
bird_data = parsing_metadata[bird]
if bird_data['images']:
image_path = os.path.join("parsed_birds/images", bird_data['images'][0]['local_path'])
if os.path.exists(image_path):
image = Image.open(image_path)
st.image(image, caption=bird)
# Add relevance checkbox for this result
st.session_state[f"{method}_relevance"][checkbox_key] = st.checkbox(
"Mark as relevant",
key=checkbox_key,
value=st.session_state.get(checkbox_key, False)
)
# Log Annotations button at the bottom
if st.button(f"Log {method} Annotations", key=log_key):
# Collect all annotations for this query and method
new_annotations = []
for i, hit in enumerate(results):
bird = hit['fields']['bird']
text = hit['fields']['chunk_text']
score = hit['_score']
chunk_id = hit.get('id', f"{bird}_chunk_{i}")
# Use the same checkbox key format
checkbox_key = f"{method}_{bird}_{i}"
is_relevant = st.session_state[f"{method}_relevance"].get(checkbox_key, False)
new_annotations.append({
'timestamp': datetime.now(),
'query': query,
'method': method,
'bird': bird,
'rank': i+1, # 1-based rank
'is_relevant': is_relevant,
'score': score,
'chunk_id': chunk_id,
'chunk_text': text
})
# Add to annotations dataframe
if new_annotations:
new_df = pd.DataFrame(new_annotations)
# Remove any previous annotations for this query and method
mask = ~((st.session_state.annotations_df['query'] == query) &
(st.session_state.annotations_df['method'] == method))
st.session_state.annotations_df = pd.concat([
st.session_state.annotations_df[mask],
new_df
]).reset_index(drop=True)
st.success(f"Logged {len(new_annotations)} annotations for {method} search!")
st.subheader(f"Unique Birds Found in {title}")
st.write(", ".join(unique_birds))
return unique_birds
## Core Mechanisms for App
if query:
dense_results = query_integrated_inference(query, "dense-bird-search")
sparse_results = query_integrated_inference(query, "sparse-bird-search")
try:
# issues with nltk downloading in cloud
bm25_results = query_bm25(query, "bm25-bird-search")
except LookupError as e:
nltk.download('punkt_tab')
st.rerun()
cascading_results = conduct_cascading_retrieval(query)
# Tabs for dense vs sparse results
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Keyword Search Results", "Dense Search Results", "Sparse Search Results", "Cascading Retrieval Results", "Metrics & Annotations"])
with tab1:
st.write("This method uses BM25 over the entire corpus to retrieve results.")
with tab2:
st.write("This method uses a dense embedding model called multilingual-e5-large to retrieve results.")
with tab3:
st.write("This method uses sparse retrieval using a proprietary model called pinecone-sparse-english-v0 to retrieve results.")
with tab4:
st.write("This method uses cascading retrieval, reranking over the dense and sparse results to retrieve results.")
st.write("The reranker used is Cohere's Rerank 3.5 Model, which supports reasoning over reranked results.")
with tab5:
st.write("This tab shows the metrics and annotations for the search results.")
st.write("This will calculate two metrics:")
st.write("1. Mean Average Precision (MAP) - This is a measure of how many relevant results are returned, as a function of the rank of the result.")
st.write("2. Relevant Birds - This is the number of unique relevant birds in the search results. This is cool, as it tells us how many new birds we learn about!")
st.write("You can also look at the annotations for each method, and modify, download them as needed.")
st.write("Try annotating a set of results to see which methods are most effective!")
unique_bm25_birds = display_search_results(bm25_results, query, "Keyword Search Results", tab1, "Keyword")
unique_dense_birds = display_search_results(dense_results, query, "Dense Search Results", tab2, "Dense")
unique_sparse_birds = display_search_results(sparse_results, query, "Sparse Search Results", tab3, "Sparse")
unique_cascading_birds = display_search_results(cascading_results, query, "Cascading Retrieval Results", tab4, "Cascading")
# Metrics and annotations tab
with tab5:
visualize_metrics()
st.subheader("Annotation History")
if not st.session_state.annotations_df.empty:
# Use data_editor instead of dataframe for interactive editing
st.subheader("Edit Annotations")
# Make a copy to avoid direct reference issues
edited_df = st.data_editor(
st.session_state.annotations_df,
num_rows="dynamic",
column_config={
"timestamp": st.column_config.DatetimeColumn(
"Timestamp",
help="When the annotation was created",
format="D MMM YYYY, h:mm a",
),
"is_relevant": st.column_config.CheckboxColumn(
"Relevant?",
help="Check if this result is relevant to the query",
),
"score": st.column_config.NumberColumn(
"Score",
help="Search score",
format="%.3f",
),
"chunk_text": st.column_config.TextColumn(
"Chunk Text",
help="The text content of this chunk",
width="large",
),
},
disabled=["timestamp", "chunk_text"],
key="annotation_editor",
on_change=update_annotations_df,
args=(st.session_state.annotations_df,),
)
# Update the session state with the edited dataframe
if not edited_df.equals(st.session_state.annotations_df):
st.session_state.annotations_df = edited_df
st.rerun() # Rerun to update metrics
# Option to download annotations
csv = st.session_state.annotations_df.to_csv(index=False)
st.download_button(
"Download Annotations CSV",
csv,
"bird_search_annotations.csv",
"text/csv",
key="download-csv"
)
else:
st.info("No annotations logged yet. Use the checkboxes and 'Log Annotations' buttons to evaluate search results.")