Skip to content

Commit 31a6a07

Browse files
committed
Created DTO for validation
Signed-off-by: Owais <[email protected]>
1 parent 5ae7034 commit 31a6a07

10 files changed

+138
-121
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor;
6+
7+
import lombok.AllArgsConstructor;
8+
import lombok.Builder;
9+
import lombok.Getter;
10+
import lombok.NonNull;
11+
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
12+
13+
/**
14+
* DTO object to hold data required for validation.
15+
*/
16+
@AllArgsConstructor
17+
@Builder
18+
@Getter
19+
public class ValidateNormalizationDTO {
20+
@NonNull
21+
private ScoreCombinationTechnique scoreCombinationTechnique;
22+
}

src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
1414
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
15+
import org.opensearch.neuralsearch.processor.ValidateNormalizationDTO;
1516
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
1617
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
1718
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
@@ -78,7 +79,10 @@ public SearchPhaseResultsProcessor create(
7879
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
7980
}
8081

81-
normalizationTechnique.validateCombinationTechnique(scoreCombinationTechnique);
82+
ValidateNormalizationDTO validateDTO = ValidateNormalizationDTO.builder()
83+
.scoreCombinationTechnique(scoreCombinationTechnique)
84+
.build();
85+
scoreNormalizationFactory.validateNormalizationTechnique(normalizationTechnique, validateDTO);
8286

8387
log.info(
8488
"Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]",

src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java

+4-23
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,12 @@
1818
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
1919

2020
import lombok.ToString;
21-
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
22-
import org.opensearch.neuralsearch.processor.combination.GeometricMeanScoreCombinationTechnique;
23-
import org.opensearch.neuralsearch.processor.combination.HarmonicMeanScoreCombinationTechnique;
24-
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
2521
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
2622
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
2723
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
2824

2925
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization;
26+
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getNumOfSubqueries;
3027

3128
/**
3229
* Abstracts normalization of scores based on L2 method
@@ -74,18 +71,8 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
7471
}
7572

7673
@Override
77-
public void validateCombinationTechnique(ScoreCombinationTechnique combinationTechnique) {
78-
switch (combinationTechnique.techniqueName()) {
79-
case ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
80-
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME:
81-
// These are the supported technique, so we do nothing
82-
break;
83-
default:
84-
throw new IllegalArgumentException(
85-
"Z Score does not support the provided combination technique {}: Supported techniques are arithmetic_mean, geometric_mean and harmonic_mean"
86-
+ combinationTechnique.techniqueName()
87-
);
88-
}
74+
public String techniqueName() {
75+
return TECHNIQUE_NAME;
8976
}
9077

9178
@Override
@@ -127,13 +114,7 @@ private List<Float> getL2Norm(final List<CompoundTopDocs> queryTopDocs) {
127114
// find any non-empty compound top docs, it's either empty if shard does not have any results for all of sub-queries,
128115
// or it has results for all the sub-queries. In edge case of shard having results only for one sub-query, there will be TopDocs for
129116
// rest of sub-queries with zero total hits
130-
int numOfSubqueries = queryTopDocs.stream()
131-
.filter(Objects::nonNull)
132-
.filter(topDocs -> topDocs.getTopDocs().size() > 0)
133-
.findAny()
134-
.get()
135-
.getTopDocs()
136-
.size();
117+
int numOfSubqueries = getNumOfSubqueries(queryTopDocs);
137118
float[] l2Norms = new float[numOfSubqueries];
138119
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
139120
if (Objects.isNull(compoundQueryTopDocs)) {

src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java

+3-26
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,12 @@
2727

2828
import lombok.ToString;
2929
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
30-
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
31-
import org.opensearch.neuralsearch.processor.combination.GeometricMeanScoreCombinationTechnique;
32-
import org.opensearch.neuralsearch.processor.combination.HarmonicMeanScoreCombinationTechnique;
33-
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
3430
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
3531
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
3632
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
3733

3834
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization;
35+
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getNumOfSubqueries;
3936
import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES;
4037

4138
/**
@@ -131,18 +128,8 @@ private MinMaxScores getMinMaxScoresResult(final List<CompoundTopDocs> queryTopD
131128
}
132129

133130
@Override
134-
public void validateCombinationTechnique(ScoreCombinationTechnique combinationTechnique) {
135-
switch (combinationTechnique.techniqueName()) {
136-
case ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
137-
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME:
138-
// These are the supported technique, so we do nothing
139-
break;
140-
default:
141-
throw new IllegalArgumentException(
142-
"Z Score does not support the provided combination technique {}: Supported techniques are arithmetic_mean, geometric_mean and harmonic_mean"
143-
+ combinationTechnique.techniqueName()
144-
);
145-
}
131+
public String techniqueName() {
132+
return TECHNIQUE_NAME;
146133
}
147134

148135
@Override
@@ -191,16 +178,6 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(final List<CompoundTo
191178
return getDocIdAtQueryForNormalization(normalizedScores, this);
192179
}
193180

194-
private int getNumOfSubqueries(final List<CompoundTopDocs> queryTopDocs) {
195-
return queryTopDocs.stream()
196-
.filter(Objects::nonNull)
197-
.filter(topDocs -> topDocs.getTopDocs().isEmpty() == false)
198-
.findAny()
199-
.get()
200-
.getTopDocs()
201-
.size();
202-
}
203-
204181
private float[] getMaxScores(final List<CompoundTopDocs> queryTopDocs, final int numOfSubqueries) {
205182
float[] maxScores = new float[numOfSubqueries];
206183
Arrays.fill(maxScores, Float.MIN_VALUE);

src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java

+2-16
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@
2424
import lombok.ToString;
2525
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
2626
import org.opensearch.neuralsearch.processor.SearchShard;
27-
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
28-
import org.opensearch.neuralsearch.processor.combination.GeometricMeanScoreCombinationTechnique;
29-
import org.opensearch.neuralsearch.processor.combination.HarmonicMeanScoreCombinationTechnique;
30-
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
3127
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
3228
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
3329
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
@@ -81,18 +77,8 @@ public String describe() {
8177
}
8278

8379
@Override
84-
public void validateCombinationTechnique(ScoreCombinationTechnique combinationTechnique) {
85-
switch (combinationTechnique.techniqueName()) {
86-
case ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
87-
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME:
88-
// These are the supported technique, so we do nothing
89-
break;
90-
default:
91-
throw new IllegalArgumentException(
92-
"Z Score does not support the provided combination technique {}: Supported techniques are arithmetic_mean, geometric_mean and harmonic_mean"
93-
+ combinationTechnique.techniqueName()
94-
);
95-
}
80+
public String techniqueName() {
81+
return TECHNIQUE_NAME;
9682
}
9783

9884
@Override

src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java

+55
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
*/
55
package org.opensearch.neuralsearch.processor.normalization;
66

7+
import org.opensearch.neuralsearch.processor.ValidateNormalizationDTO;
8+
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
9+
import org.opensearch.neuralsearch.processor.combination.GeometricMeanScoreCombinationTechnique;
10+
import org.opensearch.neuralsearch.processor.combination.HarmonicMeanScoreCombinationTechnique;
11+
12+
import java.util.List;
13+
import java.util.Locale;
714
import java.util.Map;
815
import java.util.Optional;
916
import java.util.function.Function;
@@ -28,6 +35,29 @@ public class ScoreNormalizationFactory {
2835
params -> new ZScoreNormalizationTechnique()
2936
);
3037

38+
private final Map<String, List<String>> combinationTechniqueForNormalizationTechniqueMap = Map.of(
39+
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME,
40+
List.of(
41+
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
42+
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
43+
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME
44+
),
45+
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
46+
List.of(
47+
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
48+
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
49+
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME
50+
),
51+
RRFNormalizationTechnique.TECHNIQUE_NAME,
52+
List.of(
53+
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
54+
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
55+
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME
56+
),
57+
ZScoreNormalizationTechnique.TECHNIQUE_NAME,
58+
List.of(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME)
59+
);
60+
3161
/**
3262
* Get score normalization method by technique name
3363
* @param technique name of technique
@@ -42,4 +72,29 @@ public ScoreNormalizationTechnique createNormalization(final String technique, f
4272
.orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported"))
4373
.apply(params);
4474
}
75+
76+
/**
77+
* Validate normalization technique based on combination technique and other params that needs to be validated
78+
* @param normalizationTechnique normalization technique to be validated
79+
* @param validateNormalizationDTO data transfer object that contains combination technique and other params that needs to be validated
80+
*/
81+
public void validateNormalizationTechnique(
82+
ScoreNormalizationTechnique normalizationTechnique,
83+
ValidateNormalizationDTO validateNormalizationDTO
84+
) {
85+
List<String> supportedTechniques = combinationTechniqueForNormalizationTechniqueMap.get(normalizationTechnique.techniqueName());
86+
87+
if (!supportedTechniques.contains(validateNormalizationDTO.getScoreCombinationTechnique().techniqueName())) {
88+
throw new IllegalArgumentException(
89+
String.format(
90+
Locale.ROOT,
91+
"provided combination technique %s is not supported for normalization technique %s. Supported techniques are: %s",
92+
validateNormalizationDTO.getScoreCombinationTechnique().techniqueName(),
93+
normalizationTechnique.techniqueName(),
94+
String.join(", ", supportedTechniques)
95+
)
96+
);
97+
}
98+
}
99+
45100
}

src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
package org.opensearch.neuralsearch.processor.normalization;
66

77
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
8-
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
98

109
/**
1110
* Abstracts normalization of scores in query search results.
@@ -21,5 +20,5 @@ public interface ScoreNormalizationTechnique {
2120
*/
2221
void normalize(final NormalizeScoresDTO normalizeScoresDTO);
2322

24-
void validateCombinationTechnique(final ScoreCombinationTechnique combinationTechnique) throws IllegalArgumentException;
23+
String techniqueName();
2524
}

src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java

+22-51
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,12 @@
1919
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
2020

2121
import com.google.common.primitives.Floats;
22-
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
23-
import org.opensearch.neuralsearch.processor.combination.GeometricMeanScoreCombinationTechnique;
24-
import org.opensearch.neuralsearch.processor.combination.HarmonicMeanScoreCombinationTechnique;
25-
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
2622
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
2723
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
2824
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
2925

30-
import java.util.List;
31-
import java.util.Objects;
32-
3326
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization;
27+
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getNumOfSubqueries;
3428

3529
/**
3630
* Abstracts normalization of scores based on z score method
@@ -85,21 +79,8 @@ public void normalize(NormalizeScoresDTO normalizeScoresDTO) {
8579
}
8680

8781
@Override
88-
public void validateCombinationTechnique(ScoreCombinationTechnique combinationTechnique) throws IllegalArgumentException {
89-
switch (combinationTechnique.techniqueName()) {
90-
case ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME:
91-
// This is the supported technique, so we do nothing
92-
break;
93-
case GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME:
94-
throw new IllegalArgumentException("Z Score does not support geometric mean combination technique");
95-
case HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME:
96-
throw new IllegalArgumentException("Z Score does not support harmonic mean combination technique");
97-
default:
98-
throw new IllegalArgumentException(
99-
"Z Score does not support the provided combination technique {}: Supported technique is arithmetic_mean"
100-
+ combinationTechnique.techniqueName()
101-
);
102-
}
82+
public String techniqueName() {
83+
return TECHNIQUE_NAME;
10384
}
10485

10586
@Override
@@ -143,41 +124,31 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(final List<CompoundTo
143124
return getDocIdAtQueryForNormalization(normalizedScores, this);
144125
}
145126

146-
private int getNumOfSubqueries(final List<CompoundTopDocs> queryTopDocs) {
147-
return queryTopDocs.stream()
148-
.filter(Objects::nonNull)
149-
.filter(topDocs -> !topDocs.getTopDocs().isEmpty())
150-
.findAny()
151-
.get()
152-
.getTopDocs()
153-
.size();
154-
}
155-
156-
private static float[] calculateMaxScorePerSubquery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
157-
DescriptiveStatistics[] statsPerSubquery = calculateStatsPerSubquery(queryTopDocs, numOfScores);
127+
private static float[] calculateMaxScorePerSubquery(final List<CompoundTopDocs> queryTopDocs, final int numOfSubqueries) {
128+
DescriptiveStatistics[] statsPerSubquery = calculateStatsPerSubquery(queryTopDocs, numOfSubqueries);
158129

159-
float[] maxPerSubQuery = new float[numOfScores];
160-
for (int i = 0; i < numOfScores; i++) {
130+
float[] maxPerSubQuery = new float[numOfSubqueries];
131+
for (int i = 0; i < numOfSubqueries; i++) {
161132
maxPerSubQuery[i] = (float) statsPerSubquery[i].getMax();
162133
}
163134

164135
return maxPerSubQuery;
165136
}
166137

167-
private static float[] calculateMinScorePerSubquery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
168-
DescriptiveStatistics[] statsPerSubquery = calculateStatsPerSubquery(queryTopDocs, numOfScores);
138+
private static float[] calculateMinScorePerSubquery(final List<CompoundTopDocs> queryTopDocs, final int numOfSubqueries) {
139+
DescriptiveStatistics[] statsPerSubquery = calculateStatsPerSubquery(queryTopDocs, numOfSubqueries);
169140

170-
float[] minPerSubQuery = new float[numOfScores];
171-
for (int i = 0; i < numOfScores; i++) {
141+
float[] minPerSubQuery = new float[numOfSubqueries];
142+
for (int i = 0; i < numOfSubqueries; i++) {
172143
minPerSubQuery[i] = (float) statsPerSubquery[i].getMin();
173144
}
174145

175146
return minPerSubQuery;
176147
}
177148

178-
private static DescriptiveStatistics[] calculateStatsPerSubquery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
179-
DescriptiveStatistics[] statsPerSubquery = new DescriptiveStatistics[numOfScores];
180-
for (int i = 0; i < numOfScores; i++) {
149+
private static DescriptiveStatistics[] calculateStatsPerSubquery(final List<CompoundTopDocs> queryTopDocs, final int numOfSubqueries) {
150+
DescriptiveStatistics[] statsPerSubquery = new DescriptiveStatistics[numOfSubqueries];
151+
for (int i = 0; i < numOfSubqueries; i++) {
181152
statsPerSubquery[i] = new DescriptiveStatistics();
182153
}
183154

@@ -197,22 +168,22 @@ private static DescriptiveStatistics[] calculateStatsPerSubquery(final List<Comp
197168
return statsPerSubquery;
198169
}
199170

200-
private static float[] calculateMeanPerSubquery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
201-
DescriptiveStatistics[] statsPerSubquery = calculateStatsPerSubquery(queryTopDocs, numOfScores);
171+
private static float[] calculateMeanPerSubquery(final List<CompoundTopDocs> queryTopDocs, final int numOfSubqueries) {
172+
DescriptiveStatistics[] statsPerSubquery = calculateStatsPerSubquery(queryTopDocs, numOfSubqueries);
202173

203-
float[] meanPerSubQuery = new float[numOfScores];
204-
for (int i = 0; i < numOfScores; i++) {
174+
float[] meanPerSubQuery = new float[numOfSubqueries];
175+
for (int i = 0; i < numOfSubqueries; i++) {
205176
meanPerSubQuery[i] = (float) statsPerSubquery[i].getMean();
206177
}
207178

208179
return meanPerSubQuery;
209180
}
210181

211-
private static float[] calculateStandardDeviationPerSubquery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
212-
DescriptiveStatistics[] statsPerSubquery = calculateStatsPerSubquery(queryTopDocs, numOfScores);
182+
private static float[] calculateStandardDeviationPerSubquery(final List<CompoundTopDocs> queryTopDocs, final int numOfSubqueries) {
183+
DescriptiveStatistics[] statsPerSubquery = calculateStatsPerSubquery(queryTopDocs, numOfSubqueries);
213184

214-
float[] stdPerSubQuery = new float[numOfScores];
215-
for (int i = 0; i < numOfScores; i++) {
185+
float[] stdPerSubQuery = new float[numOfSubqueries];
186+
for (int i = 0; i < numOfSubqueries; i++) {
216187
stdPerSubQuery[i] = (float) statsPerSubquery[i].getStandardDeviation();
217188
}
218189

0 commit comments

Comments
 (0)