Skip to content

Commit a3a3e54

Browse files
authored
Add support for radial search on Neural query (#1235)
* Add support for radial search on k-NN and Neural query types Signed-off-by: Thomas Farr <[email protected]> * spotless Signed-off-by: Thomas Farr <[email protected]> * Fix compile errors Signed-off-by: Thomas Farr <[email protected]> * Fix test Signed-off-by: Thomas Farr <[email protected]> --------- Signed-off-by: Thomas Farr <[email protected]>
1 parent 96f1688 commit a3a3e54

File tree

5 files changed

+101
-20
lines changed

5 files changed

+101
-20
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ This section is for maintaining a changelog for all breaking changes for the cli
3939

4040
### Added
4141
- Added `minScore` and `maxDistance` to `KnnQuery` ([#1166](https://github.com/opensearch-project/opensearch-java/pull/1166))
42+
- Added `minScore` and `maxDistance` to `NeuralQuery` ([#1235](https://github.com/opensearch-project/opensearch-java/pull/1235))
4243

4344
### Dependencies
4445

@@ -568,4 +569,4 @@ This section is for maintaining a changelog for all breaking changes for the cli
568569
[2.5.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.4.0...v2.5.0
569570
[2.4.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.3.0...v2.4.0
570571
[2.3.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.2.0...v2.3.0
571-
[2.2.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.1.0...v2.2.0
572+
[2.2.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.1.0...v2.2.0

java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ public final float[] vector() {
7676
* Optional - The number of neighbors the search of each graph will return.
7777
* @return The number of neighbors to return.
7878
*/
79+
@Nullable
7980
public final Integer k() {
8081
return this.k;
8182
}
@@ -84,6 +85,7 @@ public final Integer k() {
8485
* Optional - The minimum score allowed for the returned search results.
8586
* @return The minimum score allowed for the returned search results.
8687
*/
88+
@Nullable
8789
private final Float minScore() {
8890
return this.minScore;
8991
}
@@ -92,6 +94,7 @@ private final Float minScore() {
9294
* Optional - The maximum distance allowed between the vector and each of the returned search results.
9395
* @return The maximum distance allowed between the vector and each ofthe returned search results.
9496
*/
97+
@Nullable
9598
private final Float maxDistance() {
9699
return this.maxDistance;
97100
}
@@ -111,8 +114,6 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
111114

112115
super.serializeInternal(generator, mapper);
113116

114-
// TODO: Implement the rest of the serialization.
115-
116117
generator.writeKey("vector");
117118
generator.writeStartArray();
118119
for (float value : this.vector) {
@@ -183,7 +184,7 @@ public Builder vector(@Nullable float[] vector) {
183184
}
184185

185186
/**
186-
* Required - The number of neighbors the search of each graph will return.
187+
* Optional - The number of neighbors to return.
187188
*
188189
* @param k The number of neighbors to return.
189190
* @return This builder.

java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java

+78-8
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ public class NeuralQuery extends QueryBase implements QueryVariant {
2626
private final String field;
2727
private final String queryText;
2828
private final String queryImage;
29-
private final int k;
29+
@Nullable
30+
private final Integer k;
31+
@Nullable
32+
private final Float minScore;
33+
@Nullable
34+
private final Float maxDistance;
3035
@Nullable
3136
private final String modelId;
3237
@Nullable
@@ -41,7 +46,9 @@ private NeuralQuery(NeuralQuery.Builder builder) {
4146
}
4247
this.queryText = builder.queryText;
4348
this.queryImage = builder.queryImage;
44-
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
49+
this.k = builder.k;
50+
this.minScore = builder.minScore;
51+
this.maxDistance = builder.maxDistance;
4552
this.modelId = builder.modelId;
4653
this.filter = builder.filter;
4754
}
@@ -90,17 +97,34 @@ public final String queryImage() {
9097
}
9198

9299
/**
93-
* Required - The number of neighbors to return.
100+
* Optional - The number of neighbors to return.
94101
*
95102
* @return The number of neighbors to return.
96103
*/
97-
public final int k() {
104+
@Nullable
105+
public final Integer k() {
98106
return this.k;
99107
}
100108

101109
/**
102-
* Builder for {@link NeuralQuery}.
110+
* Optional - The minimum score threshold for the search results
111+
*
112+
* @return The minimum score threshold for the search results
113+
*/
114+
@Nullable
115+
public final Float minScore() {
116+
return this.minScore;
117+
}
118+
119+
/**
120+
* Optional - The maximum distance threshold for the search results
121+
*
122+
* @return The maximum distance threshold for the search results
103123
*/
124+
@Nullable
125+
public final Float maxDistance() {
126+
return this.maxDistance;
127+
}
104128

105129
/**
106130
* Optional - The model_id field if the default model for the index or field is set.
@@ -141,7 +165,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
141165
generator.write("model_id", this.modelId);
142166
}
143167

144-
generator.write("k", this.k);
168+
if (this.k != null) {
169+
generator.write("k", this.k);
170+
}
171+
172+
if (this.minScore != null) {
173+
generator.write("min_score", this.minScore);
174+
}
175+
176+
if (this.maxDistance != null) {
177+
generator.write("max_distance", this.maxDistance);
178+
}
145179

146180
if (this.filter != null) {
147181
generator.writeKey("filter");
@@ -152,7 +186,14 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
152186
}
153187

154188
public Builder toBuilder() {
155-
return toBuilder(new Builder()).field(field).queryText(queryText).queryImage(queryImage).k(k).modelId(modelId).filter(filter);
189+
return toBuilder(new Builder()).field(field)
190+
.queryText(queryText)
191+
.queryImage(queryImage)
192+
.k(k)
193+
.minScore(minScore)
194+
.maxDistance(maxDistance)
195+
.modelId(modelId)
196+
.filter(filter);
156197
}
157198

158199
/**
@@ -162,8 +203,13 @@ public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builde
162203
private String field;
163204
private String queryText;
164205
private String queryImage;
206+
@Nullable
165207
private Integer k;
166208
@Nullable
209+
private Float minScore;
210+
@Nullable
211+
private Float maxDistance;
212+
@Nullable
167213
private String modelId;
168214
@Nullable
169215
private Query filter;
@@ -216,7 +262,7 @@ public NeuralQuery.Builder modelId(@Nullable String modelId) {
216262
}
217263

218264
/**
219-
* Required - The number of neighbors to return.
265+
* Optional - The number of neighbors to return.
220266
*
221267
* @param k The number of neighbors to return.
222268
* @return This builder.
@@ -226,6 +272,28 @@ public NeuralQuery.Builder k(@Nullable Integer k) {
226272
return this;
227273
}
228274

275+
/**
276+
* Optional - The minimum score threshold for the search results
277+
*
278+
* @param minScore The minimum score threshold for the search results
279+
* @return This builder.
280+
*/
281+
public NeuralQuery.Builder minScore(@Nullable Float minScore) {
282+
this.minScore = minScore;
283+
return this;
284+
}
285+
286+
/**
287+
* Optional - The maximum distance threshold for the search results
288+
*
289+
* @param maxDistance The maximum distance threshold for the search results
290+
* @return This builder.
291+
*/
292+
public NeuralQuery.Builder maxDistance(@Nullable Float maxDistance) {
293+
this.maxDistance = maxDistance;
294+
return this;
295+
}
296+
229297
/**
230298
* Optional - A query to filter the results of the knn query.
231299
*
@@ -267,6 +335,8 @@ protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuer
267335
op.add(NeuralQuery.Builder::queryImage, JsonpDeserializer.stringDeserializer(), "query_image");
268336
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
269337
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");
338+
op.add(NeuralQuery.Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score");
339+
op.add(NeuralQuery.Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance");
270340
op.add(NeuralQuery.Builder::filter, Query._DESERIALIZER, "filter");
271341

272342
op.setKey(NeuralQuery.Builder::field, JsonpDeserializer.stringDeserializer());

java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ public void testNeuralQuery() {
221221
assertEquals("passage_embedding", searchRequest.query().neural().field());
222222
assertEquals("Hi world", searchRequest.query().neural().queryText());
223223
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
224-
assertEquals(100, searchRequest.query().neural().k());
224+
assertEquals((Integer) 100, searchRequest.query().neural().k());
225225
}
226226

227227
@Test
@@ -251,7 +251,7 @@ public void testNeuralQueryFromJson() {
251251
searchRequest.query().neural().queryImage()
252252
);
253253
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
254-
assertEquals(100, searchRequest.query().neural().k());
254+
assertEquals((Integer) 100, searchRequest.query().neural().k());
255255
}
256256

257257
@Test
@@ -279,10 +279,10 @@ public void testHybridQuery() {
279279
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field());
280280
assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText());
281281
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId());
282-
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
282+
assertEquals((Integer) 100, searchRequest.query().hybrid().queries().get(1).neural().k());
283283
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
284284
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
285-
assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k());
285+
assertEquals((Integer) 2, searchRequest.query().hybrid().queries().get(2).knn().k());
286286
}
287287

288288
@Test
@@ -301,9 +301,9 @@ public void testHybridQueryFromJson() {
301301
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field());
302302
assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText());
303303
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId());
304-
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
304+
assertEquals((Integer) 100, searchRequest.query().hybrid().queries().get(1).neural().k());
305305
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
306306
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
307-
assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k());
307+
assertEquals((Integer) 2, searchRequest.query().hybrid().queries().get(2).knn().k());
308308
}
309309
}

java-client/src/test/java11/org/opensearch/client/opensearch/integTest/AbstractSearchTemplateRequestIT.java

+11-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.util.List;
1313
import java.util.Map;
1414
import org.junit.Test;
15+
import org.opensearch.Version;
1516
import org.opensearch.client.json.JsonData;
1617
import org.opensearch.client.opensearch._types.Refresh;
1718
import org.opensearch.client.opensearch._types.mapping.Property;
@@ -89,6 +90,14 @@ public void testTemplateSearchAggregations() throws Exception {
8990

9091
@Test
9192
public void testMultiSearchTemplate() throws Exception {
93+
Integer expectedSuccessStatus = null;
94+
Integer expectedFailureStatus = null;
95+
96+
if (getServerVersion().onOrAfter(Version.V_2_18_0)) {
97+
expectedSuccessStatus = 200;
98+
expectedFailureStatus = 404;
99+
}
100+
92101
var index = "test-msearch-template";
93102
createDocuments(index);
94103

@@ -120,11 +129,11 @@ public void testMultiSearchTemplate() throws Exception {
120129
assertEquals(2, searchResponse.responses().size());
121130
var response = searchResponse.responses().get(0);
122131
assertTrue(response.isResult());
123-
assertNull(response.result().status());
132+
assertEquals(expectedSuccessStatus, response.result().status());
124133
assertEquals(4, response.result().hits().hits().size());
125134
var failureResponse = searchResponse.responses().get(1);
126135
assertTrue(failureResponse.isFailure());
127-
assertNull(failureResponse.failure().status());
136+
assertEquals(expectedFailureStatus, failureResponse.failure().status());
128137
}
129138

130139
private SearchTemplateResponse<SimpleDoc> sendTemplateRequest(String index, String title, boolean suggs, boolean aggs)

0 commit comments

Comments
 (0)