Skip to content

Commit 4430ec1

Browse files
[Feature branch] Lower bounds for min-max normalization in hybrid query (#1195) (#1213)
* Working draft with unit tests * Added integ test, adjust some calculations --------- Signed-off-by: Martin Gaievski <[email protected]>
1 parent da5eebb commit 4430ec1

19 files changed

+1514
-90
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55

66
## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
77
### Features
8+
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
89
### Enhancements
910
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
1011
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

+75
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;
1414

1515
import java.util.ArrayList;
16+
import java.util.Arrays;
1617
import java.util.List;
1718
import java.util.Objects;
1819

@@ -150,4 +151,78 @@ private ScoreDoc deepCopyScoreDoc(final ScoreDoc scoreDoc, final boolean isSortE
150151
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
151152
return new FieldDoc(fieldDoc.doc, fieldDoc.score, fieldDoc.fields, fieldDoc.shardIndex);
152153
}
154+
155+
@Override
156+
public boolean equals(Object other) {
157+
if (this == other) return true;
158+
if (other == null || getClass() != other.getClass()) return false;
159+
CompoundTopDocs that = (CompoundTopDocs) other;
160+
161+
if (this.topDocs.size() != that.topDocs.size()) {
162+
return false;
163+
}
164+
for (int i = 0; i < topDocs.size(); i++) {
165+
TopDocs thisTopDoc = this.topDocs.get(i);
166+
TopDocs thatTopDoc = that.topDocs.get(i);
167+
if ((thisTopDoc == null) != (thatTopDoc == null)) {
168+
return false;
169+
}
170+
if (thisTopDoc == null) {
171+
continue;
172+
}
173+
if (Objects.equals(thisTopDoc.totalHits, thatTopDoc.totalHits) == false) {
174+
return false;
175+
}
176+
if (compareScoreDocs(thisTopDoc.scoreDocs, thatTopDoc.scoreDocs) == false) {
177+
return false;
178+
}
179+
}
180+
return Objects.equals(totalHits, that.totalHits) && Objects.equals(searchShard, that.searchShard);
181+
}
182+
183+
private boolean compareScoreDocs(ScoreDoc[] first, ScoreDoc[] second) {
184+
if (first.length != second.length) {
185+
return false;
186+
}
187+
188+
for (int i = 0; i < first.length; i++) {
189+
ScoreDoc firstDoc = first[i];
190+
ScoreDoc secondDoc = second[i];
191+
if ((firstDoc == null) != (secondDoc == null)) {
192+
return false;
193+
}
194+
if (firstDoc == null) {
195+
continue;
196+
}
197+
if (firstDoc.doc != secondDoc.doc || Float.compare(firstDoc.score, secondDoc.score) != 0) {
198+
return false;
199+
}
200+
if (firstDoc instanceof FieldDoc != secondDoc instanceof FieldDoc) {
201+
return false;
202+
}
203+
if (firstDoc instanceof FieldDoc firstFieldDoc) {
204+
FieldDoc secondFieldDoc = (FieldDoc) secondDoc;
205+
if (Arrays.equals(firstFieldDoc.fields, secondFieldDoc.fields) == false) {
206+
return false;
207+
}
208+
}
209+
}
210+
return true;
211+
}
212+
213+
@Override
214+
public int hashCode() {
215+
int result = Objects.hash(totalHits, searchShard);
216+
for (TopDocs topDoc : topDocs) {
217+
result = 31 * result + topDoc.totalHits.hashCode();
218+
for (ScoreDoc scoreDoc : topDoc.scoreDocs) {
219+
result = 31 * result + Float.floatToIntBits(scoreDoc.score);
220+
result = 31 * result + scoreDoc.doc;
221+
if (scoreDoc instanceof FieldDoc fieldDoc && fieldDoc.fields != null) {
222+
result = 31 * result + Arrays.deepHashCode(fieldDoc.fields);
223+
}
224+
}
225+
}
226+
return result;
227+
}
153228
}

src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ public SearchPhaseResultsProcessor create(
5858
TECHNIQUE,
5959
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME
6060
);
61-
normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName);
61+
Map<String, Object> normalizationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, normalizationClause, PARAMETERS);
62+
normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName, normalizationParams);
6263
}
6364

6465
Map<String, Object> combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE);

src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java

+9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.Locale;
1111
import java.util.Map;
1212
import java.util.Objects;
13+
import java.util.Set;
1314

1415
import org.apache.lucene.search.ScoreDoc;
1516
import org.apache.lucene.search.TopDocs;
@@ -32,6 +33,14 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu
3233
public static final String TECHNIQUE_NAME = "l2";
3334
private static final float MIN_SCORE = 0.0f;
3435

36+
public L2ScoreNormalizationTechnique() {
37+
this(Map.of(), new ScoreNormalizationUtil());
38+
}
39+
40+
public L2ScoreNormalizationTechnique(final Map<String, Object> params, final ScoreNormalizationUtil scoreNormalizationUtil) {
41+
scoreNormalizationUtil.validateParameters(params, Set.of(), Map.of());
42+
}
43+
3544
/**
3645
* L2 normalization method.
3746
* n_score_i = score_i/sqrt(score1^2 + score2^2 + ... + scoren^2)

0 commit comments

Comments
 (0)