Skip to content

Commit 1f42ae8

Browse files
committed
Addressed PR comments
Signed-off-by: Owais <[email protected]>
1 parent 4570059 commit 1f42ae8

File tree

5 files changed

+28
-21
lines changed

5 files changed

+28
-21
lines changed

src/main/java/org/opensearch/neuralsearch/processor/ValidateNormalizationDTO.java src/main/java/org/opensearch/neuralsearch/processor/TechniqueCompatibilityCheckDTO.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99
import lombok.Getter;
1010
import lombok.NonNull;
1111
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
12+
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
1213

1314
/**
1415
* DTO object to hold data required for validation.
1516
*/
1617
@AllArgsConstructor
1718
@Builder
1819
@Getter
19-
public class ValidateNormalizationDTO {
20+
public class TechniqueCompatibilityCheckDTO {
2021
@NonNull
2122
private ScoreCombinationTechnique scoreCombinationTechnique;
23+
@NonNull
24+
private ScoreNormalizationTechnique scoreNormalizationTechnique;
2225
}

src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java

+3
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,8 @@ public interface ScoreCombinationTechnique {
1313
*/
1414
float combine(final float[] scores);
1515

16+
/**
17+
* Returns the name of the combination technique.
18+
*/
1619
String techniqueName();
1720
}

src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
1414
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
15-
import org.opensearch.neuralsearch.processor.ValidateNormalizationDTO;
15+
import org.opensearch.neuralsearch.processor.TechniqueCompatibilityCheckDTO;
1616
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
1717
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
1818
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
@@ -79,10 +79,11 @@ public SearchPhaseResultsProcessor create(
7979
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
8080
}
8181

82-
ValidateNormalizationDTO validateDTO = ValidateNormalizationDTO.builder()
82+
TechniqueCompatibilityCheckDTO techniqueCompatibilityCheckDTO = TechniqueCompatibilityCheckDTO.builder()
83+
.scoreNormalizationTechnique(normalizationTechnique)
8384
.scoreCombinationTechnique(scoreCombinationTechnique)
8485
.build();
85-
scoreNormalizationFactory.validateNormalizationTechnique(normalizationTechnique, validateDTO);
86+
scoreNormalizationFactory.isTechniquesCompatible(techniqueCompatibilityCheckDTO);
8687

8788
log.info(
8889
"Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]",

src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java

+14-17
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
*/
55
package org.opensearch.neuralsearch.processor.normalization;
66

7-
import org.opensearch.neuralsearch.processor.ValidateNormalizationDTO;
7+
import org.opensearch.neuralsearch.processor.TechniqueCompatibilityCheckDTO;
88
import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique;
99
import org.opensearch.neuralsearch.processor.combination.GeometricMeanScoreCombinationTechnique;
1010
import org.opensearch.neuralsearch.processor.combination.HarmonicMeanScoreCombinationTechnique;
1111

12-
import java.util.List;
1312
import java.util.Locale;
1413
import java.util.Map;
14+
import java.util.Set;
1515
import java.util.Optional;
1616
import java.util.function.Function;
1717

@@ -24,7 +24,7 @@ public class ScoreNormalizationFactory {
2424

2525
public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique();
2626

27-
private final Map<String, Function<Map<String, Object>, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of(
27+
private static final Map<String, Function<Map<String, Object>, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of(
2828
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME,
2929
params -> new MinMaxScoreNormalizationTechnique(params, scoreNormalizationUtil),
3030
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
@@ -35,27 +35,27 @@ public class ScoreNormalizationFactory {
3535
params -> new ZScoreNormalizationTechnique()
3636
);
3737

38-
private final Map<String, List<String>> combinationTechniqueForNormalizationTechniqueMap = Map.of(
38+
private final Map<String, Set<String>> combinationTechniqueForNormalizationTechniqueMap = Map.of(
3939
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME,
40-
List.of(
40+
Set.of(
4141
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
4242
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
4343
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME
4444
),
4545
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
46-
List.of(
46+
Set.of(
4747
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
4848
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
4949
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME
5050
),
5151
RRFNormalizationTechnique.TECHNIQUE_NAME,
52-
List.of(
52+
Set.of(
5353
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
5454
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
5555
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME
5656
),
5757
ZScoreNormalizationTechnique.TECHNIQUE_NAME,
58-
List.of(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME)
58+
Set.of(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME)
5959
);
6060

6161
/**
@@ -75,21 +75,18 @@ public ScoreNormalizationTechnique createNormalization(final String technique, f
7575

7676
/**
7777
* Validate normalization technique based on combination technique and other params that needs to be validated
78-
* @param normalizationTechnique normalization technique to be validated
79-
* @param validateNormalizationDTO data transfer object that contains combination technique and other params that needs to be validated
78+
* @param techniqueCompatibilityCheckDTO data transfer object that contains combination technique and other params that needs to be validated
8079
*/
81-
public void validateNormalizationTechnique(
82-
ScoreNormalizationTechnique normalizationTechnique,
83-
ValidateNormalizationDTO validateNormalizationDTO
84-
) {
85-
List<String> supportedTechniques = combinationTechniqueForNormalizationTechniqueMap.get(normalizationTechnique.techniqueName());
80+
public void isTechniquesCompatible(TechniqueCompatibilityCheckDTO techniqueCompatibilityCheckDTO) {
81+
ScoreNormalizationTechnique normalizationTechnique = techniqueCompatibilityCheckDTO.getScoreNormalizationTechnique();
82+
Set<String> supportedTechniques = combinationTechniqueForNormalizationTechniqueMap.get(normalizationTechnique.techniqueName());
8683

87-
if (!supportedTechniques.contains(validateNormalizationDTO.getScoreCombinationTechnique().techniqueName())) {
84+
if (supportedTechniques.contains(techniqueCompatibilityCheckDTO.getScoreCombinationTechnique().techniqueName()) == false) {
8885
throw new IllegalArgumentException(
8986
String.format(
9087
Locale.ROOT,
9188
"provided combination technique %s is not supported for normalization technique %s. Supported techniques are: %s",
92-
validateNormalizationDTO.getScoreCombinationTechnique().techniqueName(),
89+
techniqueCompatibilityCheckDTO.getScoreCombinationTechnique().techniqueName(),
9390
normalizationTechnique.techniqueName(),
9491
String.join(", ", supportedTechniques)
9592
)

src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java

+3
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,8 @@ public interface ScoreNormalizationTechnique {
2020
*/
2121
void normalize(final NormalizeScoresDTO normalizeScoresDTO);
2222

23+
/**
24+
* Returns the name of the normalization technique.
25+
*/
2326
String techniqueName();
2427
}

0 commit comments

Comments
 (0)