@@ -66,12 +66,12 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
66
66
int totalDocsCountMixed ;
67
67
if (isFirstMixedRound ()) {
68
68
totalDocsCountMixed = NUM_DOCS_PER_ROUND ;
69
- HybridQueryBuilder hybridQueryBuilder = getQueryBuilder (modelId , null , null , null );
69
+ HybridQueryBuilder hybridQueryBuilder = getQueryBuilder (modelId , null , null , null , null );
70
70
validateTestIndexOnUpgrade (totalDocsCountMixed , modelId , hybridQueryBuilder , null );
71
71
addDocument (getIndexNameForTest (), "1" , TEST_FIELD , TEXT_MIXED , null , null );
72
72
} else {
73
73
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND ;
74
- HybridQueryBuilder hybridQueryBuilder = getQueryBuilder (modelId , null , null , null );
74
+ HybridQueryBuilder hybridQueryBuilder = getQueryBuilder (modelId , null , null , null , null );
75
75
validateTestIndexOnUpgrade (totalDocsCountMixed , modelId , hybridQueryBuilder , null );
76
76
}
77
77
break ;
@@ -81,9 +81,15 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
81
81
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND ;
82
82
loadModel (modelId );
83
83
addDocument (getIndexNameForTest (), "2" , TEST_FIELD , TEXT_UPGRADED , null , null );
84
- HybridQueryBuilder hybridQueryBuilder = getQueryBuilder (modelId , null , null , null );
84
+ HybridQueryBuilder hybridQueryBuilder = getQueryBuilder (modelId , null , null , null , null );
85
85
validateTestIndexOnUpgrade (totalDocsCountUpgraded , modelId , hybridQueryBuilder , null );
86
- hybridQueryBuilder = getQueryBuilder (modelId , Boolean .FALSE , Map .of ("ef_search" , 100 ), RescoreContext .getDefault ());
86
+ hybridQueryBuilder = getQueryBuilder (
87
+ modelId ,
88
+ Boolean .FALSE ,
89
+ Map .of ("ef_search" , 100 ),
90
+ RescoreContext .getDefault (),
91
+ new MatchQueryBuilder ("_id" , "2" )
92
+ );
87
93
validateTestIndexOnUpgrade (totalDocsCountUpgraded , modelId , hybridQueryBuilder , null );
88
94
} finally {
89
95
wipeOfTestResources (getIndexNameForTest (), PIPELINE_NAME , modelId , SEARCH_PIPELINE_NAME );
@@ -123,7 +129,8 @@ private HybridQueryBuilder getQueryBuilder(
123
129
final String modelId ,
124
130
final Boolean expandNestedDocs ,
125
131
final Map <String , ?> methodParameters ,
126
- final RescoreContext rescoreContextForNeuralQuery
132
+ final RescoreContext rescoreContextForNeuralQuery ,
133
+ final QueryBuilder filter
127
134
) {
128
135
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder .builder ()
129
136
.fieldName (VECTOR_EMBEDDING_FIELD )
@@ -147,6 +154,10 @@ private HybridQueryBuilder getQueryBuilder(
147
154
hybridQueryBuilder .add (matchQueryBuilder );
148
155
hybridQueryBuilder .add (neuralQueryBuilder );
149
156
157
+ if (filter != null ) {
158
+ hybridQueryBuilder .filter (filter );
159
+ }
160
+
150
161
return hybridQueryBuilder ;
151
162
}
152
163
}
0 commit comments