Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
332 changes: 299 additions & 33 deletions rust/lance/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, OnceLock};

use arrow_schema::{DataType, Schema};
use arrow_schema::DataType;
use async_trait::async_trait;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::{FutureExt, stream};
use futures::FutureExt;
use itertools::Itertools;
use lance_core::cache::{CacheKey, UnsizedCacheKey};
use lance_core::datatypes::Field;
Expand Down Expand Up @@ -85,6 +84,7 @@ use crate::index::frag_reuse::{load_frag_reuse_index_details, open_frag_reuse_in
use crate::index::mem_wal::open_mem_wal_index;
pub use crate::index::prefilter::{FilterLoader, PreFilter};
use crate::index::scalar::{IndexDetails, fetch_index_details, load_training_data};
pub use crate::index::vector::{LogicalIvfView, LogicalVectorIndex};
use crate::session::index_caches::{FragReuseIndexKey, IndexMetadataKey};
use crate::{Error, Result, dataset::Dataset};
pub use create::CreateIndexBuilder;
Expand Down Expand Up @@ -1076,34 +1076,13 @@ impl DatasetIndexExt for Dataset {
return Err(Error::index_not_found(format!("name={}", index_name)));
}
let column = self.schema().field_by_id(indices[0].fields[0]).unwrap();

let mut schema: Option<Arc<Schema>> = None;
let mut partition_streams = Vec::with_capacity(indices.len());
for index in indices {
let index = self
.open_vector_index(&column.name, &index.uuid.to_string(), &NoOpMetricsCollector)
.await?;

let stream = index
.partition_reader(partition_id, with_vector, &NoOpMetricsCollector)
.await?;
if schema.is_none() {
schema = Some(stream.schema());
}
partition_streams.push(stream);
}

match schema {
Some(schema) => {
let merged = stream::select_all(partition_streams);
let stream = RecordBatchStreamAdapter::new(schema, merged);
Ok(Box::pin(stream))
}
None => Ok(Box::pin(RecordBatchStreamAdapter::new(
Arc::new(Schema::empty()),
stream::empty(),
))),
}
let logical_index = self
.open_logical_vector_index(&column.name, index_name)
.await?;
logical_index
.as_ivf()?
.read_partition(partition_id, with_vector)
.await
}
}

Expand Down Expand Up @@ -1363,6 +1342,12 @@ pub trait DatasetIndexInternalExt: DatasetIndexExt {
uuid: &str,
metrics: &dyn MetricsCollector,
) -> Result<Arc<dyn VectorIndex>>;
/// Opens all segments for one logical vector index and returns a materialized snapshot.
async fn open_logical_vector_index(
&self,
column: &str,
name: &str,
) -> Result<LogicalVectorIndex>;

/// Opens the fragment reuse index
async fn open_frag_reuse_index(
Expand Down Expand Up @@ -1746,6 +1731,38 @@ impl DatasetIndexInternalExt for Dataset {
Ok(index)
}

async fn open_logical_vector_index(
&self,
column: &str,
name: &str,
) -> Result<LogicalVectorIndex> {
let metadatas = self.load_indices_by_name(name).await?;
if metadatas.is_empty() {
return Err(Error::index_not_found(format!("name={name}")));
}

let field_id = self.schema().field_id(column)?;
if let Some(invalid_metadata) = metadatas
.iter()
.find(|metadata| !metadata.fields.contains(&field_id))
{
return Err(Error::invalid_input(format!(
"Logical vector index '{}' contains segment {} that does not belong to column '{}'",
name, invalid_metadata.uuid, column
)));
}

let mut segments = Vec::with_capacity(metadatas.len());
for metadata in metadatas {
let index = self
.open_vector_index(column, &metadata.uuid.to_string(), &NoOpMetricsCollector)
.await?;
segments.push((metadata, index));
}

LogicalVectorIndex::try_new(name.to_string(), column.to_string(), segments)
}

async fn open_frag_reuse_index(
&self,
metrics: &dyn MetricsCollector,
Expand Down Expand Up @@ -2127,13 +2144,16 @@ mod tests {
BuiltinIndexType, FullTextSearchQuery, InvertedIndexParams, ScalarIndexParams,
};
use lance_index::vector::{
hnsw::builder::HnswBuildParams, ivf::IvfBuildParams, sq::builder::SQBuildParams,
hnsw::builder::HnswBuildParams,
ivf::IvfBuildParams,
kmeans::{KMeansParams, train_kmeans},
sq::builder::SQBuildParams,
};
use lance_io::{assert_io_eq, assert_io_lt};
use lance_linalg::distance::{DistanceType, MetricType};
use lance_testing::datagen::generate_random_array;
use rstest::rstest;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};

async fn write_vector_segment_metadata(
dataset: &Dataset,
Expand Down Expand Up @@ -2169,6 +2189,252 @@ mod tests {
}
}

async fn write_fragmented_vector_dataset(uri: &str, dimension: i32) -> Dataset {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dimension,
),
false,
),
]));
let batches = (0..5)
.map(|i| {
let vector_values: Float32Array = (0..dimension * 80)
.map(|value| value as f32 + (i * 1000) as f32)
.collect();
let vectors =
FixedSizeListArray::try_new_from_values(vector_values, dimension).unwrap();
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(i * 80..(i + 1) * 80)),
Arc::new(vectors),
],
)
})
.collect::<std::result::Result<Vec<_>, arrow_schema::ArrowError>>()
.unwrap();
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
Dataset::write(
reader,
uri,
Some(WriteParams {
max_rows_per_group: 10,
max_rows_per_file: 80,
..Default::default()
}),
)
.await
.unwrap()
}

async fn create_segmented_vector_index(
dataset: &mut Dataset,
index_name: &str,
column: &str,
dimension: i32,
) -> Vec<Uuid> {
let batch = dataset
.scan()
.project(&[column])
.unwrap()
.try_into_batch()
.await
.unwrap();
let vectors = batch
.column_by_name(column)
.expect("vector column should exist")
.as_fixed_size_list();
let values = vectors.values().as_primitive::<Float32Type>();
let centroids = train_kmeans::<Float32Type>(
values,
KMeansParams::new(None, 10, 1, DistanceType::L2),
dimension as usize,
2,
2,
)
.unwrap()
.centroids
.as_primitive::<Float32Type>()
.clone();
let centroids =
Arc::new(FixedSizeListArray::try_new_from_values(centroids, dimension).unwrap());
let params = VectorIndexParams::with_ivf_flat_params(
DistanceType::L2,
IvfBuildParams::try_with_centroids(2, centroids).unwrap(),
);
let fragment_ids = dataset
.get_fragments()
.iter()
.map(|fragment| fragment.id() as u32)
.collect::<Vec<_>>();
let columns = [column];

let mut segments = Vec::with_capacity(fragment_ids.len());
for fragment_id in fragment_ids {
let mut builder = dataset.create_index_builder(&columns, IndexType::Vector, &params);
builder = builder
.name(index_name.to_string())
.fragments(vec![fragment_id]);
segments.push(builder.execute_uncommitted().await.unwrap());
}

let segment_ids = segments
.iter()
.map(|segment| segment.uuid)
.collect::<Vec<_>>();
dataset
.commit_existing_index_segments(index_name, column, segments)
.await
.unwrap();
segment_ids
}

#[tokio::test]
async fn test_open_logical_vector_index_single_segment_quality_apis() {
const DIMENSION: i32 = 8;

let test_dir = tempfile::tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
let mut dataset = write_fragmented_vector_dataset(test_uri, DIMENSION).await;
let params =
VectorIndexParams::with_ivf_flat_params(DistanceType::L2, IvfBuildParams::new(2));

dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("vector_idx".to_string()),
&params,
true,
)
.await
.unwrap();

let logical_index = dataset
.open_logical_vector_index("vector", "vector_idx")
.await
.unwrap();

assert_eq!(logical_index.name(), "vector_idx");
assert_eq!(logical_index.column(), "vector");
assert_eq!(logical_index.num_segments(), 1);
assert_eq!(logical_index.metadatas().len(), 1);

let rows_per_segment = logical_index.num_rows_per_segment();
assert_eq!(rows_per_segment.len(), 1);
assert_eq!(rows_per_segment[0].1, 400);

let ivf_view = logical_index.as_ivf().unwrap();
let partitions_per_segment = ivf_view.num_partitions_per_segment();
assert_eq!(partitions_per_segment, vec![(rows_per_segment[0].0, 2)]);

let partition_sizes = ivf_view.partition_sizes();
assert_eq!(partition_sizes.len(), 1);
assert_eq!(partition_sizes[0].1.len(), 2);
assert_eq!(
partition_sizes[0].1.iter().sum::<usize>(),
rows_per_segment[0].1 as usize
);
}

#[tokio::test]
async fn test_open_logical_vector_index_segmented_quality_apis() {
const DIMENSION: i32 = 8;

let test_dir = tempfile::tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
let mut dataset = write_fragmented_vector_dataset(test_uri, DIMENSION).await;
let segment_ids =
create_segmented_vector_index(&mut dataset, "vector_idx", "vector", DIMENSION).await;

let logical_index = dataset
.open_logical_vector_index("vector", "vector_idx")
.await
.unwrap();

assert_eq!(logical_index.name(), "vector_idx");
assert_eq!(logical_index.column(), "vector");
assert_eq!(logical_index.num_segments(), segment_ids.len());

let metadata_ids = logical_index
.metadatas()
.map(|metadata| metadata.uuid)
.collect::<HashSet<_>>();
assert_eq!(
metadata_ids,
segment_ids.into_iter().collect::<HashSet<_>>()
);

let rows_per_segment = logical_index.num_rows_per_segment();
assert_eq!(rows_per_segment.len(), logical_index.num_segments());
assert_eq!(
rows_per_segment
.iter()
.map(|(_, num_rows)| *num_rows)
.sum::<u64>(),
400
);
assert!(
rows_per_segment.iter().all(|(_, num_rows)| *num_rows > 0),
"each segment should contain indexed rows"
);

let ivf_view = logical_index.as_ivf().unwrap();
let partitions_per_segment = ivf_view.num_partitions_per_segment();
assert!(
partitions_per_segment
.iter()
.all(|(_, num_partitions)| *num_partitions == 2)
);

let row_count_by_segment = rows_per_segment.into_iter().collect::<HashMap<_, _>>();
let partition_sizes = ivf_view.partition_sizes();
assert_eq!(partition_sizes.len(), logical_index.num_segments());
for (segment_id, sizes) in partition_sizes {
assert_eq!(sizes.len(), 2);
assert_eq!(
sizes.iter().sum::<usize>(),
row_count_by_segment[&segment_id] as usize
);
}
}

#[tokio::test]
async fn test_open_logical_vector_index_rejects_wrong_column() {
const DIMENSION: i32 = 8;

let test_dir = tempfile::tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
let mut dataset = write_fragmented_vector_dataset(test_uri, DIMENSION).await;
let params =
VectorIndexParams::with_ivf_flat_params(DistanceType::L2, IvfBuildParams::new(2));

dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("vector_idx".to_string()),
&params,
true,
)
.await
.unwrap();

let err = dataset
.open_logical_vector_index("id", "vector_idx")
.await
.unwrap_err();
assert!(
err.to_string().contains("does not belong to column 'id'"),
"unexpected error: {err}"
);
}

#[tokio::test]
async fn test_recreate_index() {
const DIM: i32 = 8;
Expand Down
Loading
Loading