diff --git a/python/python/lance/__init__.py b/python/python/lance/__init__.py index 0b905457a50..13044623cca 100644 --- a/python/python/lance/__init__.py +++ b/python/python/lance/__init__.py @@ -33,6 +33,17 @@ bytes_read_counter, iops_counter, ) +from .mem_wal import ( + ExecutionPlan, + LsmPointLookupPlanner, + LsmScanner, + LsmVectorSearchPlanner, + MergedGeneration, + RegionField, + RegionSnapshot, + RegionSpec, + RegionWriter, +) from .namespace import ( DescribeTableRequest, LanceNamespace, @@ -83,6 +94,15 @@ "set_logger", "write_dataset", "FFILanceTableProvider", + "ExecutionPlan", + "LsmPointLookupPlanner", + "LsmScanner", + "LsmVectorSearchPlanner", + "MergedGeneration", + "RegionField", + "RegionSpec", + "RegionSnapshot", + "RegionWriter", ] diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index c73a7b5a724..7c90c77df9b 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -77,6 +77,7 @@ from lance.namespace import LanceNamespace + from . import mem_wal from .commit import CommitLock from .io import StorageOptionsProvider from .lance.indices import IndexDescription @@ -411,6 +412,25 @@ def analyze_plan( reader = _coerce_reader(data_obj, schema) return super(MergeInsertBuilder, self).analyze_plan(reader) + def mark_generations_as_merged( + self, generations: "List[mem_wal.MergedGeneration]" + ) -> "MergeInsertBuilder": + """Mark MemWAL generations as merged into the base table. + + Call this before executing the merge_insert when the source data + includes rows from MemWAL flushed generations. + + Parameters + ---------- + generations : list of MergedGeneration + Generations to mark as merged. + """ + from .mem_wal import _to_raw_merged_generations + + raw_gens = _to_raw_merged_generations(generations) + super(MergeInsertBuilder, self).mark_generations_as_merged(raw_gens) + return self + class LanceDataset(pa.dataset.Dataset): """A Lance Dataset in Lance format where the data is stored at the given uri.""" @@ -4021,6 +4041,187 @@ def centroids( return ivf.centroids + def initialize_mem_wal( + self, + *, + maintained_indexes: Optional[List[str]] = None, + region_spec: Optional["mem_wal.RegionSpec"] = None, + ) -> None: + """Initialize MemWAL on this dataset. + + Must be called once before any calls to `mem_wal_writer`. + The dataset schema must have at least one field annotated with + the ``lance-schema:unenforced-primary-key`` Arrow field metadata. + + Parameters + ---------- + maintained_indexes : list of str, optional + Names of existing vector indexes to keep updated as data is + written through the MemWAL. Must reference indexes that + already exist on the dataset. + region_spec : RegionSpec, optional + Partitioning specification for automatic region routing. + When provided, Lance will derive a region identifier from each + written row according to the spec and route writes to the + correct `~lance.mem_wal.RegionWriter` automatically. + When ``None`` (default), the caller must manage region IDs + manually by passing them to `mem_wal_writer`. + + Raises + ------ + IOError + - Dataset has no ``lance-schema:unenforced-primary-key`` field. + - An entry in *maintained_indexes* does not exist on the dataset. + - MemWAL has already been initialized on this dataset. + + Examples + -------- + Without region spec (manual region management): + + import lance + import pyarrow as pa + import tempfile + schema = pa.schema([ + ... pa.field("id", pa.int64(), nullable=False, + ... metadata={"lance-schema:unenforced-primary-key": "true"}), + ... pa.field("val", pa.float32()), + ... ]) + table = pa.table({"id": [1], "val": [0.1]}, schema=schema) + with tempfile.TemporaryDirectory() as tmpdir: + ds = lance.write_dataset(table, tmpdir) + ds.initialize_mem_wal() + + With a region spec for automatic routing by ``tenant_id``: + + from lance.mem_wal import RegionField, RegionSpec + spec = RegionSpec( + ... spec_id=1, + ... fields=[RegionField(field_id="tenant_id", source_ids=[0], + ... result_type="int64")], + ... ) + ds.initialize_mem_wal(region_spec=spec) + """ + self._ds.initialize_mem_wal( + maintained_indexes=maintained_indexes, + region_spec=region_spec, + ) + + def mem_wal_writer( + self, + region_id: str, + *, + durable_write: Optional[bool] = None, + sync_indexed_write: Optional[bool] = None, + max_wal_buffer_size: Optional[int] = None, + max_wal_flush_interval_ms: Optional[int] = None, + max_memtable_size: Optional[int] = None, + max_memtable_rows: Optional[int] = None, + max_memtable_batches: Optional[int] = None, + max_unflushed_memtable_bytes: Optional[int] = None, + ivf_index_partition_capacity_safety_factor: Optional[int] = None, + manifest_scan_batch_size: Optional[int] = None, + async_index_buffer_rows: Optional[int] = None, + async_index_interval_ms: Optional[int] = None, + backpressure_log_interval_ms: Optional[int] = None, + stats_log_interval_ms: Optional[int] = None, + ) -> "mem_wal.RegionWriter": + """Get a RegionWriter for the specified region. + + `initialize_mem_wal` must be called before using this method. + Each *region* is an independent write shard; use different region IDs + to achieve parallel ingestion without writer contention. + + Parameters + ---------- + region_id : str + UUID string identifying the write region (e.g. + ``str(uuid.uuid4())``). + durable_write : bool, optional + Whether to fsync WAL writes (default: ``True``). + sync_indexed_write : bool, optional + Whether index updates are synchronous (default: ``True``). + max_wal_buffer_size : int, optional + Maximum WAL buffer size in bytes (default: 10 MB). + max_wal_flush_interval_ms : int, optional + Maximum WAL flush interval in milliseconds (default: 100). + max_memtable_size : int, optional + Maximum MemTable size in bytes (default: 256 MB). + max_memtable_rows : int, optional + Maximum rows per MemTable (default: 100 000). + max_memtable_batches : int, optional + Maximum batches per MemTable (default: 8 000). + max_unflushed_memtable_bytes : int, optional + Maximum unflushed bytes before backpressure (default: 1 GB). + ivf_index_partition_capacity_safety_factor : int, optional + Safety factor for IVF partition capacity (default: 8). + manifest_scan_batch_size : int, optional + Batch size for manifest scans (default: 2). + async_index_buffer_rows : int, optional + Buffer rows for async index updates (default: 10 000). + async_index_interval_ms : int, optional + Interval for async index updates in milliseconds (default: 1000). + backpressure_log_interval_ms : int, optional + Interval for backpressure log messages in milliseconds + (default: 30 000). + stats_log_interval_ms : int, optional + Interval for statistics log messages in milliseconds + (default: 60 000). Pass ``0`` to disable. + + Returns + ------- + RegionWriter + A context-manager-compatible writer for the specified region. + + Examples + -------- + >>> import lance + >>> import pyarrow as pa + >>> import tempfile + >>> import uuid + >>> schema = pa.schema([ + ... pa.field("id", pa.int64(), nullable=False, + ... metadata={"lance-schema:unenforced-primary-key": "true"}), + ... pa.field("val", pa.float32()), + ... ]) + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... ds = lance.write_dataset( + ... pa.table({"id": [1], "val": [0.1]}, schema=schema), + ... tmpdir, + ... ) + ... ds.initialize_mem_wal() + ... region_id = str(uuid.uuid4()) + ... new_data = pa.table({"id": [2], "val": [0.2]}, schema=schema) + ... with ds.mem_wal_writer(region_id) as writer: + ... writer.put(new_data) + """ + import lance.mem_wal as _mw + + kwargs = { + name: val + for name, val in [ + ("durable_write", durable_write), + ("sync_indexed_write", sync_indexed_write), + ("max_wal_buffer_size", max_wal_buffer_size), + ("max_wal_flush_interval_ms", max_wal_flush_interval_ms), + ("max_memtable_size", max_memtable_size), + ("max_memtable_rows", max_memtable_rows), + ("max_memtable_batches", max_memtable_batches), + ("max_unflushed_memtable_bytes", max_unflushed_memtable_bytes), + ( + "ivf_index_partition_capacity_safety_factor", + ivf_index_partition_capacity_safety_factor, + ), + ("manifest_scan_batch_size", manifest_scan_batch_size), + ("async_index_buffer_rows", async_index_buffer_rows), + ("async_index_interval_ms", async_index_interval_ms), + ("backpressure_log_interval_ms", backpressure_log_interval_ms), + ("stats_log_interval_ms", stats_log_interval_ms), + ] + if val is not None + } + raw = self._ds.mem_wal_writer(region_id, **kwargs) + return _mw.RegionWriter(raw) + class SqlQuery: """ diff --git a/python/python/lance/mem_wal.py b/python/python/lance/mem_wal.py new file mode 100644 index 00000000000..87609435e53 --- /dev/null +++ b/python/python/lance/mem_wal.py @@ -0,0 +1,532 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +""" +Python wrappers for Lance MemWAL functionality. + +The MemWAL feature enables high-throughput, low-latency writes to a Lance +dataset via an LSM-tree structure. Data flows through three levels: + +1. **WAL** – append-only durable log (raw writes) +2. **Active MemTable** – in-memory write buffer +3. **Flushed MemTable** – Lance files written to object store +4. **Base table** – canonical Lance dataset files (after merge_insert) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional + +import pyarrow as pa + +from .lance import ( + _ExecutionPlan, + _LsmPointLookupPlanner, + _LsmScanner, + _LsmVectorSearchPlanner, + _MergedGeneration, + _RegionSnapshot, + _RegionWriter, +) +from .types import _coerce_reader + +if TYPE_CHECKING: + import lance + +__all__ = [ + "RegionField", + "RegionSpec", + "MergedGeneration", + "RegionSnapshot", + "RegionWriter", + "LsmScanner", + "ExecutionPlan", + "LsmPointLookupPlanner", + "LsmVectorSearchPlanner", +] + + +# --------------------------------------------------------------------------- +# RegionSpec +# --------------------------------------------------------------------------- + + +@dataclass +class RegionField: + """Defines one derived field used in region partitioning. + + Parameters + ---------- + field_id : str + Identifier for the derived field. + source_ids : list of int + Source field IDs used as inputs. + transform : str, optional + Optional transform name applied to the source fields. + expression : str, optional + Optional expression used to derive the field value. + result_type : str + Output type name for the derived field. + parameters : dict of str to str, optional + Extra transform parameters. + """ + + field_id: str + source_ids: List[int] + transform: Optional[str] = None + expression: Optional[str] = None + result_type: str = "" + parameters: Dict[str, str] = field(default_factory=dict) + + +@dataclass +class RegionSpec: + """Partitioning specification for deriving MemWAL region IDs.""" + + spec_id: int + fields: List[RegionField] + + +@dataclass +class MergedGeneration: + """Identifies a flushed MemWAL generation that has been merged. + + Pass a list of these to mark_generations_as_merged + so Lance knows which generations are now in the base table. + + Parameters + ---------- + region_id : str + UUID string for the write region. + generation : int + Generation number (from + :attr:`RegionSnapshot.flushed_generations`). + """ + + region_id: str + generation: int + + +class RegionSnapshot: + """Snapshot of a MemWAL region's state, used when constructing scanners. + + Parameters + ---------- + region_id : str + UUID string for the write region. + """ + + def __init__(self, region_id: str) -> None: + self._raw = _RegionSnapshot(region_id) + + @property + def region_id(self) -> str: + """UUID string for this region.""" + return self._raw.region_id + + def with_spec_id(self, spec_id: int) -> "RegionSnapshot": + """Set the RegionSpec ID.""" + self._raw = self._raw.with_spec_id(spec_id) + return self + + def with_current_generation(self, generation: int) -> "RegionSnapshot": + """Set the current (active) generation number.""" + self._raw = self._raw.with_current_generation(generation) + return self + + def with_flushed_generation(self, generation: int, path: str) -> "RegionSnapshot": + """Add a flushed generation with its storage path.""" + self._raw = self._raw.with_flushed_generation(generation, path) + return self + + def __repr__(self) -> str: + return repr(self._raw) + + +class RegionWriter: + """Stateful writer for one MemWAL region. + + Obtain an instance via mem_wal_writer. + Use as a context manager so the writer is closed automatically:: + + with dataset.mem_wal_writer(region_id) as writer: + writer.put(batch) + + Parameters + ---------- + _raw : _RegionWriter + Internal PyO3 object — do not construct directly. + """ + + def __init__(self, _raw: _RegionWriter) -> None: + self._raw = _raw + + @property + def region_id(self) -> str: + """UUID string for this writer's region.""" + return self._raw.region_id + + def put(self, data, *, schema: Optional[pa.Schema] = None) -> None: + """Write data to the MemWAL. + + Parameters + ---------- + data : ReaderLike + Any Arrow-compatible data `pyarrow.Table`, + `pyarrow.RecordBatch`, ``RecordBatchReader``, pandas + DataFrame, etc. + schema : pa.Schema, optional + Schema hint, needed when *data* is a generator. + + Raises + ------ + IOError + On WAL flush failure. + RuntimeError + If the writer has already been closed. + """ + reader = _coerce_reader(data, schema) + self._raw.put(reader) + + def close(self) -> None: + """Flush and close the writer. + + After ``close()``, calling :meth:`put` raises an error. + Automatically called when used as a context manager. + """ + self._raw.close() + + def stats(self) -> dict: + """Return a snapshot of write statistics. + + Returns + ------- + dict + Keys: ``put_count``, ``put_time_ms``, ``wal_flush_count``, + ``wal_flush_bytes``, ``wal_flush_time_ms``, + ``memtable_flush_count``, ``memtable_flush_rows``, + ``memtable_flush_time_ms``. + """ + return self._raw.stats() + + def memtable_stats(self) -> dict: + """Return current MemTable statistics. + + Returns + ------- + dict + Keys: ``row_count``, ``batch_count``, ``estimated_size_bytes``, + ``generation``. + """ + return self._raw.memtable_stats() + + def lsm_scanner( + self, region_snapshots: Optional[List[RegionSnapshot]] = None + ) -> "LsmScanner": + """Create an LSM scanner that includes the active MemTable. + + This scanner covers the base table, the given flushed generations, + and the current active MemTable — providing strong read-your-writes + consistency. + + Parameters + ---------- + region_snapshots : list of RegionSnapshot, optional + Snapshots of other regions to include. This writer's own region + is automatically included. + + Returns + ------- + LsmScanner + """ + raw_snaps = [s._raw for s in (region_snapshots or [])] + return LsmScanner(self._raw.lsm_scanner(raw_snaps)) + + def __enter__(self) -> "RegionWriter": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + self.close() + return False + + +class LsmScanner: + """LSM-aware scanner covering all data levels. + + Deduplicates by primary key, always returning the newest version of + each row across base table, flushed MemTables, and the active MemTable. + + Obtain an instance from `RegionWriter.lsm_scanner` (includes + active MemTable) or `LsmScanner.from_snapshots` (flushed only). + + The builder methods (`project`, `filter`, `limit`) + return ``self`` for chaining. + + Examples + -------- + scanner = LsmScanner.from_snapshots(dataset, [snapshot]) + scanner.project(["id", "value"]).filter("value > 0.5") + table = scanner.to_table() + """ + + def __init__(self, _raw: _LsmScanner) -> None: + self._raw = _raw + + @staticmethod + def from_snapshots( + dataset: "lance.LanceDataset", + region_snapshots: List[RegionSnapshot], + ) -> "LsmScanner": + """Create a scanner from dataset and region snapshots. + + Does **not** include the active MemTable; use + `RegionWriter.lsm_scanner` for that. + + Parameters + ---------- + dataset : LanceDataset + The base dataset to scan. + region_snapshots : list of RegionSnapshot + Region snapshots specifying flushed generations to include. + """ + raw = _LsmScanner.from_snapshots( + dataset._ds, [s._raw for s in region_snapshots] + ) + return LsmScanner(raw) + + def project(self, columns: List[str]) -> "LsmScanner": + """Select specific columns to return.""" + self._raw = self._raw.project(columns) + return self + + def filter(self, expr: str) -> "LsmScanner": + """Set a SQL filter expression (e.g. ``"value > 0.5"``).""" + self._raw = self._raw.filter(expr) + return self + + def limit(self, n: int, offset: Optional[int] = None) -> "LsmScanner": + """Limit rows returned, optionally with an offset.""" + self._raw = self._raw.limit(n, offset) + return self + + def with_row_address(self) -> "LsmScanner": + """Include the ``_rowaddr`` internal column in results.""" + self._raw = self._raw.with_row_address() + return self + + def with_memtable_gen(self) -> "LsmScanner": + """Include the ``_memtable_gen`` internal column in results.""" + self._raw = self._raw.with_memtable_gen() + return self + + def to_batch(self) -> pa.RecordBatch: + """Execute the scan and return a single merged :class:`~pyarrow.RecordBatch`.""" + return self._raw.to_batch() + + def to_batches(self) -> List[pa.RecordBatch]: + """Execute the scan and return a list of :class:`~pyarrow.RecordBatch`.""" + return list(self._raw.to_batches()) + + def to_table(self) -> pa.Table: + """Execute the scan and return a :class:`~pyarrow.Table`.""" + batch = self.to_batch() + return pa.Table.from_batches([batch]) + + def count_rows(self) -> int: + """Return the row count without loading all column data.""" + return self._raw.count_rows() + + +class ExecutionPlan: + """Executable physical plan returned by MemWAL planners. + + This wraps the Rust/DataFusion physical plan object. Planner classes only + construct plans; execution happens through this class. + + Parameters + ---------- + _raw : _ExecutionPlan + """ + + def __init__(self, _raw: _ExecutionPlan) -> None: + self._raw = _raw + + @property + def schema(self) -> pa.Schema: + """Output schema of this physical plan.""" + return self._raw.schema + + @property + def dataset_schema(self) -> pa.Schema: + """Base dataset schema used to construct this plan.""" + return self._raw.dataset_schema + + def explain(self) -> str: + """Return the physical plan as an indented string.""" + return self._raw.explain() + + def to_reader(self) -> pa.RecordBatchReader: + """Execute the plan and return a streaming reader.""" + return self._raw.to_reader() + + def to_batches(self) -> List[pa.RecordBatch]: + """Execute the plan and return all record batches.""" + return list(self._raw.to_batches()) + + def to_table(self) -> pa.Table: + """Execute the plan and return a table.""" + return self.to_reader().read_all() + + +class LsmPointLookupPlanner: + """Plans primary-key point lookups across all LSM levels. + + More efficient than `LsmScanner` for known-PK lookups due to + bloom filter optimizations and short-circuit evaluation. + + Parameters + ---------- + dataset : LanceDataset + The base dataset. + region_snapshots : list of RegionSnapshot + Region snapshots specifying flushed generations to include. + pk_columns : list of str, optional + Primary key column names. Inferred from schema metadata if omitted. + + Examples + -------- + planner = LsmPointLookupPlanner(dataset, [snapshot]) + plan = planner.plan_lookup(pa.array([42], type=pa.int64())) + result = plan.to_table() + """ + + def __init__( + self, + dataset: "lance.LanceDataset", + region_snapshots: List[RegionSnapshot], + pk_columns: Optional[List[str]] = None, + ) -> None: + self._raw = _LsmPointLookupPlanner( + dataset._ds, + [s._raw for s in region_snapshots], + pk_columns, + ) + + def plan_lookup( + self, + pk_value: pa.Array, + columns: Optional[List[str]] = None, + ) -> ExecutionPlan: + """Plan a point lookup by primary key value. + + Parameters + ---------- + pk_value : pa.Array + For single-column primary keys, a PyArrow array with exactly one + element. For composite primary keys, a single-row + `pyarrow.StructArray` with one field per primary-key column. + columns : list of str, optional + Columns to project. Returns all columns if omitted. + + Returns + ------- + ExecutionPlan + Physical plan for the lookup. Execute it via `to_table`, + `to_reader`, or `to_batches`. + """ + return ExecutionPlan(self._raw.plan_lookup(pk_value, columns)) + + +class LsmVectorSearchPlanner: + """Plans IVF-PQ vector KNN search across all LSM levels. + + Results include staleness filtering to return only the latest version + of each row. The output schema includes the ``_distance`` column. + + Parameters + ---------- + dataset : LanceDataset + The base dataset. + region_snapshots : list of RegionSnapshot + Region snapshots specifying flushed generations to include. + vector_column : str + Name of the ``FixedSizeList`` vector column. + pk_columns : list of str, optional + Primary key columns. Inferred from schema metadata if omitted. + distance_type : str, optional + Distance metric — one of ``"l2"`` (default), ``"cosine"``, + ``"dot"``, ``"hamming"``. + + Examples + -------- + import numpy as np + planner = LsmVectorSearchPlanner(dataset, [snapshot], "vector") + query = pa.array(np.random.rand(128).astype("float32")) + plan = planner.plan_search(query, k=10) + result = plan.to_table() + """ + + def __init__( + self, + dataset: "lance.LanceDataset", + region_snapshots: List[RegionSnapshot], + vector_column: str, + pk_columns: Optional[List[str]] = None, + distance_type: Optional[str] = None, + ) -> None: + kwargs = {} + if pk_columns is not None: + kwargs["pk_columns"] = pk_columns + if distance_type is not None: + kwargs["distance_type"] = distance_type + self._raw = _LsmVectorSearchPlanner( + dataset._ds, + [s._raw for s in region_snapshots], + vector_column, + **kwargs, + ) + + def plan_search( + self, + query: pa.Array, + k: int = 10, + nprobes: int = 20, + columns: Optional[List[str]] = None, + ) -> ExecutionPlan: + """Plan a KNN vector search. + + Parameters + ---------- + query : pa.Array + A flat ``Float32Array`` of length ``vector_dim``. + k : int, optional + Number of nearest neighbours to return (default: 10). + nprobes : int, optional + Number of IVF partitions to probe (default: 20). + columns : list of str, optional + Columns to project. Returns all columns + ``_distance`` if + omitted. + + Returns + ------- + ExecutionPlan + Physical plan for the vector search. Execute it via + `to_table`, `to_reader`, or `to_batches`. + """ + return ExecutionPlan(self._raw.plan_search(query, k, nprobes, columns)) + + +def _unwrap_region_id(region_id: str) -> str: + """Validate region_id is a UUID string.""" + import uuid as _uuid + + _uuid.UUID(region_id) # raises ValueError if invalid + return region_id + + +def _to_raw_merged_generations( + generations: Iterable[MergedGeneration], +) -> list: + """Convert Python MergedGeneration list to PyO3 _MergedGeneration list.""" + return [_MergedGeneration(g.region_id, g.generation) for g in generations] diff --git a/python/python/tests/test_mem_wal.py b/python/python/tests/test_mem_wal.py new file mode 100644 index 00000000000..e7a6f51b67a --- /dev/null +++ b/python/python/tests/test_mem_wal.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors +import math +import os +import time +import uuid + +import lance +import pyarrow as pa +from lance.mem_wal import ( + LsmPointLookupPlanner, + LsmScanner, + RegionSnapshot, +) + +_PK_META = {"lance-schema:unenforced-primary-key": "true"} +_LOOKUP_SCHEMA = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False, metadata=_PK_META), + pa.field("name", pa.utf8()), + ] +) + + +def _lookup_table(ids, prefix: str) -> pa.Table: + """Build a table where name = '{prefix}_{id}' for each id.""" + return pa.table( + { + "id": pa.array(ids, pa.int64()), + "name": pa.array([f"{prefix}_{i}" for i in ids], pa.utf8()), + }, + schema=_LOOKUP_SCHEMA, + ) + + +def _write_flushed_gen(base_path: str, region_id: str, gen_folder: str, data: pa.Table): + """Write a flushed-generation Lance dataset at the expected sub-path. + + The collector resolves flushed generation paths as: + {base_dataset_path}/_mem_wal/{region_id}/{gen_folder} + """ + gen_path = os.path.join(base_path, "_mem_wal", region_id, gen_folder) + lance.write_dataset(data, gen_path, schema=_LOOKUP_SCHEMA) + + +def test_point_lookup_with_memtables(tmp_path): + """ + Lookup against a base table that has one flushed generation containing an + update. The flushed version must win over the base table version. + + Setup + ----- + base : ids [1, 2, 3] names ["base_1", "base_2", "base_3"] + gen_1 : ids [2] names ["gen1_2"] ← update to id=2 + + RegionSnapshot: flushed_generation(gen=1, path="gen_1"), current_generation=2 + """ + ds_path = str(tmp_path / "base") + region_id = str(uuid.uuid4()) + + # --- Base dataset --- + base_ds = lance.write_dataset( + _lookup_table([1, 2, 3], "base"), ds_path, schema=_LOOKUP_SCHEMA + ) + base_ds.initialize_mem_wal() + + # --- Flushed generation: overwrites id=2 --- + _write_flushed_gen(ds_path, region_id, "gen_1", _lookup_table([2], "gen1")) + + # --- RegionSnapshot describing the flushed state --- + snap = ( + RegionSnapshot(region_id) + .with_flushed_generation(1, "gen_1") + .with_current_generation(2) + ) + + planner = LsmPointLookupPlanner(base_ds, [snap]) + assert not hasattr(planner, "lookup") + + # id=2 must return the flushed version + plan = planner.plan_lookup(pa.array([2], type=pa.int64())) + assert plan.schema.names == ["id", "name"] + assert plan.dataset_schema.names == ["id", "name"] + assert "Take" in plan.explain() or "Scan" in plan.explain() + result = plan.to_table() + assert len(result) == 1, "Expected exactly one row for id=2" + assert result.column("name")[0].as_py() == "gen1_2", ( + "Flushed generation must win over base table" + ) + + # id=1 is only in the base table + result_base = planner.plan_lookup(pa.array([1], type=pa.int64())).to_table() + assert len(result_base) == 1 + assert result_base.column("name")[0].as_py() == "base_1" + + # id=99 does not exist anywhere + result_miss = planner.plan_lookup(pa.array([99], type=pa.int64())).to_table() + assert len(result_miss) == 0, "Non-existent key must return empty result" + + +def test_lsm_scanner_with_memtables(tmp_path): + """ + Full-scan via LsmScanner.from_snapshots deduplicates rows across LSM + levels: each primary key appears exactly once, from its newest level. + + base : ids [1, 2, 3] names ["base_1", "base_2", "base_3"] + gen_1 : ids [2] names ["gen1_2"] ← overwrites id=2 + + Expected result: 3 unique rows — id=2 from gen_1, id=1 and id=3 from base. + """ + ds_path = str(tmp_path / "base") + region_id = str(uuid.uuid4()) + + base_ds = lance.write_dataset( + _lookup_table([1, 2, 3], "base"), ds_path, schema=_LOOKUP_SCHEMA + ) + base_ds.initialize_mem_wal() + + _write_flushed_gen(ds_path, region_id, "gen_1", _lookup_table([2], "gen1")) + + snap = ( + RegionSnapshot(region_id) + .with_flushed_generation(1, "gen_1") + .with_current_generation(2) + ) + + scanner = LsmScanner.from_snapshots(base_ds, [snap]) + table = scanner.to_table() + + assert len(table) == 3, f"Expected 3 deduplicated rows, got {len(table)}" + name_by_id = {row["id"]: row["name"] for row in table.to_pylist()} + + assert name_by_id[1] == "base_1" + assert name_by_id[2] == "gen1_2", "Flushed gen must overwrite base for id=2" + assert name_by_id[3] == "base_3" + + +_VDIM = 4 # matches Rust test fixture dimension + + +def _vector_search_schema(): + """Schema for vector-search tests: id (int32 PK) + vector column.""" + pk_meta = {"lance-schema:unenforced-primary-key": "true"} + return pa.schema( + [ + pa.field("id", pa.int32(), nullable=False, metadata=pk_meta), + pa.field("vector", pa.list_(pa.float32(), _VDIM)), + ] + ) + + +def _vector_search_table(ids): + """Build a table with deterministic vectors matching the Rust fixture. + + For id=N the vector is [N*0.1, N*0.1+0.1, N*0.1+0.2, N*0.1+0.3]. + """ + flat = [] + for i in ids: + base = i * 0.1 + flat.extend([base, base + 0.1, base + 0.2, base + 0.3]) + storage = pa.array(flat, type=pa.float32()) + vectors = pa.FixedSizeListArray.from_arrays(storage, _VDIM) + return pa.table( + {"id": pa.array(ids, pa.int32()), "vector": vectors}, + schema=_vector_search_schema(), + ) + + +VECTOR_DIM = 32 +ROWS_PER_BATCH = 50 +NUM_WRITE_ROUNDS = 3 +BATCHES_PER_ROUND = 3 + + +def _e2e_schema(): + """Schema for the e2e test: id (PK), vector, text.""" + pk_meta = {"lance-schema:unenforced-primary-key": "true"} + return pa.schema( + [ + pa.field("id", pa.int64(), nullable=False, metadata=pk_meta), + pa.field("vector", pa.list_(pa.float32(), VECTOR_DIM)), + pa.field("text", pa.utf8()), + ] + ) + + +def _e2e_batch(schema, start_id: int, num_rows: int) -> pa.RecordBatch: + """Deterministic RecordBatch for the e2e test.""" + ids = list(range(start_id, start_id + num_rows)) + flat_vecs = [ + math.sin((start_id + i) * 0.1 + d * 0.01) + for i in range(num_rows) + for d in range(VECTOR_DIM) + ] + storage = pa.array(flat_vecs, type=pa.float32()) + vector_array = pa.FixedSizeListArray.from_arrays(storage, VECTOR_DIM) + texts = [f"Sample text for row {start_id + i}" for i in range(num_rows)] + return pa.record_batch( + { + "id": pa.array(ids, pa.int64()), + "vector": vector_array, + "text": pa.array(texts, pa.utf8()), + }, + schema=schema, + ) + + +def test_region_writer_e2e_correctness(tmp_path): + """ + End-to-end correctness test for RegionWriter covering: + - Multi-round writes that trigger WAL and MemTable flushes + - File-system layout verification (_mem_wal//wal/ and manifest/) + - Flushed generation data readable via LsmScanner + - New writer created after close can write and scan correctly + + Mirrors Rust test: region_writer_tests::test_region_writer_e2e_correctness + """ + schema = _e2e_schema() + ds_path = str(tmp_path / "ds") + + # 500 seed rows so BTree index training succeeds + initial_batch = _e2e_batch(schema, start_id=0, num_rows=500) + ds = lance.write_dataset( + pa.Table.from_batches([initial_batch]), ds_path, schema=schema + ) + + ds.create_scalar_index("id", "BTREE", name="id_btree") + ds.initialize_mem_wal(maintained_indexes=["id_btree"]) + + # Small buffers to trigger WAL and MemTable flushes during the test + region_id = str(uuid.uuid4()) + writer = ds.mem_wal_writer( + region_id, + durable_write=True, + sync_indexed_write=True, + max_wal_buffer_size=10 * 1024, # 10 KB + max_wal_flush_interval_ms=50, + max_memtable_size=80, # flush after ~80 rows + ) + + total_rows_written = 0 + for _round in range(NUM_WRITE_ROUNDS): + start_id = 500 + total_rows_written + for i in range(BATCHES_PER_ROUND): + batch = _e2e_batch(schema, start_id + i * ROWS_PER_BATCH, ROWS_PER_BATCH) + writer.put(pa.Table.from_batches([batch])) + total_rows_written += BATCHES_PER_ROUND * ROWS_PER_BATCH + time.sleep(0.15) # allow async WAL/memtable flush + + writer.close() + + # === Stats === + stats = writer.stats() + assert stats["put_count"] == NUM_WRITE_ROUNDS * BATCHES_PER_ROUND + assert stats["wal_flush_count"] >= 1, "Expected at least one WAL flush" + + closed_memtable_stats = writer.memtable_stats() + assert closed_memtable_stats["row_count"] == 0 + assert closed_memtable_stats["batch_count"] == 0 + assert closed_memtable_stats["generation"] >= 1 + + # === File-system layout === + mem_wal_dir = os.path.join(ds_path, "_mem_wal", region_id) + assert os.path.isdir(mem_wal_dir), f"MemWAL directory missing: {mem_wal_dir}" + + wal_dir = os.path.join(mem_wal_dir, "wal") + assert os.path.isdir(wal_dir), "WAL sub-directory missing" + wal_files = os.listdir(wal_dir) + assert len(wal_files) >= 1 + assert all(f.endswith(".arrow") for f in wal_files), ( + f"All WAL files should have .arrow extension, got: {wal_files}" + ) + + manifest_dir = os.path.join(mem_wal_dir, "manifest") + assert os.path.isdir(manifest_dir), "Manifest sub-directory missing" + assert len(os.listdir(manifest_dir)) >= 1 + + # === Generation counter advanced === + mt_stats = writer.memtable_stats() + assert mt_stats["generation"] >= 1 + + # === New writer: write and read back via active MemTable scanner === + ds2 = lance.dataset(ds_path) + region_id2 = str(uuid.uuid4()) + with ds2.mem_wal_writer( + region_id2, durable_write=False, sync_indexed_write=True + ) as writer2: + verify_batch = _e2e_batch(schema, start_id=10000, num_rows=10) + writer2.put(pa.Table.from_batches([verify_batch])) + result = writer2.lsm_scanner().to_table() + + assert len(result) >= 10 + new_ids = result.column("id").to_pylist() + assert 10000 in new_ids + assert 10009 in new_ids diff --git a/python/src/dataset.rs b/python/src/dataset.rs index d1dcfe5c515..093da2826d8 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -112,6 +112,14 @@ pub mod stats; const DEFAULT_NPROBES: usize = 1; const LANCE_COMMIT_MESSAGE_KEY: &str = "__lance_commit_message"; +fn stats_log_interval_from_millis(ms: u64) -> Option { + if ms == 0 { + None + } else { + Some(std::time::Duration::from_millis(ms)) + } +} + fn convert_reader(reader: &Bound) -> PyResult> { let py = reader.py(); if reader.is_instance_of::() { @@ -321,6 +329,25 @@ impl MergeInsertBuilder { rt().block_on(None, job.analyze_plan(new_data_stream))? .map_err(|err| PyIOError::new_err(err.to_string())) } + + /// Mark MemWAL generations as merged into the base table. + /// + /// Call this when executing a merge_insert that incorporates MemWAL + /// flushed generation data. This updates the MemWAL generation tracking + /// to prevent duplicate merges. + pub fn mark_generations_as_merged<'a>( + mut slf: PyRefMut<'a, Self>, + generations: Vec>, + ) -> PyResult> { + use lance_index::mem_wal::MergedGeneration; + + let gens: Vec = generations + .iter() + .map(|g| g.borrow().to_lance()) + .collect::>()?; + slf.builder.mark_generations_as_merged(gens); + Ok(slf) + } } impl MergeInsertBuilder { @@ -2692,6 +2719,165 @@ impl Dataset { let builder = ds.delta(); Ok(DatasetDeltaBuilder { builder }) } + + /// Initialize MemWAL on this dataset. + /// + /// Must be called once before any `mem_wal_writer()` calls. + /// Requires the dataset schema to have at least one field with + /// the `lance-schema:unenforced-primary-key` metadata. + #[pyo3(signature=(maintained_indexes=None, region_spec=None))] + fn initialize_mem_wal( + &mut self, + py: Python<'_>, + maintained_indexes: Option>, + region_spec: Option>, + ) -> PyResult<()> { + use lance::dataset::mem_wal::DatasetMemWalExt; + use lance_index::mem_wal::{RegionField, RegionSpec}; + use std::collections::HashMap; + + let region_spec_rust = if let Some(spec) = region_spec { + let spec_id: u32 = spec.getattr("spec_id")?.extract()?; + let fields_py: Vec> = spec.getattr("fields")?.extract()?; + let fields = fields_py + .iter() + .map(|f| -> PyResult { + Ok(RegionField { + field_id: f.getattr("field_id")?.extract()?, + source_ids: f.getattr("source_ids")?.extract()?, + transform: f.getattr("transform")?.extract()?, + expression: f.getattr("expression")?.extract()?, + result_type: f.getattr("result_type")?.extract()?, + parameters: f + .getattr("parameters")? + .extract::>()?, + }) + }) + .collect::>>()?; + Some(RegionSpec { spec_id, fields }) + } else { + None + }; + + let config = lance::dataset::mem_wal::MemWalConfig { + region_spec: region_spec_rust, + maintained_indexes: maintained_indexes.unwrap_or_default(), + }; + let mut ds = Arc::clone(&self.ds); + let new_ds = rt() + .block_on(Some(py), async move { + Arc::make_mut(&mut ds).initialize_mem_wal(config).await?; + Ok::, lance_core::Error>(ds) + })? + .map_err(|e| PyIOError::new_err(e.to_string()))?; + self.ds = new_ds; + Ok(()) + } + + /// Get a RegionWriter for the specified region. + /// + /// `initialize_mem_wal()` must be called before using this method. + #[allow(clippy::too_many_arguments)] + #[pyo3(signature=( + region_id, + *, + durable_write=None, + sync_indexed_write=None, + max_wal_buffer_size=None, + max_wal_flush_interval_ms=None, + max_memtable_size=None, + max_memtable_rows=None, + max_memtable_batches=None, + max_unflushed_memtable_bytes=None, + ivf_index_partition_capacity_safety_factor=None, + manifest_scan_batch_size=None, + async_index_buffer_rows=None, + async_index_interval_ms=None, + backpressure_log_interval_ms=None, + stats_log_interval_ms=None, + ))] + fn mem_wal_writer( + &self, + py: Python<'_>, + region_id: String, + durable_write: Option, + sync_indexed_write: Option, + max_wal_buffer_size: Option, + max_wal_flush_interval_ms: Option, + max_memtable_size: Option, + max_memtable_rows: Option, + max_memtable_batches: Option, + max_unflushed_memtable_bytes: Option, + ivf_index_partition_capacity_safety_factor: Option, + manifest_scan_batch_size: Option, + async_index_buffer_rows: Option, + async_index_interval_ms: Option, + backpressure_log_interval_ms: Option, + stats_log_interval_ms: Option, + ) -> PyResult { + use lance::dataset::mem_wal::{DatasetMemWalExt, RegionWriterConfig}; + + let uuid = uuid::Uuid::parse_str(®ion_id) + .map_err(|e| PyValueError::new_err(format!("Invalid region_id UUID: {}", e)))?; + + let mut config = RegionWriterConfig::default(); + if let Some(v) = durable_write { + config = config.with_durable_write(v); + } + if let Some(v) = sync_indexed_write { + config = config.with_sync_indexed_write(v); + } + if let Some(v) = max_wal_buffer_size { + config = config.with_max_wal_buffer_size(v); + } + if let Some(v) = max_wal_flush_interval_ms { + config = config.with_max_wal_flush_interval(std::time::Duration::from_millis(v)); + } + if let Some(v) = max_memtable_size { + config = config.with_max_memtable_size(v); + } + if let Some(v) = max_memtable_rows { + config = config.with_max_memtable_rows(v); + } + if let Some(v) = max_memtable_batches { + config = config.with_max_memtable_batches(v); + } + if let Some(v) = max_unflushed_memtable_bytes { + config = config.with_max_unflushed_memtable_bytes(v); + } + if let Some(v) = ivf_index_partition_capacity_safety_factor { + config = config.with_ivf_index_partition_capacity_safety_factor(v); + } + if let Some(v) = manifest_scan_batch_size { + config = config.with_manifest_scan_batch_size(v); + } + if let Some(v) = async_index_buffer_rows { + config = config.with_async_index_buffer_rows(v); + } + if let Some(v) = async_index_interval_ms { + config = config.with_async_index_interval(std::time::Duration::from_millis(v)); + } + if let Some(v) = backpressure_log_interval_ms { + config = config.with_backpressure_log_interval(std::time::Duration::from_millis(v)); + } + if let Some(v) = stats_log_interval_ms { + config = config.with_stats_log_interval(stats_log_interval_from_millis(v)); + } + + let ds = self.ds.clone(); + let writer = rt() + .block_on( + Some(py), + async move { ds.mem_wal_writer(uuid, config).await }, + )? + .map_err(|e| PyIOError::new_err(e.to_string()))?; + + Ok(crate::mem_wal::PyRegionWriter::new( + writer, + uuid, + self.ds.clone(), + )) + } } #[pyclass(name = "SqlQuery", module = "_lib", subclass)] diff --git a/python/src/lib.rs b/python/src/lib.rs index 8c6f7c186ed..d34a73c0fdf 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -68,6 +68,7 @@ pub(crate) mod executor; pub(crate) mod file; pub(crate) mod fragment; pub(crate) mod indices; +pub(crate) mod mem_wal; pub(crate) mod namespace; pub(crate) mod reader; pub(crate) mod scanner; @@ -282,6 +283,14 @@ fn lance(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + // MemWAL classes + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pyfunction!(bfloat16_array))?; m.add_wrapped(wrap_pyfunction!(write_dataset))?; m.add_wrapped(wrap_pyfunction!(write_fragments))?; diff --git a/python/src/mem_wal.rs b/python/src/mem_wal.rs new file mode 100644 index 00000000000..82584e6c921 --- /dev/null +++ b/python/src/mem_wal.rs @@ -0,0 +1,858 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use arrow::ffi_stream::ArrowArrayStreamReader; +use arrow::pyarrow::*; +use arrow_array::{ + Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, + StructArray, make_array, +}; +use arrow_data::ArrayData; +use arrow_schema::{DataType, Field, Schema as ArrowSchema}; +use datafusion::common::ScalarValue; +use datafusion::physical_plan::{ExecutionPlan, collect, displayable}; +use datafusion::prelude::SessionContext; +use futures::TryStreamExt; +use lance::dataset::Dataset as LanceDataset; +use lance::dataset::mem_wal::scanner::{ + LsmDataSourceCollector, LsmPointLookupPlanner, LsmVectorSearchPlanner, +}; +use lance::dataset::mem_wal::write::{MemTableStats, WriteStatsSnapshot}; +use lance::dataset::mem_wal::{LsmScanner, RegionSnapshot, RegionWriter}; +use lance_index::mem_wal::MergedGeneration as LanceMergedGeneration; +use lance_linalg::distance::DistanceType; +use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList}; +use tokio::sync::Mutex as TokioMutex; +use uuid::Uuid; + +use crate::dataset::Dataset as PyDataset; +use crate::rt; + +/// Represents a single generation of a MemWAL region that has been merged +/// into the base table. Used with `MergeInsertBuilder.mark_generations_as_merged()`. +#[pyclass(name = "_MergedGeneration", module = "_lib")] +pub struct PyMergedGeneration { + pub region_id: String, + pub generation: u64, +} + +#[pymethods] +impl PyMergedGeneration { + #[new] + pub fn new(region_id: String, generation: u64) -> Self { + Self { + region_id, + generation, + } + } + + #[getter] + pub fn region_id(&self) -> &str { + &self.region_id + } + + #[getter] + pub fn generation(&self) -> u64 { + self.generation + } + + pub fn __repr__(&self) -> String { + format!( + "_MergedGeneration(region_id='{}', generation={})", + self.region_id, self.generation + ) + } +} + +impl PyMergedGeneration { + pub fn to_lance(&self) -> PyResult { + let uuid = Uuid::parse_str(&self.region_id) + .map_err(|e| PyValueError::new_err(format!("Invalid region_id UUID: {}", e)))?; + Ok(LanceMergedGeneration::new(uuid, self.generation)) + } +} + +/// Snapshot of a MemWAL region's state at a point in time. +/// +/// Used to specify which flushed generations to include when creating an +/// `_LsmScanner`. Supports a builder pattern for adding generations. +#[pyclass(name = "_RegionSnapshot", module = "_lib")] +#[derive(Clone)] +pub struct PyRegionSnapshot { + pub inner: RegionSnapshot, +} + +#[pymethods] +impl PyRegionSnapshot { + #[new] + pub fn new(region_id: String) -> PyResult { + let uuid = Uuid::parse_str(®ion_id) + .map_err(|e| PyValueError::new_err(format!("Invalid region_id UUID: {}", e)))?; + Ok(Self { + inner: RegionSnapshot::new(uuid), + }) + } + + /// Set the RegionSpec ID for this snapshot. + pub fn with_spec_id(mut slf: PyRefMut<'_, Self>, spec_id: u32) -> PyRefMut<'_, Self> { + slf.inner = slf.inner.clone().with_spec_id(spec_id); + slf + } + + /// Set the current (active) generation number. + pub fn with_current_generation( + mut slf: PyRefMut<'_, Self>, + generation: u64, + ) -> PyRefMut<'_, Self> { + slf.inner = slf.inner.clone().with_current_generation(generation); + slf + } + + /// Add a flushed generation by its generation number and storage path. + pub fn with_flushed_generation( + mut slf: PyRefMut<'_, Self>, + generation: u64, + path: String, + ) -> PyRefMut<'_, Self> { + slf.inner = slf.inner.clone().with_flushed_generation(generation, path); + slf + } + + #[getter] + pub fn region_id(&self) -> String { + self.inner.region_id.to_string() + } + + pub fn __repr__(&self) -> String { + format!( + "_RegionSnapshot(region_id='{}', current_gen={}, flushed_gens={})", + self.inner.region_id, + self.inner.current_generation, + self.inner.flushed_generations.len() + ) + } +} + +/// Long-lived stateful writer for a MemWAL region. +/// +/// Supports writing batches, querying statistics, creating LSM scanners, +/// and graceful shutdown. Supports the Python context manager protocol. +#[pyclass(name = "_RegionWriter", module = "_lib")] +pub struct PyRegionWriter { + inner: Arc>>, + closed_state: Arc>>, + region_id: Uuid, + dataset: Arc, +} + +#[derive(Clone)] +struct ClosedRegionWriterState { + stats: WriteStatsSnapshot, + memtable_stats: MemTableStats, +} + +#[pymethods] +impl PyRegionWriter { + /// Write data batches to the MemWAL. + /// + /// Accepts any PyArrow-compatible data source (RecordBatch, Table, + /// or an Arrow stream reader). + pub fn put(&self, py: Python<'_>, data: &Bound<'_, PyAny>) -> PyResult<()> { + let reader = ArrowArrayStreamReader::from_pyarrow_bound(data) + .map_err(|e| PyValueError::new_err(format!("Cannot read data as Arrow: {}", e)))?; + let batches: Vec = reader + .collect::>() + .map_err(|e| PyIOError::new_err(format!("Failed to read batches: {}", e)))?; + + if batches.is_empty() { + return Ok(()); + } + + let inner = self.inner.clone(); + rt().block_on(Some(py), async move { + let guard = inner.lock().await; + match guard.as_ref() { + Some(writer) => writer.put(batches).await.map(|_| ()), + None => Err(lance_core::Error::invalid_input( + "RegionWriter is already closed", + )), + } + })? + .map_err(|e: lance::Error| PyIOError::new_err(e.to_string())) + } + + /// Flush pending data and close the writer. + /// + /// After close(), calling put() will raise an error. + /// This is called automatically when using the context manager. + pub fn close(&self, py: Python<'_>) -> PyResult<()> { + let inner = self.inner.clone(); + let closed_state = self.closed_state.clone(); + rt().block_on(Some(py), async move { + let mut guard = inner.lock().await; + if let Some(writer) = guard.take() { + let stats_handle = writer.stats_handle(); + // Snapshot stats before close so the captured state reflects + // what was written, not any internal bookkeeping done by close(). + let stats_snapshot = stats_handle.snapshot(); + let memtable_stats_before_close = writer.memtable_stats().await; + writer.close().await?; + let closed_memtable_stats = closed_memtable_stats(memtable_stats_before_close); + let mut closed_guard = closed_state.lock().await; + *closed_guard = Some(ClosedRegionWriterState { + stats: stats_snapshot, + memtable_stats: closed_memtable_stats, + }); + Ok(()) + } else { + Ok(()) + } + })? + .map_err(|e: lance::Error| PyIOError::new_err(e.to_string())) + } + + /// Return a snapshot of current write statistics. + /// + /// Returns a dict with keys: put_count, put_time_ms, wal_flush_count, + /// wal_flush_bytes, memtable_flush_count, memtable_flush_rows. + pub fn stats(&self, py: Python<'_>) -> PyResult> { + let inner = self.inner.clone(); + let closed_state = self.closed_state.clone(); + let stats = rt() + .block_on(Some(py), async move { + let guard = inner.lock().await; + if let Some(writer) = guard.as_ref() { + Ok(writer.stats()) + } else { + let closed_guard = closed_state.lock().await; + closed_guard + .as_ref() + .map(|state| state.stats.clone()) + .ok_or_else(|| { + lance_core::Error::invalid_input("RegionWriter is already closed") + }) + } + })? + .map_err(|e| PyIOError::new_err(e.to_string()))?; + + write_stats_to_pydict(py, &stats) + } + + /// Return current MemTable statistics. + /// + /// Returns a dict with keys: row_count, batch_count, estimated_size_bytes, + /// generation. + pub fn memtable_stats(&self, py: Python<'_>) -> PyResult> { + let inner = self.inner.clone(); + let closed_state = self.closed_state.clone(); + let stats = rt() + .block_on(Some(py), async move { + let guard = inner.lock().await; + match guard.as_ref() { + Some(w) => Ok(w.memtable_stats().await), + None => { + let closed_guard = closed_state.lock().await; + closed_guard + .as_ref() + .map(|state| state.memtable_stats.clone()) + .ok_or_else(|| { + lance_core::Error::invalid_input("RegionWriter is already closed") + }) + } + } + })? + .map_err(|e| PyIOError::new_err(e.to_string()))?; + + memtable_stats_to_pydict(py, &stats) + } + + /// Create an LSM scanner that includes the active MemTable for strong consistency. + /// + /// The scanner covers: base table + given flushed generations + current active MemTable. + #[pyo3(signature = (region_snapshots=vec![]))] + pub fn lsm_scanner( + &self, + py: Python<'_>, + region_snapshots: Vec>, + ) -> PyResult { + let snapshots: Vec = region_snapshots + .iter() + .map(|s| s.borrow().inner.clone()) + .collect(); + + let pk_columns = get_pk_columns(&self.dataset)?; + let inner = self.inner.clone(); + let dataset = self.dataset.clone(); + let region_id = self.region_id; + + let active_ref = rt() + .block_on(Some(py), async move { + let guard = inner.lock().await; + match guard.as_ref() { + Some(w) => Ok(w.active_memtable_ref().await), + None => Err(lance_core::Error::invalid_input( + "RegionWriter is already closed", + )), + } + })? + .map_err(|e| PyIOError::new_err(e.to_string()))?; + + let scanner = LsmScanner::new(dataset, snapshots, pk_columns) + .with_active_memtable(region_id, active_ref); + + Ok(PyLsmScanner { + inner: Some(scanner), + }) + } + + /// Return the region ID as a UUID string. + #[getter] + pub fn region_id(&self) -> String { + self.region_id.to_string() + } + + pub fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + pub fn __exit__( + &self, + py: Python<'_>, + _exc_type: &Bound<'_, PyAny>, + _exc_val: &Bound<'_, PyAny>, + _exc_tb: &Bound<'_, PyAny>, + ) -> PyResult { + self.close(py)?; + Ok(false) + } +} + +impl PyRegionWriter { + /// Create from a Rust RegionWriter and dataset reference. + pub fn new(writer: RegionWriter, region_id: Uuid, dataset: Arc) -> Self { + Self { + inner: Arc::new(TokioMutex::new(Some(writer))), + closed_state: Arc::new(TokioMutex::new(None)), + region_id, + dataset, + } + } +} + +/// Python wrapper around a DataFusion physical execution plan. +#[pyclass(name = "_ExecutionPlan", module = "_lib")] +#[derive(Clone)] +pub struct PyExecutionPlan { + plan: Arc, + dataset_schema: Arc, +} + +impl PyExecutionPlan { + pub fn new(plan: Arc, dataset_schema: Arc) -> Self { + Self { + plan, + dataset_schema, + } + } +} + +#[pymethods] +impl PyExecutionPlan { + #[getter] + fn schema<'py>(&self, py: Python<'py>) -> PyResult> { + self.plan.schema().to_pyarrow(py) + } + + #[getter] + fn dataset_schema<'py>(&self, py: Python<'py>) -> PyResult> { + self.dataset_schema.to_pyarrow(py) + } + + fn explain(&self) -> String { + format!("{}", displayable(self.plan.as_ref()).indent(true)) + } + + fn to_batches<'py>(&self, py: Python<'py>) -> PyResult> { + let plan = self.plan.clone(); + let batches = rt() + .block_on(Some(py), async move { + let ctx = SessionContext::new(); + collect(plan, ctx.task_ctx()).await + })? + .map_err(|e| PyIOError::new_err(format!("Plan execution failed: {}", e)))?; + + let py_batches: Vec> = batches + .into_iter() + .map(|batch| { + PyArrowType(batch) + .into_pyobject(py) + .map(|batch| batch.into_any()) + .map_err(|e| PyIOError::new_err(e.to_string())) + }) + .collect::>()?; + PyList::new(py, py_batches) + } + + fn to_reader<'py>(&self, py: Python<'py>) -> PyResult> { + let plan = self.plan.clone(); + let batches = rt() + .block_on(Some(py), async move { + let ctx = SessionContext::new(); + collect(plan, ctx.task_ctx()).await + })? + .map_err(|e| PyIOError::new_err(format!("Plan execution failed: {}", e)))?; + + let schema = self.plan.schema().clone(); + let reader: Box = Box::new(RecordBatchIterator::new( + batches.into_iter().map(Ok), + schema, + )); + reader.into_pyarrow(py) + } +} + +/// LSM-aware scanner covering base table, flushed MemTables, and active MemTable. +/// +/// Provides deduplication by primary key, always returning the newest version +/// of each row across all LSM levels. +#[pyclass(name = "_LsmScanner", module = "_lib")] +pub struct PyLsmScanner { + inner: Option, +} + +#[pymethods] +impl PyLsmScanner { + /// Create a scanner from dataset and region snapshots (without active MemTable). + #[staticmethod] + pub fn from_snapshots( + dataset: &Bound<'_, PyDataset>, + region_snapshots: Vec>, + ) -> PyResult { + let ds = dataset.borrow().ds.clone(); + let snapshots: Vec = region_snapshots + .iter() + .map(|s| s.borrow().inner.clone()) + .collect(); + let pk_columns = get_pk_columns(&ds)?; + Ok(Self { + inner: Some(LsmScanner::new(ds, snapshots, pk_columns)), + }) + } + + /// Select specific columns to return. + pub fn project( + mut slf: PyRefMut<'_, Self>, + columns: Vec, + ) -> PyResult> { + let scanner = slf + .inner + .take() + .ok_or_else(|| PyRuntimeError::new_err("Scanner has already been consumed"))?; + let cols: Vec<&str> = columns.iter().map(|s| s.as_str()).collect(); + slf.inner = Some(scanner.project(&cols)); + Ok(slf) + } + + /// Set a SQL filter expression. + pub fn filter(mut slf: PyRefMut<'_, Self>, expr: String) -> PyResult> { + let scanner = slf + .inner + .take() + .ok_or_else(|| PyRuntimeError::new_err("Scanner has already been consumed"))?; + slf.inner = Some( + scanner + .filter(&expr) + .map_err(|e| PyValueError::new_err(e.to_string()))?, + ); + Ok(slf) + } + + /// Limit the number of rows returned. + #[pyo3(signature = (n, offset=None))] + pub fn limit( + mut slf: PyRefMut<'_, Self>, + n: usize, + offset: Option, + ) -> PyResult> { + let scanner = slf + .inner + .take() + .ok_or_else(|| PyRuntimeError::new_err("Scanner has already been consumed"))?; + slf.inner = Some(scanner.limit(n, offset)); + Ok(slf) + } + + /// Include the `_rowaddr` internal column in results. + pub fn with_row_address(mut slf: PyRefMut<'_, Self>) -> PyResult> { + let scanner = slf + .inner + .take() + .ok_or_else(|| PyRuntimeError::new_err("Scanner has already been consumed"))?; + slf.inner = Some(scanner.with_row_address()); + Ok(slf) + } + + /// Include the `_memtable_gen` internal column in results. + pub fn with_memtable_gen(mut slf: PyRefMut<'_, Self>) -> PyResult> { + let scanner = slf + .inner + .take() + .ok_or_else(|| PyRuntimeError::new_err("Scanner has already been consumed"))?; + slf.inner = Some(scanner.with_memtable_gen()); + Ok(slf) + } + + /// Execute the scan and return a single PyArrow RecordBatch. + pub fn to_batch<'py>(&self, py: Python<'py>) -> PyResult> { + let scanner = self + .inner + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("Scanner has already been consumed"))?; + let batch = rt() + .block_on(Some(py), scanner.try_into_batch()) + .map_err(|e| PyIOError::new_err(e.to_string()))? + .map_err(|e| PyIOError::new_err(e.to_string()))?; + PyArrowType(batch).into_pyobject(py).map(|b| b.into_any()) + } + + /// Execute the scan and return all batches as a Python list. + pub fn to_batches<'py>(&self, py: Python<'py>) -> PyResult> { + let scanner = self + .inner + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("Scanner has already been consumed"))?; + let stream = rt() + .block_on(Some(py), scanner.try_into_stream()) + .map_err(|e| PyIOError::new_err(e.to_string()))? + .map_err(|e| PyIOError::new_err(e.to_string()))?; + let batches: Vec = rt() + .block_on(Some(py), stream.try_collect()) + .map_err(|e| PyIOError::new_err(e.to_string()))? + .map_err(|e| PyIOError::new_err(e.to_string()))?; + + let py_batches: Vec> = batches + .into_iter() + .map(|b| { + PyArrowType(b) + .into_pyobject(py) + .map(|b| b.into_any()) + .map_err(|e| PyIOError::new_err(e.to_string())) + }) + .collect::>()?; + PyList::new(py, py_batches) + } + + /// Return the row count without loading all data. + pub fn count_rows(&self, py: Python<'_>) -> PyResult { + let scanner = self + .inner + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("Scanner has already been consumed"))?; + rt().block_on(Some(py), scanner.count_rows()) + .map_err(|e| PyIOError::new_err(e.to_string()))? + .map_err(|e| PyIOError::new_err(e.to_string())) + } +} + +/// Plans and executes primary key point lookups across all LSM levels. +/// +/// More efficient than `_LsmScanner` for known-PK lookups due to bloom filter +/// optimizations and short-circuit evaluation. +#[pyclass(name = "_LsmPointLookupPlanner", module = "_lib")] +pub struct PyLsmPointLookupPlanner { + planner: LsmPointLookupPlanner, + dataset_schema: Arc, + pk_columns: Vec, +} + +#[pymethods] +impl PyLsmPointLookupPlanner { + #[new] + #[pyo3(signature = (dataset, region_snapshots, pk_columns=None))] + pub fn new( + dataset: &Bound<'_, PyDataset>, + region_snapshots: Vec>, + pk_columns: Option>, + ) -> PyResult { + let ds = dataset.borrow().ds.clone(); + let snapshots: Vec = region_snapshots + .iter() + .map(|s| s.borrow().inner.clone()) + .collect(); + let pk_cols = match pk_columns { + Some(cols) => cols, + None => get_pk_columns(&ds)?, + }; + let base_schema = Arc::new(ArrowSchema::from(ds.schema())); + let collector = LsmDataSourceCollector::new(ds.clone(), snapshots); + let planner = LsmPointLookupPlanner::new(collector, pk_cols.clone(), base_schema.clone()); + Ok(Self { + planner, + dataset_schema: base_schema, + pk_columns: pk_cols, + }) + } + + /// Plan a single-row point lookup by primary key. + /// + /// For single-column primary keys, `pk_value` should be a PyArrow array + /// with exactly one element. For composite primary keys, `pk_value` must + /// be a StructArray with exactly one row and one field per PK column. + #[pyo3(signature = (pk_value, columns=None))] + pub fn plan_lookup( + &self, + py: Python<'_>, + pk_value: PyArrowType, + columns: Option>, + ) -> PyResult { + let array = make_array(pk_value.0); + let pk_values = scalar_values_from_pk_value(array.as_ref(), &self.pk_columns)?; + let proj: Option> = columns; + let planner_ref = &self.planner; + let plan = rt() + .block_on(Some(py), async { + planner_ref.plan_lookup(&pk_values, proj.as_deref()).await + })? + .map_err(|e| PyIOError::new_err(e.to_string()))?; + + Ok(PyExecutionPlan::new(plan, self.dataset_schema.clone())) + } +} + +/// Plans and executes vector KNN search across all LSM levels. +/// +/// Only supports IVF-PQ vector indexes maintained in MemWAL. +/// Results include staleness filtering to return only the latest version of each row. +#[pyclass(name = "_LsmVectorSearchPlanner", module = "_lib")] +pub struct PyLsmVectorSearchPlanner { + planner: LsmVectorSearchPlanner, + vector_dim: usize, + dataset_schema: Arc, +} + +#[pymethods] +impl PyLsmVectorSearchPlanner { + #[new] + #[pyo3(signature = (dataset, region_snapshots, vector_column, pk_columns=None, distance_type=None))] + pub fn new( + dataset: &Bound<'_, PyDataset>, + region_snapshots: Vec>, + vector_column: String, + pk_columns: Option>, + distance_type: Option, + ) -> PyResult { + let ds = dataset.borrow().ds.clone(); + let snapshots: Vec = region_snapshots + .iter() + .map(|s| s.borrow().inner.clone()) + .collect(); + let pk_cols = match pk_columns { + Some(cols) => cols, + None => get_pk_columns(&ds)?, + }; + let base_schema = Arc::new(ArrowSchema::from(ds.schema())); + + let dist_type = parse_distance_type(distance_type.as_deref().unwrap_or("l2"))?; + + let vector_dim = get_vector_dim(&ds, &vector_column)?; + + let collector = LsmDataSourceCollector::new(ds, snapshots); + let planner = LsmVectorSearchPlanner::new( + collector, + pk_cols, + base_schema.clone(), + vector_column, + dist_type, + ); + + Ok(Self { + planner, + vector_dim, + dataset_schema: base_schema, + }) + } + + /// Plan a KNN vector search. + /// + /// `query` should be a flat PyArrow Float32Array with `vector_dim` elements. + #[pyo3(signature = (query, k=10, nprobes=20, columns=None))] + pub fn plan_search( + &self, + py: Python<'_>, + query: PyArrowType, + k: usize, + nprobes: usize, + columns: Option>, + ) -> PyResult { + let query_array = make_array(query.0); + let float32_array = query_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + PyValueError::new_err( + "query must be a Float32Array. Use pa.array(values, type=pa.float32())", + ) + })?; + + if float32_array.len() != self.vector_dim { + return Err(PyValueError::new_err(format!( + "Query vector has {} dimensions, expected {}", + float32_array.len(), + self.vector_dim + ))); + } + + // Wrap the flat array into a FixedSizeListArray with one row + let field = Arc::new(Field::new("item", DataType::Float32, true)); + let fsl = FixedSizeListArray::try_new( + field, + self.vector_dim as i32, + Arc::new(float32_array.clone()), + None, + ) + .map_err(|e| PyValueError::new_err(format!("Cannot create query vector: {}", e)))?; + + let planner_ref = &self.planner; + let plan = rt() + .block_on(Some(py), async { + planner_ref + .plan_search(&fsl, k, nprobes, columns.as_deref()) + .await + })? + .map_err(|e| PyIOError::new_err(e.to_string()))?; + + Ok(PyExecutionPlan::new(plan, self.dataset_schema.clone())) + } +} + +/// Extract primary key column names from dataset schema. +pub(crate) fn get_pk_columns(ds: &LanceDataset) -> PyResult> { + let pk_fields = ds.schema().unenforced_primary_key(); + if pk_fields.is_empty() { + return Err(PyValueError::new_err( + "Dataset has no primary key. Set 'lance-schema:unenforced-primary-key' metadata \ + on the primary key field(s).", + )); + } + Ok(pk_fields.iter().map(|f| f.name.clone()).collect()) +} + +/// Parse distance type string to DistanceType enum. +fn parse_distance_type(s: &str) -> PyResult { + match s.to_lowercase().as_str() { + "l2" | "euclidean" => Ok(DistanceType::L2), + "cosine" => Ok(DistanceType::Cosine), + "dot" | "inner_product" => Ok(DistanceType::Dot), + "hamming" => Ok(DistanceType::Hamming), + _ => Err(PyValueError::new_err(format!( + "Unknown distance_type '{}'. Valid values: 'l2', 'cosine', 'dot', 'hamming'", + s + ))), + } +} + +/// Get the vector dimension from the dataset schema for a given column. +fn get_vector_dim(ds: &LanceDataset, column: &str) -> PyResult { + let schema = ArrowSchema::from(ds.schema()); + let field = schema.field_with_name(column).map_err(|_| { + PyValueError::new_err(format!("Column '{}' not found in dataset schema", column)) + })?; + match field.data_type() { + DataType::FixedSizeList(_, size) => Ok(*size as usize), + other => Err(PyValueError::new_err(format!( + "Column '{}' is not a FixedSizeList (got {:?}). \ + Vector columns must be FixedSizeList.", + column, other + ))), + } +} + +fn write_stats_to_pydict(py: Python<'_>, stats: &WriteStatsSnapshot) -> PyResult> { + let dict = PyDict::new(py); + dict.set_item("put_count", stats.put_count)?; + dict.set_item("put_time_ms", stats.put_time.as_millis() as u64)?; + dict.set_item("wal_flush_count", stats.wal_flush_count)?; + dict.set_item("wal_flush_bytes", stats.wal_flush_bytes)?; + dict.set_item("wal_flush_time_ms", stats.wal_flush_time.as_millis() as u64)?; + dict.set_item("memtable_flush_count", stats.memtable_flush_count)?; + dict.set_item("memtable_flush_rows", stats.memtable_flush_rows)?; + dict.set_item( + "memtable_flush_time_ms", + stats.memtable_flush_time.as_millis() as u64, + )?; + Ok(dict.into_any().unbind()) +} + +fn memtable_stats_to_pydict(py: Python<'_>, stats: &MemTableStats) -> PyResult> { + let dict = PyDict::new(py); + dict.set_item("row_count", stats.row_count)?; + dict.set_item("batch_count", stats.batch_count)?; + dict.set_item("estimated_size_bytes", stats.estimated_size)?; + dict.set_item("generation", stats.generation)?; + Ok(dict.into_any().unbind()) +} + +fn scalar_values_from_pk_value( + pk_value: &dyn Array, + pk_columns: &[String], +) -> PyResult> { + if pk_value.len() != 1 { + return Err(PyValueError::new_err(format!( + "pk_value must contain exactly one row, got {}", + pk_value.len() + ))); + } + + if pk_columns.len() == 1 { + let scalar = ScalarValue::try_from_array(pk_value, 0) + .map_err(|e| PyValueError::new_err(format!("Cannot convert pk_value: {}", e)))?; + return Ok(vec![scalar]); + } + + let struct_array = pk_value.as_any().downcast_ref::().ok_or_else(|| { + PyValueError::new_err(format!( + "Composite primary key lookup requires a StructArray with exactly one row and {} fields", + pk_columns.len() + )) + })?; + + if struct_array.num_columns() != pk_columns.len() { + return Err(PyValueError::new_err(format!( + "Composite primary key lookup expected {} fields, got {}", + pk_columns.len(), + struct_array.num_columns() + ))); + } + + let mut pk_values = Vec::with_capacity(pk_columns.len()); + for column_name in pk_columns { + let column = struct_array.column_by_name(column_name).ok_or_else(|| { + PyValueError::new_err(format!( + "Composite primary key lookup requires field '{}' in pk_value", + column_name + )) + })?; + let scalar = ScalarValue::try_from_array(column.as_ref(), 0).map_err(|e| { + PyValueError::new_err(format!("Cannot convert composite pk_value: {}", e)) + })?; + pk_values.push(scalar); + } + Ok(pk_values) +} + +fn closed_memtable_stats(stats_before_close: MemTableStats) -> MemTableStats { + if stats_before_close.batch_count == 0 { + return stats_before_close; + } + + MemTableStats { + row_count: 0, + batch_count: 0, + estimated_size: 0, + generation: stats_before_close.generation.saturating_add(1), + } +}