diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 48ac7dbd8a..b7f5094ac3 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashSet; use std::ops::Range; use std::pin::Pin; use std::sync::{Arc, LazyLock}; @@ -74,6 +75,7 @@ use lance_linalg::distance::MetricType; use lance_table::format::{Fragment, IndexMetadata}; use roaring::RoaringBitmap; use tracing::{Span, info_span, instrument}; +use uuid::Uuid; use super::Dataset; use crate::dataset::row_offsets_to_row_addresses; @@ -761,6 +763,9 @@ pub struct Scanner { /// If set, this scanner serves only these fragments. fragments: Option>, + /// If set, this scanner will only search the specified vector index segments. + index_segments: Option>, + /// Only search the data being indexed (weak consistency search). /// /// Default value is false. @@ -994,6 +999,7 @@ impl Scanner { use_stats: true, ordered: true, fragments: None, + index_segments: None, fast_search: false, use_scalar_index: true, include_deleted_rows: false, @@ -1040,6 +1046,23 @@ impl Scanner { self } + /// Restrict vector index search to the specified index segments. + /// + /// This setting is only supported for vector search. + /// + /// If [`Self::with_fragments`] is also set then rows from those fragments that are not covered + /// by the selected index segments will still be searched with flat KNN. Otherwise, unindexed + /// fragments outside the selected index segments are not searched. + pub fn with_index_segments(&mut self, segments: Vec) -> Result<&mut Self> { + if segments.is_empty() { + return Err(Error::invalid_input( + "with_index_segments does not accept an empty segment list".to_string(), + )); + } + self.index_segments = Some(segments); + Ok(self) + } + fn get_batch_size(&self) -> usize { // Default batch size to be large enough so that a i32 column can be // read in a single range request. For the object store default of @@ -2167,6 +2190,12 @@ impl Scanner { } } + if self.index_segments.is_some() && self.nearest.is_none() { + return Err(Error::not_supported( + "with_index_segments is only supported for vector search".to_string(), + )); + } + Ok(()) } @@ -3421,55 +3450,127 @@ impl Scanner { } else { Arc::new(vec![]) }; - // Find an index for the column and check if metric is compatible - let matching_index = if let Some(index) = - indices.iter().find(|i| i.fields.contains(&column_id)) - { - // TODO: Once we do https://github.com/lance-format/lance/issues/5231, we - // should be able to get the metric type directly from the index metadata, - // at least for newer indexes. - let idx = self - .dataset - .open_vector_index( - q.column.as_str(), - &index.uuid.to_string(), - &NoOpMetricsCollector, - ) - .await?; - let index_metric = idx.metric_type(); + let index_and_segments = if use_index { + if let Some(requested_segments) = self.index_segments.as_ref() { + let requested_segment_set = + requested_segments.iter().copied().collect::>(); + let requested_index_segments = indices + .iter() + .filter(|idx| requested_segment_set.contains(&idx.uuid)) + .cloned() + .collect::>(); + + if requested_index_segments.len() != requested_segment_set.len() { + let found_segment_set = requested_index_segments + .iter() + .map(|idx| idx.uuid) + .collect::>(); + let missing_segments = requested_segment_set + .difference(&found_segment_set) + .map(ToString::to_string) + .collect::>(); + return Err(Error::invalid_input(format!( + "with_index_segments referenced unknown index segments: {missing_segments:?}", + ))); + } - // Check if user's requested metric is compatible with index - let use_this_index = match q.metric_type { - Some(user_metric) => { - if user_metric == index_metric { - true + if requested_index_segments + .iter() + .any(|idx| !idx.fields.contains(&column_id)) + { + return Err(Error::invalid_input(format!( + "with_index_segments contained a segment that does not belong to vector column '{}'", + q.column + ))); + } + + let index_name = requested_index_segments[0].name.clone(); + if requested_index_segments + .iter() + .any(|idx| idx.name != index_name) + { + return Err(Error::invalid_input( + "with_index_segments must reference segments from a single logical index" + .to_string(), + )); + } + + let selected_index_segments = + self.retain_relevant_index_segments(requested_index_segments); + if selected_index_segments.is_empty() { + None + } else { + let idx = self + .dataset + .open_vector_index( + q.column.as_str(), + &selected_index_segments[0].uuid.to_string(), + &NoOpMetricsCollector, + ) + .await?; + let index_metric = idx.metric_type(); + let use_this_index = match q.metric_type { + Some(user_metric) => { + if user_metric == index_metric { + true + } else { + return Err(Error::invalid_input(format!( + "with_index_segments requested metric {:?} but the selected index segments use {:?}", + user_metric, index_metric + ))); + } + } + None => true, + }; + if use_this_index { + Some((index_name, selected_index_segments, index_metric)) } else { - log::warn!( - "Requested metric {:?} is incompatible with index metric {:?}, falling back to brute-force search", - user_metric, - index_metric - ); - false + None } } - None => true, // No preference, use index's metric - }; + } else if let Some(index) = indices.iter().find(|i| i.fields.contains(&column_id)) { + // TODO: Once we do https://github.com/lance-format/lance/issues/5231, we + // should be able to get the metric type directly from the index metadata, + // at least for newer indexes. + let idx = self + .dataset + .open_vector_index( + q.column.as_str(), + &index.uuid.to_string(), + &NoOpMetricsCollector, + ) + .await?; + let index_metric = idx.metric_type(); - if use_this_index { - Some((index, idx, index_metric)) - } else { - None - } - } else { - None - }; + let use_this_index = match q.metric_type { + Some(user_metric) => { + if user_metric == index_metric { + true + } else { + log::warn!( + "Requested metric {:?} is incompatible with index metric {:?}, falling back to brute-force search", + user_metric, + index_metric + ); + false + } + } + None => true, + }; - // Only return index and deltas if there is an index on the column and at least one of the target fragments are indexed - let index_and_deltas = if let Some((index, _idx, index_metric)) = matching_index { - let deltas = self.dataset.load_indices_by_name(&index.name).await?; - let index_frags = self.get_indexed_frags(&deltas); - if !index_frags.is_empty() { - Some((index, deltas, index_metric)) + if use_this_index { + let index_segments = self.retain_relevant_index_segments( + self.dataset.load_indices_by_name(&index.name).await?, + ); + let index_frags = self.get_indexed_frags(&index_segments); + if !index_segments.is_empty() && !index_frags.is_empty() { + Some((index.name.clone(), index_segments, index_metric)) + } else { + None + } + } else { + None + } } else { None } @@ -3477,7 +3578,7 @@ impl Scanner { None }; - if let Some((index, deltas, index_metric)) = index_and_deltas { + if let Some((index_name, index_segments, index_metric)) = index_and_segments { log::trace!("index found for vector search"); // Use the index's metric type q.metric_type = Some(index_metric); @@ -3489,8 +3590,8 @@ impl Scanner { )); } let ann_node = match vector_type { - DataType::FixedSizeList(_, _) => self.ann(&q, &deltas, filter_plan).await?, - DataType::List(_) => self.multivec_ann(&q, &deltas, filter_plan).await?, + DataType::FixedSizeList(_, _) => self.ann(&q, &index_segments, filter_plan).await?, + DataType::List(_) => self.multivec_ann(&q, &index_segments, filter_plan).await?, _ => unreachable!(), }; @@ -3507,7 +3608,9 @@ impl Scanner { }; // vector, _distance, _rowid if !self.fast_search { - knn_node = self.knn_combined(&q, index, knn_node, filter_plan).await?; + knn_node = self + .knn_combined(&q, &index_name, &index_segments, knn_node, filter_plan) + .await?; } Ok(knn_node) @@ -3557,27 +3660,27 @@ impl Scanner { async fn knn_combined( &self, q: &Query, - index: &IndexMetadata, + index_name: &str, + indexed_segments: &[IndexMetadata], mut knn_node: Arc, filter_plan: &ExprFilterPlan, ) -> Result> { - // Get unindexed fragments and filter to target fragments - let unindexed_fragments = - self.retain_target_fragments(self.dataset.unindexed_fragments(&index.name).await?); - - if !unindexed_fragments.is_empty() { - // need to set the metric type to be the same as the index - // to make sure the distance is comparable. - let idx = self - .dataset - .open_vector_index( - q.column.as_str(), - &index.uuid.to_string(), - &NoOpMetricsCollector, - ) - .await?; - let mut q = q.clone(); - q.metric_type = Some(idx.metric_type()); + let fallback_fragments = if let Some(target_fragments) = &self.fragments { + let indexed_fragments = self.get_indexed_frags(indexed_segments); + target_fragments + .iter() + .filter(|fragment| !indexed_fragments.contains(fragment.id as u32)) + .cloned() + .collect::>() + } else if self.index_segments.is_some() { + Vec::new() + } else { + self.dataset.unindexed_fragments(index_name).await? + }; + + if !fallback_fragments.is_empty() { + let q = q.clone(); + debug_assert!(q.metric_type.is_some()); // If the vector column is not present, we need to take the vector column, so // that the distance value is comparable with the flat search ones. @@ -3606,7 +3709,7 @@ impl Scanner { false, false, vector_scan_projection, - Arc::new(unindexed_fragments), + Arc::new(fallback_fragments), // Can't pushdown limit/offset in an ANN search None, // We are re-ordering anyways, so no need to get data in data @@ -4176,6 +4279,25 @@ impl Scanner { } } + fn retain_relevant_index_segments( + &self, + index_segments: Vec, + ) -> Vec { + if let Some(fragments) = &self.fragments { + let target_fragments = RoaringBitmap::from_iter(fragments.iter().map(|f| f.id as u32)); + index_segments + .into_iter() + .filter(|idx| { + idx.fragment_bitmap + .as_ref() + .is_some_and(|fragmap| !(fragmap & &target_fragments).is_empty()) + }) + .collect() + } else { + index_segments + } + } + /// Retain only fragments that are in the user-specified fragment list. /// If no fragment list is specified, returns the fragments unchanged. fn retain_target_fragments(&self, mut fragments: Vec) -> Vec { @@ -4543,6 +4665,7 @@ pub mod test_dataset { use arrow_array::{ ArrayRef, FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator, StringArray, + types::Float32Type, }; use arrow_schema::{ArrowError, DataType}; use lance_arrow::FixedSizeListArrayExt; @@ -4551,7 +4674,13 @@ pub mod test_dataset { use lance_index::{ IndexType, scalar::{ScalarIndexParams, inverted::tokenizer::InvertedIndexParams}, + vector::{ + ivf::IvfBuildParams, + kmeans::{KMeansParams, train_kmeans}, + }, }; + use lance_linalg::distance::DistanceType; + use uuid::Uuid; use crate::dataset::WriteParams; use crate::index::vector::VectorIndexParams; @@ -4662,6 +4791,63 @@ pub mod test_dataset { Ok(()) } + pub async fn make_segmented_vector_index(&mut self) -> Result> { + let batch = self + .dataset + .scan() + .project(&["vec"]) + .unwrap() + .try_into_batch() + .await?; + let vectors = batch + .column_by_name("vec") + .expect("vector column should exist") + .as_fixed_size_list(); + let values = vectors.values().as_primitive::(); + let centroids = train_kmeans::( + values, + KMeansParams::new(None, 10, 1, DistanceType::L2), + self.dimension as usize, + 2, + 2, + ) + .unwrap() + .centroids + .as_primitive::() + .clone(); + let centroids = Arc::new( + FixedSizeListArray::try_new_from_values(centroids, self.dimension as i32).unwrap(), + ); + let params = VectorIndexParams::with_ivf_flat_params( + DistanceType::L2, + IvfBuildParams::try_with_centroids(2, centroids).unwrap(), + ); + let fragment_ids = self + .dataset + .get_fragments() + .iter() + .map(|fragment| fragment.id() as u32) + .collect::>(); + + let mut segments = Vec::with_capacity(fragment_ids.len()); + for fragment_id in fragment_ids { + let mut builder = + self.dataset + .create_index_builder(&["vec"], IndexType::Vector, ¶ms); + builder = builder.name("idx".to_string()).fragments(vec![fragment_id]); + segments.push(builder.execute_uncommitted().await?); + } + + let segment_ids = segments + .iter() + .map(|segment| segment.uuid) + .collect::>(); + self.dataset + .commit_existing_index_segments("idx", "vec", segments) + .await?; + Ok(segment_ids) + } + pub async fn make_scalar_index(&mut self) -> Result<()> { self.dataset .create_index( @@ -10258,8 +10444,8 @@ full_filter=name LIKE Utf8(\"test%2\"), refine_filter=name LIKE Utf8(\"test%2\") .await .unwrap(); - // Create index on first 2 fragments - test_ds.make_vector_index().await.unwrap(); + // Create one segment per indexed fragment so fragment filtering must prune ANN fan-out. + test_ds.make_segmented_vector_index().await.unwrap(); let query: Float32Array = (0..32).map(|v| v as f32).collect(); @@ -10280,6 +10466,207 @@ full_filter=name LIKE Utf8(\"test%2\"), refine_filter=name LIKE Utf8(\"test%2\") .await; } + #[tokio::test] + async fn test_vector_search_fragment_filter_prunes_segment_fanout() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false) + .await + .unwrap(); + test_ds.make_segmented_vector_index().await.unwrap(); + + let query: Float32Array = (0..32).map(|v| v as f32).collect(); + test_ds.append_data_with_range(400, 410).await.unwrap(); + test_ds.append_data_with_range(410, 420).await.unwrap(); + let fragments = test_ds.dataset.fragments(); + + let mut scanner = test_ds.dataset.scan(); + scanner.nearest("vec", &query, 420).unwrap(); + let full_plan = scanner.explain_plan(true).await.unwrap(); + assert!( + full_plan.contains("ANNSubIndex: name=idx, k=420, deltas=2, metric=L2"), + "expected two ANN deltas without fragment filter, plan was:\n{full_plan}" + ); + + let mut scanner = test_ds.dataset.scan(); + scanner + .nearest("vec", &query, 420) + .unwrap() + .with_fragments(vec![fragments[0].clone()]); + let filtered_plan = scanner.explain_plan(true).await.unwrap(); + assert!( + filtered_plan.contains("ANNSubIndex: name=idx, k=420, deltas=1, metric=L2"), + "expected one ANN delta with fragment filter, plan was:\n{filtered_plan}" + ); + } + + #[tokio::test] + async fn test_vector_search_respects_index_segments() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false) + .await + .unwrap(); + let segment_ids = test_ds.make_segmented_vector_index().await.unwrap(); + + let query: Float32Array = (0..32).map(|v| v as f32).collect(); + test_ds.append_data_with_range(400, 410).await.unwrap(); + test_ds.append_data_with_range(410, 420).await.unwrap(); + + let mut scanner = test_ds.dataset.scan(); + scanner + .nearest("vec", &query, 420) + .unwrap() + .with_index_segments(vec![segment_ids[0]]) + .unwrap(); + let batch = scanner.try_into_batch().await.unwrap(); + let i_array = batch + .column_by_name("i") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(batch.num_rows(), 200); + assert_values_in_range( + i_array, + 0..200, + "Should only get results from the selected index segment", + ); + } + + #[tokio::test] + async fn test_vector_search_intersects_fragments_and_index_segments() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false) + .await + .unwrap(); + let segment_ids = test_ds.make_segmented_vector_index().await.unwrap(); + + let query: Float32Array = (0..32).map(|v| v as f32).collect(); + test_ds.append_data_with_range(400, 410).await.unwrap(); + test_ds.append_data_with_range(410, 420).await.unwrap(); + let fragments = test_ds.dataset.fragments(); + + let mut scanner = test_ds.dataset.scan(); + scanner + .nearest("vec", &query, 420) + .unwrap() + .with_fragments(vec![fragments[0].clone(), fragments[2].clone()]) + .with_index_segments(vec![segment_ids[0]]) + .unwrap(); + let batch = scanner.try_into_batch().await.unwrap(); + let i_array = batch + .column_by_name("i") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert!( + i_array + .iter() + .all(|v| v.is_some_and(|val| (0..200).contains(&val) || (400..410).contains(&val))) + && i_array + .iter() + .any(|v| v.is_some_and(|val| (0..200).contains(&val))) + && i_array + .iter() + .any(|v| v.is_some_and(|val| (400..410).contains(&val))), + "Should get selected segment rows plus flat fallback for target fragments outside the selected segments" + ); + } + + #[tokio::test] + async fn test_vector_search_rejects_unknown_index_segment() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false) + .await + .unwrap(); + test_ds.make_segmented_vector_index().await.unwrap(); + + let query: Float32Array = (0..32).map(|v| v as f32).collect(); + let err = test_ds + .dataset + .scan() + .nearest("vec", &query, 10) + .unwrap() + .with_index_segments(vec![Uuid::new_v4()]) + .unwrap() + .try_into_batch() + .await + .unwrap_err(); + assert!( + err.to_string().contains("unknown index segments"), + "unexpected error: {err}" + ); + } + + #[tokio::test] + async fn test_vector_search_rejects_metric_mismatch_for_index_segments() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false) + .await + .unwrap(); + let segment_ids = test_ds.make_segmented_vector_index().await.unwrap(); + + let query: Float32Array = (0..32).map(|v| v as f32).collect(); + let err = test_ds + .dataset + .scan() + .nearest("vec", &query, 10) + .unwrap() + .distance_metric(DistanceType::Dot) + .with_index_segments(vec![segment_ids[0]]) + .unwrap() + .try_into_batch() + .await + .unwrap_err(); + assert!( + err.to_string() + .contains("with_index_segments requested metric"), + "unexpected error: {err}" + ); + } + + #[tokio::test] + async fn test_with_index_segments_rejects_empty_list() { + let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false) + .await + .unwrap(); + let query: Float32Array = (0..32).map(|v| v as f32).collect(); + + let Err(err) = test_ds + .dataset + .scan() + .nearest("vec", &query, 10) + .unwrap() + .with_index_segments(vec![]) + else { + panic!("expected empty index segments to be rejected"); + }; + assert!( + err.to_string() + .contains("with_index_segments does not accept an empty segment list"), + "unexpected error: {err}" + ); + } + + #[tokio::test] + async fn test_with_index_segments_rejected_for_non_vector_query() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false) + .await + .unwrap(); + let segment_ids = test_ds.make_segmented_vector_index().await.unwrap(); + + let err = test_ds + .dataset + .scan() + .project(&["i"]) + .unwrap() + .with_index_segments(vec![segment_ids[0]]) + .unwrap() + .try_into_batch() + .await + .unwrap_err(); + assert!( + err.to_string() + .contains("with_index_segments is only supported for vector search"), + "unexpected error: {err}" + ); + } + #[tokio::test] async fn test_fts_respects_fragment_list() { let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false)