Skip to content

Commit a85d0ad

Browse files
ww2283yichuan-w
andauthored
Feature/optimize ollama batching (#152)
* feat: add metadata output to search results - Add --show-metadata flag to display file paths in search results - Preserve document metadata (file_path, file_name, timestamps) during chunking - Update MCP tool schema to support show_metadata parameter - Enhance CLI search output to display metadata when requested - Fix pre-existing bug: args.backend -> args.backend_name Resolves #144 * fix: resolve ZMQ linking issues in Python extension - Use pkg_check_modules IMPORTED_TARGET to create PkgConfig::ZMQ - Set PKG_CONFIG_PATH to prioritize ARM64 Homebrew on Apple Silicon - Override macOS -undefined dynamic_lookup to force proper symbol resolution - Use PUBLIC linkage for ZMQ in faiss library for transitive linking - Mark cppzmq includes as SYSTEM to suppress warnings Fixes editable install ZMQ symbol errors while maintaining compatibility across Linux, macOS Intel, and macOS ARM64 platforms. * style: apply ruff formatting * chore: update faiss submodule to use ww2283 fork Use ww2283/faiss fork with fix/zmq-linking branch to resolve CI checkout failures. The ZMQ linking fixes are not yet merged upstream. * feat: implement true batch processing for Ollama embeddings Migrate from deprecated /api/embeddings to modern /api/embed endpoint which supports batch inputs. This reduces HTTP overhead by sending 32 texts per request instead of making individual API calls. Changes: - Update endpoint from /api/embeddings to /api/embed - Change parameter from 'prompt' (single) to 'input' (array) - Update response parsing for batch embeddings array - Increase timeout to 60s for batch processing - Improve error handling for batch requests Performance: - Reduces API calls by 32x (batch size) - Eliminates HTTP connection overhead per text - Note: Ollama still processes batch items sequentially internally Related: #151 * fall back to original faiss as i merge the PR --------- Co-authored-by: yichuan520030910320 <[email protected]>
1 parent dbb5f4d commit a85d0ad

File tree

4 files changed

+126
-64
lines changed

4 files changed

+126
-64
lines changed

packages/leann-backend-hnsw/CMakeLists.txt

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,25 @@ if(APPLE)
2929
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
3030
endif()
3131

32-
# Use system ZeroMQ instead of building from source
32+
# Find ZMQ using pkg-config with IMPORTED_TARGET for automatic target creation
3333
find_package(PkgConfig REQUIRED)
34-
pkg_check_modules(ZMQ REQUIRED libzmq)
34+
35+
# On ARM64 macOS, ensure pkg-config finds ARM64 Homebrew packages first
36+
if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
37+
set(ENV{PKG_CONFIG_PATH} "/opt/homebrew/lib/pkgconfig:/opt/homebrew/share/pkgconfig:$ENV{PKG_CONFIG_PATH}")
38+
endif()
39+
40+
pkg_check_modules(ZMQ REQUIRED IMPORTED_TARGET libzmq)
41+
42+
# This creates PkgConfig::ZMQ target automatically with correct properties
43+
if(TARGET PkgConfig::ZMQ)
44+
message(STATUS "Found and configured ZMQ target: PkgConfig::ZMQ")
45+
else()
46+
message(FATAL_ERROR "pkg_check_modules did not create IMPORTED target for ZMQ.")
47+
endif()
3548

3649
# Add cppzmq headers
37-
include_directories(third_party/cppzmq)
50+
include_directories(SYSTEM third_party/cppzmq)
3851

3952
# Configure msgpack-c - disable boost dependency
4053
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)

packages/leann-core/src/leann/cli.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@ def create_parser(self) -> argparse.ArgumentParser:
255255
action="store_true",
256256
help="Non-interactive mode: automatically select index without prompting",
257257
)
258+
search_parser.add_argument(
259+
"--show-metadata",
260+
action="store_true",
261+
help="Display file paths and metadata in search results",
262+
)
258263

259264
# Ask command
260265
ask_parser = subparsers.add_parser("ask", help="Ask questions")
@@ -1263,7 +1268,7 @@ def file_filter(
12631268
from .chunking_utils import create_text_chunks
12641269

12651270
# Use enhanced chunking with AST support
1266-
all_texts = create_text_chunks(
1271+
chunk_texts = create_text_chunks(
12671272
documents,
12681273
chunk_size=self.node_parser.chunk_size,
12691274
chunk_overlap=self.node_parser.chunk_overlap,
@@ -1274,6 +1279,14 @@ def file_filter(
12741279
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
12751280
)
12761281

1282+
# Note: AST chunking currently returns plain text chunks without metadata
1283+
# We preserve basic file info by associating chunks with their source documents
1284+
# For better metadata preservation, documents list order should be maintained
1285+
for chunk_text in chunk_texts:
1286+
# TODO: Enhance create_text_chunks to return metadata alongside text
1287+
# For now, we store chunks with empty metadata
1288+
all_texts.append({"text": chunk_text, "metadata": {}})
1289+
12771290
except ImportError as e:
12781291
print(
12791292
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
@@ -1285,17 +1298,27 @@ def file_filter(
12851298
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
12861299
# Check if this is a code file based on source path
12871300
source_path = doc.metadata.get("source", "")
1301+
file_path = doc.metadata.get("file_path", "")
12881302
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
12891303

1304+
# Extract metadata to preserve with chunks
1305+
chunk_metadata = {
1306+
"file_path": file_path or source_path,
1307+
"file_name": doc.metadata.get("file_name", ""),
1308+
}
1309+
1310+
# Add optional metadata if available
1311+
if "creation_date" in doc.metadata:
1312+
chunk_metadata["creation_date"] = doc.metadata["creation_date"]
1313+
if "last_modified_date" in doc.metadata:
1314+
chunk_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
1315+
12901316
# Use appropriate parser based on file type
12911317
parser = self.code_parser if is_code_file else self.node_parser
12921318
nodes = parser.get_nodes_from_documents([doc])
12931319

12941320
for node in nodes:
1295-
text_with_source = (
1296-
"Chunk source:" + source_path + "\n" + node.get_content().replace("\n", " ")
1297-
)
1298-
all_texts.append(text_with_source)
1321+
all_texts.append({"text": node.get_content(), "metadata": chunk_metadata})
12991322

13001323
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
13011324
return all_texts
@@ -1370,7 +1393,7 @@ async def build_index(self, args):
13701393

13711394
index_dir.mkdir(parents=True, exist_ok=True)
13721395

1373-
print(f"Building index '{index_name}' with {args.backend} backend...")
1396+
print(f"Building index '{index_name}' with {args.backend_name} backend...")
13741397

13751398
embedding_options: dict[str, Any] = {}
13761399
if args.embedding_mode == "ollama":
@@ -1382,7 +1405,7 @@ async def build_index(self, args):
13821405
embedding_options["api_key"] = resolved_embedding_key
13831406

13841407
builder = LeannBuilder(
1385-
backend_name=args.backend,
1408+
backend_name=args.backend_name,
13861409
embedding_model=args.embedding_model,
13871410
embedding_mode=args.embedding_mode,
13881411
embedding_options=embedding_options or None,
@@ -1393,10 +1416,8 @@ async def build_index(self, args):
13931416
num_threads=args.num_threads,
13941417
)
13951418

1396-
for chunk_text_with_source in all_texts:
1397-
chunk_source = chunk_text_with_source.split("\n")[0].split(":")[1]
1398-
chunk_text = chunk_text_with_source.split("\n")[1]
1399-
builder.add_text(chunk_text, {"source": chunk_source})
1419+
for chunk in all_texts:
1420+
builder.add_text(chunk["text"], metadata=chunk["metadata"])
14001421

14011422
builder.build_index(index_path)
14021423
print(f"Index built at {index_path}")
@@ -1517,6 +1538,23 @@ async def search_documents(self, args):
15171538
print(f"Search results for '{query}' (top {len(results)}):")
15181539
for i, result in enumerate(results, 1):
15191540
print(f"{i}. Score: {result.score:.3f}")
1541+
1542+
# Display metadata if flag is set
1543+
if args.show_metadata and result.metadata:
1544+
file_path = result.metadata.get("file_path", "")
1545+
if file_path:
1546+
print(f" 📄 File: {file_path}")
1547+
1548+
file_name = result.metadata.get("file_name", "")
1549+
if file_name and file_name != file_path:
1550+
print(f" 📝 Name: {file_name}")
1551+
1552+
# Show timestamps if available
1553+
if "creation_date" in result.metadata:
1554+
print(f" 🕐 Created: {result.metadata['creation_date']}")
1555+
if "last_modified_date" in result.metadata:
1556+
print(f" 🕑 Modified: {result.metadata['last_modified_date']}")
1557+
15201558
print(f" {result.text[:200]}...")
15211559
print(f" Source: {result.metadata.get('source', '')}")
15221560
print()

packages/leann-core/src/leann/embedding_compute.py

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -574,9 +574,10 @@ def compute_embeddings_ollama(
574574
host: Optional[str] = None,
575575
) -> np.ndarray:
576576
"""
577-
Compute embeddings using Ollama API with simplified batch processing.
577+
Compute embeddings using Ollama API with true batch processing.
578578
579-
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
579+
Uses the /api/embed endpoint which supports batch inputs.
580+
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
580581
581582
Args:
582583
texts: List of texts to compute embeddings for
@@ -681,11 +682,11 @@ def compute_embeddings_ollama(
681682
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
682683
model_name = resolved_model_name
683684

684-
# Verify the model supports embeddings by testing it
685+
# Verify the model supports embeddings by testing it with /api/embed
685686
try:
686687
test_response = requests.post(
687-
f"{resolved_host}/api/embeddings",
688-
json={"model": model_name, "prompt": "test"},
688+
f"{resolved_host}/api/embed",
689+
json={"model": model_name, "input": "test"},
689690
timeout=10,
690691
)
691692
if test_response.status_code != 200:
@@ -717,56 +718,55 @@ def compute_embeddings_ollama(
717718
# If torch is not available, use conservative batch size
718719
batch_size = 32
719720

720-
logger.info(f"Using batch size: {batch_size}")
721+
logger.info(f"Using batch size: {batch_size} for true batch processing")
721722

722723
def get_batch_embeddings(batch_texts):
723-
"""Get embeddings for a batch of texts."""
724-
all_embeddings = []
725-
failed_indices = []
724+
"""Get embeddings for a batch of texts using /api/embed endpoint."""
725+
max_retries = 3
726+
retry_count = 0
726727

727-
for i, text in enumerate(batch_texts):
728-
max_retries = 3
729-
retry_count = 0
728+
# Truncate very long texts to avoid API issues
729+
truncated_texts = [text[:8000] if len(text) > 8000 else text for text in batch_texts]
730730

731-
# Truncate very long texts to avoid API issues
732-
truncated_text = text[:8000] if len(text) > 8000 else text
733-
while retry_count < max_retries:
734-
try:
735-
response = requests.post(
736-
f"{resolved_host}/api/embeddings",
737-
json={"model": model_name, "prompt": truncated_text},
738-
timeout=30,
739-
)
740-
response.raise_for_status()
731+
while retry_count < max_retries:
732+
try:
733+
# Use /api/embed endpoint with "input" parameter for batch processing
734+
response = requests.post(
735+
f"{resolved_host}/api/embed",
736+
json={"model": model_name, "input": truncated_texts},
737+
timeout=60, # Increased timeout for batch processing
738+
)
739+
response.raise_for_status()
740+
741+
result = response.json()
742+
batch_embeddings = result.get("embeddings")
741743

742-
result = response.json()
743-
embedding = result.get("embedding")
744+
if batch_embeddings is None:
745+
raise ValueError("No embeddings returned from API")
744746

745-
if embedding is None:
746-
raise ValueError(f"No embedding returned for text {i}")
747+
if not isinstance(batch_embeddings, list):
748+
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
747749

748-
if not isinstance(embedding, list) or len(embedding) == 0:
749-
raise ValueError(f"Invalid embedding format for text {i}")
750+
if len(batch_embeddings) != len(batch_texts):
751+
raise ValueError(
752+
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
753+
)
754+
755+
return batch_embeddings, []
750756

751-
all_embeddings.append(embedding)
752-
break
757+
except requests.exceptions.Timeout:
758+
retry_count += 1
759+
if retry_count >= max_retries:
760+
logger.warning(f"Timeout for batch after {max_retries} retries")
761+
return None, list(range(len(batch_texts)))
753762

754-
except requests.exceptions.Timeout:
755-
retry_count += 1
756-
if retry_count >= max_retries:
757-
logger.warning(f"Timeout for text {i} after {max_retries} retries")
758-
failed_indices.append(i)
759-
all_embeddings.append(None)
760-
break
763+
except Exception as e:
764+
retry_count += 1
765+
if retry_count >= max_retries:
766+
logger.error(f"Failed to get embeddings for batch: {e}")
767+
return None, list(range(len(batch_texts)))
761768

762-
except Exception as e:
763-
retry_count += 1
764-
if retry_count >= max_retries:
765-
logger.error(f"Failed to get embedding for text {i}: {e}")
766-
failed_indices.append(i)
767-
all_embeddings.append(None)
768-
break
769-
return all_embeddings, failed_indices
769+
return None, list(range(len(batch_texts)))
770770

771771
# Process texts in batches
772772
all_embeddings = []
@@ -784,7 +784,7 @@ def get_batch_embeddings(batch_texts):
784784
num_batches = (len(texts) + batch_size - 1) // batch_size
785785

786786
if show_progress:
787-
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
787+
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
788788
else:
789789
batch_iterator = range(num_batches)
790790

@@ -795,10 +795,14 @@ def get_batch_embeddings(batch_texts):
795795

796796
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
797797

798-
# Adjust failed indices to global indices
799-
global_failed = [start_idx + idx for idx in batch_failed]
800-
all_failed_indices.extend(global_failed)
801-
all_embeddings.extend(batch_embeddings)
798+
if batch_embeddings is not None:
799+
all_embeddings.extend(batch_embeddings)
800+
else:
801+
# Entire batch failed, add None placeholders
802+
all_embeddings.extend([None] * len(batch_texts))
803+
# Adjust failed indices to global indices
804+
global_failed = [start_idx + idx for idx in batch_failed]
805+
all_failed_indices.extend(global_failed)
802806

803807
# Handle failed embeddings
804808
if all_failed_indices:

packages/leann-core/src/leann/mcp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def handle_request(request):
6060
"maximum": 128,
6161
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
6262
},
63+
"show_metadata": {
64+
"type": "boolean",
65+
"default": False,
66+
"description": "Include file paths and metadata in search results. Useful for understanding which files contain the results.",
67+
},
6368
},
6469
"required": ["index_name", "query"],
6570
},
@@ -104,6 +109,8 @@ def handle_request(request):
104109
f"--complexity={args.get('complexity', 32)}",
105110
"--non-interactive",
106111
]
112+
if args.get("show_metadata", False):
113+
cmd.append("--show-metadata")
107114
result = subprocess.run(cmd, capture_output=True, text=True)
108115

109116
elif tool_name == "leann_list":

0 commit comments

Comments
 (0)