Skip to content

Commit f5377c0

Browse files
authored
Adds ZScore Normalization Technique (#1224)
* Initial implementation of zscore Signed-off-by: Owais <[email protected]> * Added negative technique for GP and HP Signed-off-by: Owais <[email protected]> * Minor refactoring and removed negative HP and GP Signed-off-by: Owais <[email protected]> * Added UTs and ITs and few edge cases Signed-off-by: Owais <[email protected]> * Used DescriptiveStatistics and added validations for combination technique Signed-off-by: Owais <[email protected]> * Created DTO for validation Signed-off-by: Owais <[email protected]> * Addressed PR comments Signed-off-by: Owais <[email protected]> --------- Signed-off-by: Owais <[email protected]>
1 parent 8abb418 commit f5377c0

21 files changed

+898
-33
lines changed

CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@ All notable changes to this project are documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). See the [CONTRIBUTING guide](./CONTRIBUTING.md#Changelog) for instructions on how to add changelog entries.
55

66
## [Unreleased 3.x](https://github.com/opensearch-project/neural-search/compare/main...HEAD)
7+
78
### Features
89
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
910
- Support filter function for HybridQueryBuilder and NeuralQueryBuilder ([#1206](https://github.com/opensearch-project/neural-search/pull/1206))
11+
- Add Z Score normalization technique ([#1224](https://github.com/opensearch-project/neural-search/pull/1224))
12+
1013
### Enhancements
14+
1115
### Bug Fixes
16+
1217
### Infrastructure
18+
1319
### Documentation
20+
1421
### Maintenance
22+
1523
### Refactoring

DEVELOPER_GUIDE.md

+12-12
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,9 @@ through the same build issue.
351351

352352
### Class and package names
353353

354-
Class names should use `CamelCase`.
354+
Class names should use `CamelCase`.
355355

356-
Try to put new classes into existing packages if package name abstracts the purpose of the class.
356+
Try to put new classes into existing packages if package name abstracts the purpose of the class.
357357

358358
Example of good class file name and package utilization:
359359

@@ -371,7 +371,7 @@ methods rather than a long single one and does everything.
371371
### Documentation
372372

373373
Document you code. That includes purpose of new classes, every public method and code sections that have critical or non-trivial
374-
logic (check this example https://github.com/opensearch-project/neural-search/blob/main/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java#L238).
374+
logic (check this example https://github.com/opensearch-project/neural-search/blob/main/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java#L238).
375375

376376
When you submit a feature PR, please submit a new
377377
[documentation issue](https://github.com/opensearch-project/documentation-website/issues/new/choose). This is a path for the documentation to be published as part of https://opensearch.org/docs/latest/ documentation site.
@@ -384,17 +384,17 @@ For the most part, we're using common conventions for Java projects. Here are a
384384

385385
1. Use descriptive names for classes, methods, fields, and variables.
386386
2. Avoid abbreviations unless they are widely accepted
387-
3. Use `final` on all method arguments unless it's absolutely necessary
387+
3. Use `final` on all method arguments unless it's absolutely necessary
388388
4. Wildcard imports are not allowed.
389389
5. Static imports are preferred over qualified imports when using static methods
390390
6. Prefer creating non-static public methods whenever possible. Avoid static methods in general, as they can often serve as shortcuts.
391391
Static methods are acceptable if they are private and do not access class state.
392-
7. Use functional programming style inside methods unless it's a performance critical section.
392+
7. Use functional programming style inside methods unless it's a performance critical section.
393393
8. For parameters of lambda expression please use meaningful names instead of shorten cryptic ones.
394394
9. Use Optional for return values if the value may not be present. This should be preferred to returning null.
395395
10. Do not create checked exceptions, and do not throw checked exceptions from public methods whenever possible. In general, if you call a method with a checked exception, you should wrap that exception into an unchecked exception.
396396
11. Throwing checked exceptions from private methods is acceptable.
397-
12. Use String.format when a string includes parameters, and prefer this over direct string concatenation. Always specify a Locale with String.format;
397+
12. Use String.format when a string includes parameters, and prefer this over direct string concatenation. Always specify a Locale with String.format;
398398
as a rule of thumb, use Locale.ROOT.
399399
13. Prefer Lombok annotations to the manually written boilerplate code
400400
14. When throwing an exception, avoid including user-provided content in the exception message. For secure coding practices,
@@ -440,17 +440,17 @@ Fix any new warnings before submitting your PR to ensure proper code documentati
440440

441441
### Tests
442442

443-
Write unit and integration tests for your new functionality.
443+
Write unit and integration tests for your new functionality.
444444

445445
Unit tests are preferred as they are cheap and fast, try to use them to cover all possible
446-
combinations of parameters. Utilize mocks to mimic dependencies.
446+
combinations of parameters. Utilize mocks to mimic dependencies.
447447

448-
Integration tests should be used sparingly, focusing primarily on the main (happy path) scenario or cases where extensive
449-
mocking is impractical. Include one or two unhappy paths to confirm that correct response codes are returned to the user.
450-
Whenever possible, favor scenarios that do not require model deployment. If model deployment is necessary, use an existing
448+
Integration tests should be used sparingly, focusing primarily on the main (happy path) scenario or cases where extensive
449+
mocking is impractical. Include one or two unhappy paths to confirm that correct response codes are returned to the user.
450+
Whenever possible, favor scenarios that do not require model deployment. If model deployment is necessary, use an existing
451451
model, as tests involving new model deployments are the most resource-intensive.
452452

453-
If your changes could affect backward compatibility, please include relevant backward compatibility tests along with your
453+
If your changes could affect backward compatibility, please include relevant backward compatibility tests along with your
454454
PR. For guidance on adding these tests, refer to the [Backwards Compatibility Testing](#backwards-compatibility-testing) section in this guide.
455455

456456
### Outdated or irrelevant code

build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ dependencies {
260260
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
261261
testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
262262
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
263+
implementation group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1'
263264
// ml-common excluded reflection for runtime so we need to add it by ourselves.
264265
// https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9
265266
// TODO: Remove following three lines of dependencies if ml-common include them in their jar
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor;
6+
7+
import lombok.AllArgsConstructor;
8+
import lombok.Builder;
9+
import lombok.Getter;
10+
import lombok.NonNull;
11+
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
12+
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
13+
14+
/**
15+
* DTO object to hold data required for validation.
16+
*/
17+
@AllArgsConstructor
18+
@Builder
19+
@Getter
20+
public class TechniqueCompatibilityCheckDTO {
21+
@NonNull
22+
private ScoreCombinationTechnique scoreCombinationTechnique;
23+
@NonNull
24+
private ScoreNormalizationTechnique scoreNormalizationTechnique;
25+
}

src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java

+5
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ public float combine(final float[] scores) {
5858
return combinedScore / sumOfWeights;
5959
}
6060

61+
@Override
62+
public String techniqueName() {
63+
return TECHNIQUE_NAME;
64+
}
65+
6166
@Override
6267
public String describe() {
6368
return describeCombinationTechnique(TECHNIQUE_NAME, weights);

src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java

+5
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ public float combine(final float[] scores) {
5858
return sumOfWeights == 0 ? ZERO_SCORE : (float) Math.exp(weightedLnSum / sumOfWeights);
5959
}
6060

61+
@Override
62+
public String techniqueName() {
63+
return TECHNIQUE_NAME;
64+
}
65+
6166
@Override
6267
public String describe() {
6368
return describeCombinationTechnique(TECHNIQUE_NAME, weights);

src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java

+5
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ public float combine(final float[] scores) {
5555
return sumOfHarmonics > 0 ? sumOfWeights / sumOfHarmonics : ZERO_SCORE;
5656
}
5757

58+
@Override
59+
public String techniqueName() {
60+
return TECHNIQUE_NAME;
61+
}
62+
5863
@Override
5964
public String describe() {
6065
return describeCombinationTechnique(TECHNIQUE_NAME, weights);

src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java

+5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ public float combine(final float[] scores) {
3737
return sumScores;
3838
}
3939

40+
@Override
41+
public String techniqueName() {
42+
return TECHNIQUE_NAME;
43+
}
44+
4045
@Override
4146
public String describe() {
4247
return describeCombinationTechnique(TECHNIQUE_NAME, List.of());

src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java

+5
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,9 @@ public interface ScoreCombinationTechnique {
1212
* @return combined score
1313
*/
1414
float combine(final float[] scores);
15+
16+
/**
17+
* Returns the name of the combination technique.
18+
*/
19+
String techniqueName();
1520
}

src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java

+8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
1414
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
15+
import org.opensearch.neuralsearch.processor.TechniqueCompatibilityCheckDTO;
1516
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
1617
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
1718
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
@@ -77,6 +78,13 @@ public SearchPhaseResultsProcessor create(
7778
Map<String, Object> combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS);
7879
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
7980
}
81+
82+
TechniqueCompatibilityCheckDTO techniqueCompatibilityCheckDTO = TechniqueCompatibilityCheckDTO.builder()
83+
.scoreNormalizationTechnique(normalizationTechnique)
84+
.scoreCombinationTechnique(scoreCombinationTechnique)
85+
.build();
86+
scoreNormalizationFactory.isTechniquesCompatible(techniqueCompatibilityCheckDTO);
87+
8088
log.info(
8189
"Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]",
8290
NormalizationProcessor.TYPE,

src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
2424

2525
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization;
26+
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getNumOfSubqueries;
2627

2728
/**
2829
* Abstracts normalization of scores based on L2 method
@@ -69,6 +70,11 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
6970
}
7071
}
7172

73+
@Override
74+
public String techniqueName() {
75+
return TECHNIQUE_NAME;
76+
}
77+
7278
@Override
7379
public String describe() {
7480
return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME);
@@ -108,13 +114,7 @@ private List<Float> getL2Norm(final List<CompoundTopDocs> queryTopDocs) {
108114
// find any non-empty compound top docs, it's either empty if shard does not have any results for all of sub-queries,
109115
// or it has results for all the sub-queries. In edge case of shard having results only for one sub-query, there will be TopDocs for
110116
// rest of sub-queries with zero total hits
111-
int numOfSubqueries = queryTopDocs.stream()
112-
.filter(Objects::nonNull)
113-
.filter(topDocs -> topDocs.getTopDocs().size() > 0)
114-
.findAny()
115-
.get()
116-
.getTopDocs()
117-
.size();
117+
int numOfSubqueries = getNumOfSubqueries(queryTopDocs);
118118
float[] l2Norms = new float[numOfSubqueries];
119119
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
120120
if (Objects.isNull(compoundQueryTopDocs)) {

src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java

+6-10
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
3333

3434
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization;
35+
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getNumOfSubqueries;
3536
import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES;
3637

3738
/**
@@ -126,6 +127,11 @@ private MinMaxScores getMinMaxScoresResult(final List<CompoundTopDocs> queryTopD
126127
return new MinMaxScores(minScoresPerSubquery, maxScoresPerSubquery);
127128
}
128129

130+
@Override
131+
public String techniqueName() {
132+
return TECHNIQUE_NAME;
133+
}
134+
129135
@Override
130136
public String describe() {
131137
return lowerBoundsOptional.map(lb -> {
@@ -172,16 +178,6 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(final List<CompoundTo
172178
return getDocIdAtQueryForNormalization(normalizedScores, this);
173179
}
174180

175-
private int getNumOfSubqueries(final List<CompoundTopDocs> queryTopDocs) {
176-
return queryTopDocs.stream()
177-
.filter(Objects::nonNull)
178-
.filter(topDocs -> topDocs.getTopDocs().isEmpty() == false)
179-
.findAny()
180-
.get()
181-
.getTopDocs()
182-
.size();
183-
}
184-
185181
private float[] getMaxScores(final List<CompoundTopDocs> queryTopDocs, final int numOfSubqueries) {
186182
float[] maxScores = new float[numOfSubqueries];
187183
Arrays.fill(maxScores, Float.MIN_VALUE);

src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java

+5
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ public String describe() {
7676
return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, rankConstant);
7777
}
7878

79+
@Override
80+
public String techniqueName() {
81+
return TECHNIQUE_NAME;
82+
}
83+
7984
@Override
8085
public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
8186
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();

src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java

+57-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
*/
55
package org.opensearch.neuralsearch.processor.normalization;
66

7+
import org.opensearch.neuralsearch.processor.TechniqueCompatibilityCheckDTO;
8+
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
9+
import org.opensearch.neuralsearch.processor.combination.GeometricMeanScoreCombinationTechnique;
10+
import org.opensearch.neuralsearch.processor.combination.HarmonicMeanScoreCombinationTechnique;
11+
12+
import java.util.Locale;
713
import java.util.Map;
14+
import java.util.Set;
815
import java.util.Optional;
916
import java.util.function.Function;
1017

@@ -17,13 +24,38 @@ public class ScoreNormalizationFactory {
1724

1825
public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique();
1926

20-
private final Map<String, Function<Map<String, Object>, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of(
27+
private static final Map<String, Function<Map<String, Object>, ScoreNormalizationTechnique>> SCORE_NORMALIZATION_METHODS = Map.of(
2128
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME,
2229
params -> new MinMaxScoreNormalizationTechnique(params, scoreNormalizationUtil),
2330
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
2431
params -> new L2ScoreNormalizationTechnique(params, scoreNormalizationUtil),
2532
RRFNormalizationTechnique.TECHNIQUE_NAME,
26-
params -> new RRFNormalizationTechnique(params, scoreNormalizationUtil)
33+
params -> new RRFNormalizationTechnique(params, scoreNormalizationUtil),
34+
ZScoreNormalizationTechnique.TECHNIQUE_NAME,
35+
params -> new ZScoreNormalizationTechnique()
36+
);
37+
38+
private static final Map<String, Set<String>> COMBINATION_TECHNIQUE_FOR_NORMALIZATION_METHODS = Map.of(
39+
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME,
40+
Set.of(
41+
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
42+
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
43+
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME
44+
),
45+
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
46+
Set.of(
47+
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
48+
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
49+
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME
50+
),
51+
RRFNormalizationTechnique.TECHNIQUE_NAME,
52+
Set.of(
53+
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
54+
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
55+
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME
56+
),
57+
ZScoreNormalizationTechnique.TECHNIQUE_NAME,
58+
Set.of(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME)
2759
);
2860

2961
/**
@@ -36,8 +68,30 @@ public ScoreNormalizationTechnique createNormalization(final String technique) {
3668
}
3769

3870
public ScoreNormalizationTechnique createNormalization(final String technique, final Map<String, Object> params) {
39-
return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique))
71+
return Optional.ofNullable(SCORE_NORMALIZATION_METHODS.get(technique))
4072
.orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported"))
4173
.apply(params);
4274
}
75+
76+
/**
77+
* Validate normalization technique based on combination technique and other params that needs to be validated
78+
* @param techniqueCompatibilityCheckDTO data transfer object that contains combination technique and other params that needs to be validated
79+
*/
80+
public void isTechniquesCompatible(TechniqueCompatibilityCheckDTO techniqueCompatibilityCheckDTO) {
81+
ScoreNormalizationTechnique normalizationTechnique = techniqueCompatibilityCheckDTO.getScoreNormalizationTechnique();
82+
Set<String> supportedTechniques = COMBINATION_TECHNIQUE_FOR_NORMALIZATION_METHODS.get(normalizationTechnique.techniqueName());
83+
84+
if (supportedTechniques.contains(techniqueCompatibilityCheckDTO.getScoreCombinationTechnique().techniqueName()) == false) {
85+
throw new IllegalArgumentException(
86+
String.format(
87+
Locale.ROOT,
88+
"provided combination technique %s is not supported for normalization technique %s. Supported techniques are: %s",
89+
techniqueCompatibilityCheckDTO.getScoreCombinationTechnique().techniqueName(),
90+
normalizationTechnique.techniqueName(),
91+
String.join(", ", supportedTechniques)
92+
)
93+
);
94+
}
95+
}
96+
4397
}

src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java

+4
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,8 @@ public interface ScoreNormalizationTechnique {
2020
*/
2121
void normalize(final NormalizeScoresDTO normalizeScoresDTO);
2222

23+
/**
24+
* Returns the name of the normalization technique.
25+
*/
26+
String techniqueName();
2327
}

0 commit comments

Comments
 (0)