Skip to content

Commit 0483f3a

Browse files
authored
Merge pull request #9 from KillrVideo/patrick/merge-local-changes
Add data generator, loaders, vector compatibility tests, and documentation
2 parents e1a8685 + bf4b848 commit 0483f3a

5 files changed

Lines changed: 3430 additions & 22 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ __pycache__/
5656
*.pyc
5757

5858
# Editor files
59-
.vscode
59+
.vscode/
6060
.idea/
6161
*.swp
6262
*.swo

loaders/astra-tables/load_data_cql.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,6 @@ def load_user_credentials(session):
212212

213213
def load_videos(session):
214214
"""Load videos_cleaned.csv or videos.csv."""
215-
#csv_file = DATA_DIR / "videos_cleaned.csv"
216-
#if not csv_file.exists():
217215
csv_file = DATA_DIR / "videos.csv"
218216

219217
logger.info(f"csv_file: {csv_file}")
@@ -365,7 +363,7 @@ def main():
365363
""")
366364
return
367365

368-
table = sys.argv[1]
366+
table = sys.argv[1] if len(sys.argv) > 1 else None
369367

370368
try:
371369
session, cluster = get_session()

loaders/astra-tables/load_with_embeddings.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,16 @@
2020

2121
import argparse
2222
import csv
23-
from curses import raw
2423
import json
2524
import sys
2625
import time
2726
import yaml
2827
import pickle
28+
from decimal import Decimal
2929
from pathlib import Path
3030
from typing import List, Dict, Any, Optional
3131
from datetime import datetime
3232
from uuid import UUID
33-
from decimal import Decimal
3433

3534
try:
3635
from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT
@@ -271,10 +270,9 @@ def generate_embedding_for_text(self, text: str, target_dimensions: int) -> Opti
271270
embedding = self.embedder.generate_embedding(text)
272271

273272
# Reduce to target dimensions
274-
#reduced = self.embedder.reduce_dimensions(embedding, target_dimensions)
273+
reduced = self.embedder.reduce_dimensions(embedding, target_dimensions)
275274

276-
#return reduced.flatten().tolist()
277-
return embedding
275+
return reduced.flatten().tolist()
278276

279277
except Exception as e:
280278
log_warning(f"Failed to generate embedding: {e}")
@@ -389,8 +387,8 @@ def load_table_with_embeddings(self, table_name: str, csv_path: Path) -> tuple[i
389387
rows_failed = 0
390388
embeddings_generated = 0
391389

392-
#try:
393-
with open(csv_path, 'r', encoding='utf-8') as f:
390+
try:
391+
with open(csv_path, 'r', encoding='utf-8') as f:
394392
reader = csv.DictReader(f)
395393
columns = list(reader.fieldnames)
396394

@@ -414,7 +412,7 @@ def load_table_with_embeddings(self, table_name: str, csv_path: Path) -> tuple[i
414412
batch = []
415413

416414
for row in reader:
417-
#try:
415+
try:
418416
parsed_values = []
419417
text_for_embedding = None
420418

@@ -441,12 +439,10 @@ def load_table_with_embeddings(self, table_name: str, csv_path: Path) -> tuple[i
441439

442440
# Generate embedding
443441
if text_for_embedding:
444-
#print("Got here")
445442
embedding = self.generate_embedding_for_text(
446443
text_for_embedding,
447444
vector_mappings[table_name]['dimensions']
448445
)
449-
#print("Got here2")
450446
parsed_values.append(embedding)
451447
if embedding is not None:
452448
embeddings_generated += 1
@@ -463,11 +459,9 @@ def load_table_with_embeddings(self, table_name: str, csv_path: Path) -> tuple[i
463459
# numeric
464460
value = row.get(col, '')
465461
if '.' in value:
466-
#float
467462
raw_value = float(Decimal(value))
468463
parsed_values.append(raw_value)
469464
else:
470-
#integer
471465
raw_value = int(value)
472466
parsed_values.append(raw_value)
473467
elif row.get(col, '').startswith("[") and row.get(col, '').endswith("]"):
@@ -502,9 +496,9 @@ def load_table_with_embeddings(self, table_name: str, csv_path: Path) -> tuple[i
502496
status += f" ({embeddings_generated} embeddings)"
503497
print(status + "...", end='\r')
504498

505-
#except Exception as e:
506-
# log_warning(f"Failed to process row: {e}")
507-
# rows_failed += 1
499+
except Exception as e:
500+
log_warning(f"Failed to process row: {e}")
501+
rows_failed += 1
508502

509503
# Execute remaining batch
510504
if batch:
@@ -522,9 +516,9 @@ def load_table_with_embeddings(self, table_name: str, csv_path: Path) -> tuple[i
522516

523517
return rows_loaded, rows_failed
524518

525-
#except Exception as e:
526-
# log_error(f"Failed to load {table_name}: {e}")
527-
# return rows_loaded, rows_failed
519+
except Exception as e:
520+
log_error(f"Failed to load {table_name}: {e}")
521+
return rows_loaded, rows_failed
528522

529523
def load_counter_table(self, table_name: str, csv_path: Path, update_query: str) -> tuple[int, int]:
530524
"""Load counter table (same as before)"""

loaders/astra-tables/setup_embeddings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def main():
198198
return 1
199199

200200

201+
201202
print()
202203

203204
# Verify schema compatibility

0 commit comments

Comments
 (0)