56
56
import org .opensearch .index .IndexSettings ;
57
57
import org .opensearch .index .mapper .MappedFieldType ;
58
58
import org .opensearch .index .mapper .TextFieldMapper ;
59
+ import org .opensearch .index .query .BoolQueryBuilder ;
59
60
import org .opensearch .index .query .MatchAllQueryBuilder ;
60
61
import org .opensearch .index .query .QueryBuilder ;
61
62
import org .opensearch .index .query .QueryBuilders ;
@@ -82,6 +83,7 @@ public class HybridQueryBuilderTests extends OpenSearchQueryTestCase {
82
83
static final String TEXT_FIELD_NAME = "field" ;
83
84
static final String QUERY_TEXT = "Hello world!" ;
84
85
static final String TERM_QUERY_TEXT = "keyword" ;
86
+ static final String FILTER_TERM_QUERY_TEXT = "filterKeyword" ;
85
87
static final String MODEL_ID = "mfgfgdsfgfdgsde" ;
86
88
static final int K = 10 ;
87
89
static final float BOOST = 1.8f ;
@@ -436,6 +438,121 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
436
438
assertEquals (TERM_QUERY_TEXT , termQueryBuilder .value ());
437
439
}
438
440
441
+ /**
442
+ * Tests basic query:
443
+ * {
444
+ * "query": {
445
+ * "hybrid": {
446
+ * "queries": [
447
+ * {
448
+ * "neural": {
449
+ * "text_knn": {
450
+ * "query_text": "Hello world",
451
+ * "model_id": "dcsdcasd",
452
+ * "k": 1
453
+ * }
454
+ * }
455
+ * },
456
+ * {
457
+ * "term": {
458
+ * "text": "keyword"
459
+ * }
460
+ * }
461
+ * ]
462
+ * "filter": {
463
+ * "term": {
464
+ * "text": "filterKeyword"
465
+ * }
466
+ * }
467
+ * }
468
+ * }
469
+ * }
470
+ */
471
+ @ SneakyThrows
472
+ public void testFromXContent_whenMultipleSubQueriesAndFilter_thenBuildSuccessfully () {
473
+ setUpClusterService ();
474
+ XContentBuilder xContentBuilder = XContentFactory .jsonBuilder ()
475
+ .startObject ()
476
+ .startArray ("queries" )
477
+ .startObject ()
478
+ .startObject (NeuralQueryBuilder .NAME )
479
+ .startObject (VECTOR_FIELD_NAME )
480
+ .field (QUERY_TEXT_FIELD .getPreferredName (), QUERY_TEXT )
481
+ .field (MODEL_ID_FIELD .getPreferredName (), MODEL_ID )
482
+ .field (K_FIELD .getPreferredName (), K )
483
+ .field (BOOST_FIELD .getPreferredName (), BOOST )
484
+ .endObject ()
485
+ .endObject ()
486
+ .endObject ()
487
+ .startObject ()
488
+ .startObject (TermQueryBuilder .NAME )
489
+ .field (TEXT_FIELD_NAME , TERM_QUERY_TEXT )
490
+ .endObject ()
491
+ .endObject ()
492
+ .endArray ()
493
+
494
+ .field ("pagination_depth" , 10 )
495
+ .startObject ("filter" )
496
+ .startObject (TermQueryBuilder .NAME )
497
+ .field (TEXT_FIELD_NAME , FILTER_TERM_QUERY_TEXT )
498
+ .endObject ()
499
+ .endObject ()
500
+ .endObject ();
501
+
502
+ NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry (
503
+ List .of (
504
+ new NamedXContentRegistry .Entry (QueryBuilder .class , new ParseField (TermQueryBuilder .NAME ), TermQueryBuilder ::fromXContent ),
505
+ new NamedXContentRegistry .Entry (
506
+ QueryBuilder .class ,
507
+ new ParseField (NeuralQueryBuilder .NAME ),
508
+ NeuralQueryBuilder ::fromXContent
509
+ ),
510
+ new NamedXContentRegistry .Entry (
511
+ QueryBuilder .class ,
512
+ new ParseField (HybridQueryBuilder .NAME ),
513
+ HybridQueryBuilder ::fromXContent
514
+ )
515
+ )
516
+ );
517
+ XContentParser contentParser = createParser (
518
+ namedXContentRegistry ,
519
+ xContentBuilder .contentType ().xContent (),
520
+ BytesReference .bytes (xContentBuilder )
521
+ );
522
+ contentParser .nextToken ();
523
+
524
+ HybridQueryBuilder queryTwoSubQueries = HybridQueryBuilder .fromXContent (contentParser );
525
+ assertEquals (2 , queryTwoSubQueries .queries ().size ());
526
+ assertTrue (queryTwoSubQueries .queries ().get (0 ) instanceof NeuralQueryBuilder );
527
+
528
+ assertTrue (queryTwoSubQueries .queries ().get (1 ) instanceof BoolQueryBuilder );
529
+ assertEquals (1 , ((BoolQueryBuilder ) queryTwoSubQueries .queries ().get (1 )).must ().size ());
530
+ assertTrue (((BoolQueryBuilder ) queryTwoSubQueries .queries ().get (1 )).must ().get (0 ) instanceof TermQueryBuilder );
531
+ assertEquals (1 , ((BoolQueryBuilder ) queryTwoSubQueries .queries ().get (1 )).filter ().size ());
532
+
533
+ assertEquals (10 , queryTwoSubQueries .paginationDepth ().intValue ());
534
+ // verify knn vector query
535
+ NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder ) queryTwoSubQueries .queries ().get (0 );
536
+ assertEquals (VECTOR_FIELD_NAME , neuralQueryBuilder .fieldName ());
537
+ assertEquals (QUERY_TEXT , neuralQueryBuilder .queryText ());
538
+ assertEquals (K , (int ) neuralQueryBuilder .k ());
539
+ assertEquals (MODEL_ID , neuralQueryBuilder .modelId ());
540
+ assertEquals (BOOST , neuralQueryBuilder .boost (), 0f );
541
+ assertEquals (
542
+ new TermQueryBuilder (TEXT_FIELD_NAME , FILTER_TERM_QUERY_TEXT ),
543
+ ((NeuralQueryBuilder ) queryTwoSubQueries .queries ().get (0 )).filter ()
544
+ );
545
+ // verify term query
546
+ assertEquals (
547
+ new TermQueryBuilder (TEXT_FIELD_NAME , TERM_QUERY_TEXT ),
548
+ ((BoolQueryBuilder ) queryTwoSubQueries .queries ().get (1 )).must ().get (0 )
549
+ );
550
+ assertEquals (
551
+ new TermQueryBuilder (TEXT_FIELD_NAME , FILTER_TERM_QUERY_TEXT ),
552
+ ((BoolQueryBuilder ) queryTwoSubQueries .queries ().get (1 )).filter ().get (0 )
553
+ );
554
+ }
555
+
439
556
@ SneakyThrows
440
557
public void testFromXContent_whenIncorrectFormat_thenFail () {
441
558
XContentBuilder unsupportedFieldXContentBuilder = XContentFactory .jsonBuilder ()
@@ -960,6 +1077,29 @@ public void testVisit() {
960
1077
assertEquals (3 , visitedQueries .size ());
961
1078
}
962
1079
1080
+ public void testFilter () {
1081
+ HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder ().add (
1082
+ NeuralQueryBuilder .builder ().fieldName ("test" ).queryText ("test" ).build ()
1083
+ ).add (new NeuralSparseQueryBuilder ());
1084
+ // Test for Null filter Case
1085
+ QueryBuilder queryBuilder = hybridQueryBuilder .filter (null );
1086
+ assertEquals (queryBuilder , hybridQueryBuilder );
1087
+
1088
+ // Test for Non-Null filter case and assert every field as expected
1089
+ HybridQueryBuilder updatedHybridQueryBuilder = (HybridQueryBuilder ) hybridQueryBuilder .filter (new MatchAllQueryBuilder ());
1090
+ assertEquals (updatedHybridQueryBuilder .queryName (), hybridQueryBuilder .queryName ());
1091
+ assertEquals (updatedHybridQueryBuilder .paginationDepth (), hybridQueryBuilder .paginationDepth ());
1092
+ NeuralQueryBuilder updatedNeuralQueryBuilder = (NeuralQueryBuilder ) updatedHybridQueryBuilder .queries ().get (0 );
1093
+ assertEquals (new MatchAllQueryBuilder (), updatedNeuralQueryBuilder .filter ());
1094
+ BoolQueryBuilder updatedNeuralSparseQueryBuilder = (BoolQueryBuilder ) updatedHybridQueryBuilder .queries ().get (1 );
1095
+ assertEquals (new NeuralSparseQueryBuilder (), updatedNeuralSparseQueryBuilder .must ().get (0 ));
1096
+ assertEquals (new MatchAllQueryBuilder (), updatedNeuralSparseQueryBuilder .filter ().get (0 ));
1097
+
1098
+ // Test for Non-Null filter case but encountered Nested HybridQueryBuilder to throw Unsupported Exception
1099
+ updatedHybridQueryBuilder .add (new HybridQueryBuilder ());
1100
+ assertThrows (UnsupportedOperationException .class , () -> updatedHybridQueryBuilder .filter (new MatchAllQueryBuilder ()));
1101
+ }
1102
+
963
1103
private Map <String , Object > getInnerMap (Object innerObject , String queryName , String fieldName ) {
964
1104
if (!(innerObject instanceof Map )) {
965
1105
fail ("field name does not map to nested object" );
0 commit comments