|
32 | 32 |
|
33 | 33 | package org.opensearch.client.opensearch.model;
|
34 | 34 |
|
| 35 | +import java.util.Arrays; |
35 | 36 | import org.junit.Test;
|
36 | 37 | import org.opensearch.client.json.JsonData;
|
| 38 | +import org.opensearch.client.opensearch._types.FieldValue; |
37 | 39 | import org.opensearch.client.opensearch._types.mapping.Property;
|
38 | 40 | import org.opensearch.client.opensearch._types.mapping.TypeMapping;
|
| 41 | +import org.opensearch.client.opensearch._types.query_dsl.KnnQuery; |
| 42 | +import org.opensearch.client.opensearch._types.query_dsl.NeuralQuery; |
39 | 43 | import org.opensearch.client.opensearch._types.query_dsl.Query;
|
40 | 44 | import org.opensearch.client.opensearch._types.query_dsl.QueryBuilders;
|
| 45 | +import org.opensearch.client.opensearch._types.query_dsl.TermQuery; |
41 | 46 | import org.opensearch.client.opensearch.core.SearchRequest;
|
42 | 47 | import org.opensearch.client.opensearch.indices.GetMappingResponse;
|
43 | 48 |
|
@@ -243,4 +248,57 @@ public void testNeuralQueryFromJson() {
|
243 | 248 | assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
|
244 | 249 | assertEquals(100, searchRequest.query().neural().k());
|
245 | 250 | }
|
| 251 | + |
| 252 | + @Test |
| 253 | + public void testHybridQuery() { |
| 254 | + |
| 255 | + Query query = Query.of( |
| 256 | + h -> h.hybrid( |
| 257 | + q -> q.queries( |
| 258 | + Arrays.asList( |
| 259 | + new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(), |
| 260 | + new NeuralQuery.Builder().field("passage_embedding") |
| 261 | + .queryText("Hi world") |
| 262 | + .modelId("bQ1J8ooBpBj3wT4HVUsb") |
| 263 | + .k(100) |
| 264 | + .build() |
| 265 | + .toQuery(), |
| 266 | + new KnnQuery.Builder().field("passage_embedding").vector(new float[] { 0.01f, 0.02f }).k(2).build().toQuery() |
| 267 | + ) |
| 268 | + ) |
| 269 | + ) |
| 270 | + ); |
| 271 | + SearchRequest searchRequest = SearchRequest.of(s -> s.query(query)); |
| 272 | + assertEquals("passage_text", searchRequest.query().hybrid().queries().get(0).term().field()); |
| 273 | + assertEquals("Foo bar", searchRequest.query().hybrid().queries().get(0).term().value().stringValue()); |
| 274 | + assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field()); |
| 275 | + assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText()); |
| 276 | + assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId()); |
| 277 | + assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k()); |
| 278 | + assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field()); |
| 279 | + assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length); |
| 280 | + assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k()); |
| 281 | + } |
| 282 | + |
| 283 | + @Test |
| 284 | + public void testHybridQueryFromJson() { |
| 285 | + |
| 286 | + String json = "{\"query\"" |
| 287 | + + ":{\"hybrid\":{\"queries\":[{\"term\":{\"passage_text\":\"Foo bar\"}}," |
| 288 | + + "{\"neural\":{\"passage_embedding\":{\"query_text\":\"Hi world\",\"model_id\":\"bQ1J8ooBpBj3wT4HVUsb\",\"k\":100}}}," |
| 289 | + + "{\"knn\":{\"passage_embedding\":{\"vector\":[0.01,0.02],\"k\":2}}}]}},\"size\":10" |
| 290 | + + "}"; |
| 291 | + |
| 292 | + SearchRequest searchRequest = ModelTestCase.fromJson(json, SearchRequest.class, mapper); |
| 293 | + |
| 294 | + assertEquals("passage_text", searchRequest.query().hybrid().queries().get(0).term().field()); |
| 295 | + assertEquals("Foo bar", searchRequest.query().hybrid().queries().get(0).term().value().stringValue()); |
| 296 | + assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field()); |
| 297 | + assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText()); |
| 298 | + assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId()); |
| 299 | + assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k()); |
| 300 | + assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field()); |
| 301 | + assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length); |
| 302 | + assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k()); |
| 303 | + } |
246 | 304 | }
|
0 commit comments