Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,23 @@ def create_diskann_embedding_server(
model_name: str = "sentence-transformers/all-mpnet-base-v2",
embedding_mode: str = "sentence-transformers",
distance_metric: str = "l2",
enable_warmup: bool = True,
):
"""
Create and start a ZMQ-based embedding server for DiskANN backend.
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.

Args:
passages_file: Path to the metadata file (.meta.json)
zmq_port: Port for ZMQ server
model_name: Name of the embedding model to use
embedding_mode: Embedding backend mode
distance_metric: Distance metric (l2, mips, cosine)
enable_warmup: If True, pre-load model and run warmup embedding on startup
"""
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
logger.info(f"Using embedding mode: {embedding_mode}")
logger.info(f"Warmup enabled: {enable_warmup}")

# Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent
Expand All @@ -72,6 +82,24 @@ def create_diskann_embedding_server(
finally:
sys.path.pop(0)

# Warmup: Pre-load the embedding model by computing a dummy embedding
# This ensures the model is cached and ready for fast subsequent queries
if enable_warmup:
warmup_start = time.time()
logger.info("Starting model warmup...")
try:
# Compute a dummy embedding to trigger model loading and caching
_ = compute_embeddings(
["warmup query for model preloading"],
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
warmup_time = time.time() - warmup_start
logger.info(f"Model warmup completed in {warmup_time:.2f}s")
except Exception as e:
logger.warning(f"Model warmup failed (non-fatal): {e}")

# Check port availability
import socket

Expand Down Expand Up @@ -479,14 +507,29 @@ def signal_handler(sig, frame):
choices=["l2", "mips", "cosine"],
help="Distance metric for similarity computation",
)
parser.add_argument(
"--enable-warmup",
action="store_true",
default=True,
help="Pre-load embedding model on startup for faster first query (default: True)",
)
parser.add_argument(
"--no-warmup",
action="store_true",
help="Disable warmup (lazy model loading)",
)

args = parser.parse_args()

# Determine warmup setting (--no-warmup takes precedence)
enable_warmup = not args.no_warmup

# Create and start the DiskANN embedding server
create_diskann_embedding_server(
passages_file=args.passages_file,
zmq_port=args.zmq_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
distance_metric=args.distance_metric,
enable_warmup=enable_warmup,
)
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,23 @@ def create_hnsw_embedding_server(
model_name: str = "sentence-transformers/all-mpnet-base-v2",
distance_metric: str = "mips",
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = True,
):
"""
Create and start a ZMQ-based embedding server for HNSW backend.
Simplified version using unified embedding computation module.

Args:
passages_file: Path to the metadata file (.meta.json)
zmq_port: Port for ZMQ server
model_name: Name of the embedding model to use
distance_metric: Distance metric (mips, l2, cosine)
embedding_mode: Embedding backend mode
enable_warmup: If True, pre-load model and run warmup embedding on startup
"""
logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
logger.info(f"Using embedding mode: {embedding_mode}")
logger.info(f"Warmup enabled: {enable_warmup}")

# Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent
Expand All @@ -85,6 +95,24 @@ def create_hnsw_embedding_server(
finally:
sys.path.pop(0)

# Warmup: Pre-load the embedding model by computing a dummy embedding
# This ensures the model is cached and ready for fast subsequent queries
if enable_warmup:
warmup_start = time.time()
logger.info("Starting model warmup...")
try:
# Compute a dummy embedding to trigger model loading and caching
_ = compute_embeddings(
["warmup query for model preloading"],
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
warmup_time = time.time() - warmup_start
logger.info(f"Model warmup completed in {warmup_time:.2f}s")
except Exception as e:
logger.warning(f"Model warmup failed (non-fatal): {e}")

# Check port availability
import socket

Expand Down Expand Up @@ -492,14 +520,29 @@ def signal_handler(sig, frame):
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
parser.add_argument(
"--enable-warmup",
action="store_true",
default=True,
help="Pre-load embedding model on startup for faster first query (default: True)",
)
parser.add_argument(
"--no-warmup",
action="store_true",
help="Disable warmup (lazy model loading)",
)

args = parser.parse_args()

# Determine warmup setting (--no-warmup takes precedence)
enable_warmup = not args.no_warmup

# Create and start the HNSW embedding server
create_hnsw_embedding_server(
passages_file=args.passages_file,
zmq_port=args.zmq_port,
model_name=args.model_name,
distance_metric=args.distance_metric,
embedding_mode=args.embedding_mode,
enable_warmup=enable_warmup,
)
75 changes: 57 additions & 18 deletions packages/leann-core/src/leann/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,25 +995,10 @@ def __init__(
)
self.bm25_scorer: Optional[BM25Scorer] = None

# Optional one-shot warmup at construction time to hide cold-start latency.
# Auto-warmup if requested - this pre-loads the embedding model
# to avoid cold-start latency on the first search
if self._warmup:
try:
_ = self.backend_impl.compute_query_embedding(
"__LEANN_WARMUP__",
use_server_if_available=self.recompute_embeddings,
)
except Exception as exc:
logger.warning(f"Warmup embedding failed (ignored): {exc}")

# Optional one-shot warmup at construction time to hide cold-start latency.
if self._warmup:
try:
_ = self.backend_impl.compute_query_embedding(
"__LEANN_WARMUP__",
use_server_if_available=self.recompute_embeddings,
)
except Exception as exc:
logger.warning(f"Warmup embedding failed (ignored): {exc}")
self.warmup()

def search(
self,
Expand Down Expand Up @@ -1376,6 +1361,60 @@ def __del__(self):
pass


def warmup(self, port: int = 5557) -> float:
"""Pre-warm the embedding server and model for faster subsequent searches.

This method starts the embedding server (if not already running) and
ensures the embedding model is loaded and cached. Call this before
your first search to avoid cold-start latency.

Args:
port: ZMQ port for the embedding server (default: 5557)

Returns:
Time taken for warmup in seconds

Example:
>>> searcher = LeannSearcher("path/to/index.leann")
>>> warmup_time = searcher.warmup()
>>> print(f"Warmup completed in {warmup_time:.2f}s")
>>> # Subsequent searches will be faster
>>> results = searcher.search("my query")
"""
import time

start_time = time.time()
logger.info("Starting LeannSearcher warmup...")

try:
# Start the embedding server with warmup enabled
# This triggers model loading in the server process
zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str,
port=port,
enable_warmup=True,
)

# Optionally, do a dummy query to ensure everything is fully warmed up
# This tests the full path including ZMQ communication
try:
_ = self.backend_impl.compute_query_embedding(
"__LEANN_WARMUP__",
use_server_if_available=self.recompute_embeddings,
)
except Exception as exc:
logger.warning(f"Warmup embedding failed during dummy query (soft fail): {exc}")

except Exception as e:
logger.error(f"Warmup failed: {e}")
# Don't raise, we want to allow search even if warmup fails
return 0.0

elapsed = time.time() - start_time
logger.info(f"Warmup completed in {elapsed:.4f}s")
return elapsed


class LeannChat:
def __init__(
self,
Expand Down
3 changes: 3 additions & 0 deletions packages/leann-core/src/leann/embedding_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ def _build_server_command(
command.extend(["--embedding-mode", embedding_mode])
if kwargs.get("distance_metric"):
command.extend(["--distance-metric", kwargs["distance_metric"]])
# Control warmup behavior - default is enabled, use --no-warmup to disable
if not kwargs.get("enable_warmup", True):
command.append("--no-warmup")

return command

Expand Down
Loading
Loading