5353import java .util .concurrent .Callable ;
5454import java .util .stream .Collectors ;
5555
56+ import static org .opensearch .knn .profile .StopWatchUtils .startStopWatch ;
57+ import static org .opensearch .knn .profile .StopWatchUtils .stopStopWatchAndLog ;
58+
5659/**
5760 * {@link KNNQuery} executes approximate nearest neighbor search (ANN) on a segment level.
5861 * {@link NativeEngineKnnVectorQuery} executes approximate nearest neighbor search but gives
@@ -274,13 +277,19 @@ private List<PerLeafResult> doSearch(
274277
275278 // For memory optimized search, it should kick off 2nd search if optimistic
276279 if (knnQuery .isMemoryOptimizedSearch () && perLeafResults .size () > 1 ) {
277- run2ndOptimisticSearch (perLeafResults , knnWeight , leafReaderContexts , k , indexSearcher );
280+ log .debug (
281+ "Running second deep dive search in optimistic while memory optimized search is enabled. perLeafResults.size()={}" ,
282+ perLeafResults .size ()
283+ );
284+ final StopWatch stopWatch = startStopWatch (log );
285+ reentrantSearch (perLeafResults , knnWeight , leafReaderContexts , k , indexSearcher );
286+ stopStopWatchAndLog (log , stopWatch , "2ndOptimisticSearch" , knnQuery .getShardId (), "All Shards" , knnQuery .getField ());
278287 }
279288
280289 return perLeafResults ;
281290 }
282291
283- private void run2ndOptimisticSearch (
292+ private void reentrantSearch (
284293 final List <PerLeafResult > perLeafResults ,
285294 final KNNWeight knnWeight ,
286295 final List <LeafReaderContext > leafReaderContexts ,
@@ -299,7 +308,7 @@ private void run2ndOptimisticSearch(
299308
300309 assert (perLeafResults .size () == leafReaderContexts .size ());
301310
302- // Get collector manager first
311+ // Get memory optimized knn weight first, it's safe get it, we checked it already.
303312 final MemoryOptimizedKNNWeight memoryOptKNNWeight = (MemoryOptimizedKNNWeight ) knnWeight ;
304313
305314 // How many results have we collected?
@@ -313,25 +322,24 @@ private void run2ndOptimisticSearch(
313322 return ;
314323 }
315324
316- // Build segment to results table
317- final Map <Integer , TopDocs > segmentOrdToResults = new HashMap <>(leafReaderContexts .size ());
318- for (int i = 0 ; i < perLeafResults .size (); i ++) {
319- segmentOrdToResults .put (leafReaderContexts .get (i ).ord , perLeafResults .get (i ).getResult ());
320- }
321-
322325 // Start 2nd deep dive, and get the minimum bar.
323326 final float minTopKScore = OptimisticSearchStrategyUtils .findKthLargestScore (perLeafResults , knnQuery .getK (), totalResults );
324327
325328 // Select candidate segments for 2nd search. Pick whatever segment returned all vectors whose score values are greater than `kth`
326329 // value in the merged results.
327330 final List <Callable <TopDocs >> secondDeepDiveTasks = new ArrayList <>();
328331 final List <Integer > contextIndices = new ArrayList <>();
332+ final Map <Integer , TopDocs > segmentOrdToResults = new HashMap <>();
333+
329334 for (int i = 0 ; i < leafReaderContexts .size (); ++i ) {
330335 final LeafReaderContext leafReaderContext = leafReaderContexts .get (i );
331336 final PerLeafResult perLeafResult = perLeafResults .get (i );
332- final TopDocs perLeaf = segmentOrdToResults .get (leafReaderContext . ord );
337+ final TopDocs perLeaf = perLeafResults .get (i ). getResult ( );
333338 if (perLeaf .scoreDocs .length > 0 && perLeafResult .getSearchMode () == PerLeafResult .SearchMode .APPROXIMATE_SEARCH ) {
334339 if (FORCE_REENTER_TESTING || perLeaf .scoreDocs [perLeaf .scoreDocs .length - 1 ].score >= minTopKScore ) {
340+ // For the target segment, save top results. Which will be used as seeds.
341+ segmentOrdToResults .put (leafReaderContext .ord , perLeaf );
342+
335343 // All this leaf's hits are at or above the global topK min score; explore it further
336344 secondDeepDiveTasks .add (
337345 () -> knnWeight .approximateSearch (
@@ -348,24 +356,24 @@ private void run2ndOptimisticSearch(
348356
349357 // Kick off 2nd search tasks
350358 if (secondDeepDiveTasks .isEmpty () == false ) {
351- final ReentrantKnnCollectorManager knnCollectorManagerPhase2 = new ReentrantKnnCollectorManager (
359+ final ReentrantKnnCollectorManager reentrantCollectorManager = new ReentrantKnnCollectorManager (
352360 new TopKnnCollectorManager (k , indexSearcher ),
353361 segmentOrdToResults ,
354362 knnQuery .getQueryVector (),
355363 knnQuery .getField ()
356364 );
357365
358366 // Make weight use reentrant collector manager
359- memoryOptKNNWeight .setOptimistic2ndKnnCollectorManager ( knnCollectorManagerPhase2 );
367+ memoryOptKNNWeight .setReentrantKNNCollectorManager ( reentrantCollectorManager );
360368
361369 final List <TopDocs > deepDiveTopDocs = indexSearcher .getTaskExecutor ().invokeAll (secondDeepDiveTasks );
362370
363371 // Override results for target context
364372 for (int i = 0 ; i < deepDiveTopDocs .size (); ++i ) {
365373 // Override with the new results
366- final TopDocs resultsFrom2ncDeepDive = deepDiveTopDocs .get (i );
374+ final TopDocs resultsFromDeepDive = deepDiveTopDocs .get (i );
367375 final PerLeafResult perLeafResult = perLeafResults .get (contextIndices .get (i ));
368- perLeafResult .setResult (resultsFrom2ncDeepDive );
376+ perLeafResult .setResult (resultsFromDeepDive );
369377 }
370378 }
371379 }
0 commit comments