Skip to content

Commit 5ae7034

Browse files
committed
Used DescriptiveStatistics and added validations for combination technique
Signed-off-by: Owais <[email protected]>
1 parent bc43a35 commit 5ae7034

14 files changed

+214
-106
lines changed

build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ dependencies {
260260
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
261261
testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
262262
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
263+
implementation group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1'
263264
// ml-common excluded reflection for runtime so we need to add it by ourselves.
264265
// https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9
265266
// TODO: Remove following three lines of dependencies if ml-common include them in their jar

src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ public float combine(final float[] scores) {
5858
return combinedScore / sumOfWeights;
5959
}
6060

61+
@Override
62+
public String techniqueName() {
63+
return TECHNIQUE_NAME;
64+
}
65+
6166
@Override
6267
public String describe() {
6368
return describeCombinationTechnique(TECHNIQUE_NAME, weights);

src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ public float combine(final float[] scores) {
5858
return sumOfWeights == 0 ? ZERO_SCORE : (float) Math.exp(weightedLnSum / sumOfWeights);
5959
}
6060

61+
@Override
62+
public String techniqueName() {
63+
return TECHNIQUE_NAME;
64+
}
65+
6166
@Override
6267
public String describe() {
6368
return describeCombinationTechnique(TECHNIQUE_NAME, weights);

src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ public float combine(final float[] scores) {
5555
return sumOfHarmonics > 0 ? sumOfWeights / sumOfHarmonics : ZERO_SCORE;
5656
}
5757

58+
@Override
59+
public String techniqueName() {
60+
return TECHNIQUE_NAME;
61+
}
62+
5863
@Override
5964
public String describe() {
6065
return describeCombinationTechnique(TECHNIQUE_NAME, weights);

src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ public float combine(final float[] scores) {
3737
return sumScores;
3838
}
3939

40+
@Override
41+
public String techniqueName() {
42+
return TECHNIQUE_NAME;
43+
}
44+
4045
@Override
4146
public String describe() {
4247
return describeCombinationTechnique(TECHNIQUE_NAME, List.of());

src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@ public interface ScoreCombinationTechnique {
1212
* @return combined score
1313
*/
1414
float combine(final float[] scores);
15+
16+
String techniqueName();
1517
}

src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique;
1919
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
2020
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
21-
import org.opensearch.neuralsearch.processor.normalization.ZScoreNormalizationTechnique;
2221
import org.opensearch.search.pipeline.Processor;
2322
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
2423

@@ -51,9 +50,8 @@ public SearchPhaseResultsProcessor create(
5150
) throws Exception {
5251
Map<String, Object> normalizationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, NORMALIZATION_CLAUSE);
5352
ScoreNormalizationTechnique normalizationTechnique = ScoreNormalizationFactory.DEFAULT_METHOD;
54-
String normalizationTechniqueName = MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME;
5553
if (Objects.nonNull(normalizationClause)) {
56-
normalizationTechniqueName = readStringProperty(
54+
String normalizationTechniqueName = readStringProperty(
5755
NormalizationProcessor.TYPE,
5856
tag,
5957
normalizationClause,
@@ -75,15 +73,13 @@ public SearchPhaseResultsProcessor create(
7573
TECHNIQUE,
7674
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME
7775
);
78-
// case when technique is z score and combination is not arithmetic mean
79-
if (normalizationTechniqueName.equals(ZScoreNormalizationTechnique.TECHNIQUE_NAME)
80-
&& !combinationTechnique.equals(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME)) {
81-
throw new IllegalArgumentException("Z Score supports only arithmetic_mean combination technique");
82-
}
8376
// check for optional combination params
8477
Map<String, Object> combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS);
8578
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
8679
}
80+
81+
normalizationTechnique.validateCombinationTechnique(scoreCombinationTechnique);
82+
8783
log.info(
8884
"Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]",
8985
NormalizationProcessor.TYPE,

src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
1919

2020
import lombok.ToString;
21+
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
22+
import org.opensearch.neuralsearch.processor.combination.GeometricMeanScoreCombinationTechnique;
23+
import org.opensearch.neuralsearch.processor.combination.HarmonicMeanScoreCombinationTechnique;
24+
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
2125
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
2226
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
2327
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
@@ -69,6 +73,21 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
6973
}
7074
}
7175

76+
@Override
77+
public void validateCombinationTechnique(ScoreCombinationTechnique combinationTechnique) {
78+
switch (combinationTechnique.techniqueName()) {
79+
case ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
80+
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME:
81+
// These are the supported technique, so we do nothing
82+
break;
83+
default:
84+
throw new IllegalArgumentException(
85+
"Z Score does not support the provided combination technique {}: Supported techniques are arithmetic_mean, geometric_mean and harmonic_mean"
86+
+ combinationTechnique.techniqueName()
87+
);
88+
}
89+
}
90+
7291
@Override
7392
public String describe() {
7493
return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME);

src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727

2828
import lombok.ToString;
2929
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
30+
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
31+
import org.opensearch.neuralsearch.processor.combination.GeometricMeanScoreCombinationTechnique;
32+
import org.opensearch.neuralsearch.processor.combination.HarmonicMeanScoreCombinationTechnique;
33+
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
3034
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
3135
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
3236
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
@@ -126,6 +130,21 @@ private MinMaxScores getMinMaxScoresResult(final List<CompoundTopDocs> queryTopD
126130
return new MinMaxScores(minScoresPerSubquery, maxScoresPerSubquery);
127131
}
128132

133+
@Override
134+
public void validateCombinationTechnique(ScoreCombinationTechnique combinationTechnique) {
135+
switch (combinationTechnique.techniqueName()) {
136+
case ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
137+
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME:
138+
// These are the supported technique, so we do nothing
139+
break;
140+
default:
141+
throw new IllegalArgumentException(
142+
"Z Score does not support the provided combination technique {}: Supported techniques are arithmetic_mean, geometric_mean and harmonic_mean"
143+
+ combinationTechnique.techniqueName()
144+
);
145+
}
146+
}
147+
129148
@Override
130149
public String describe() {
131150
return lowerBoundsOptional.map(lb -> {

src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
import lombok.ToString;
2525
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
2626
import org.opensearch.neuralsearch.processor.SearchShard;
27+
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
28+
import org.opensearch.neuralsearch.processor.combination.GeometricMeanScoreCombinationTechnique;
29+
import org.opensearch.neuralsearch.processor.combination.HarmonicMeanScoreCombinationTechnique;
30+
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
2731
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
2832
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
2933
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
@@ -76,6 +80,21 @@ public String describe() {
7680
return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, rankConstant);
7781
}
7882

83+
@Override
84+
public void validateCombinationTechnique(ScoreCombinationTechnique combinationTechnique) {
85+
switch (combinationTechnique.techniqueName()) {
86+
case ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
87+
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME:
88+
// These are the supported technique, so we do nothing
89+
break;
90+
default:
91+
throw new IllegalArgumentException(
92+
"Z Score does not support the provided combination technique {}: Supported techniques are arithmetic_mean, geometric_mean and harmonic_mean"
93+
+ combinationTechnique.techniqueName()
94+
);
95+
}
96+
}
97+
7998
@Override
8099
public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
81100
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();

src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package org.opensearch.neuralsearch.processor.normalization;
66

77
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
8+
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
89

910
/**
1011
* Abstracts normalization of scores in query search results.
@@ -20,4 +21,5 @@ public interface ScoreNormalizationTechnique {
2021
*/
2122
void normalize(final NormalizeScoresDTO normalizeScoresDTO);
2223

24+
void validateCombinationTechnique(final ScoreCombinationTechnique combinationTechnique) throws IllegalArgumentException;
2325
}

0 commit comments

Comments
 (0)