4
4
*/
5
5
package org .opensearch .neuralsearch .processor .normalization ;
6
6
7
- import org .opensearch .neuralsearch .processor .ValidateNormalizationDTO ;
7
+ import org .opensearch .neuralsearch .processor .TechniqueCompatibilityCheckDTO ;
8
8
import org .opensearch .neuralsearch .processor .combination .ArithmeticMeanScoreCombinationTechnique ;
9
9
import org .opensearch .neuralsearch .processor .combination .GeometricMeanScoreCombinationTechnique ;
10
10
import org .opensearch .neuralsearch .processor .combination .HarmonicMeanScoreCombinationTechnique ;
11
11
12
- import java .util .List ;
13
12
import java .util .Locale ;
14
13
import java .util .Map ;
14
+ import java .util .Set ;
15
15
import java .util .Optional ;
16
16
import java .util .function .Function ;
17
17
@@ -24,7 +24,7 @@ public class ScoreNormalizationFactory {
24
24
25
25
public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique ();
26
26
27
- private final Map <String , Function <Map <String , Object >, ScoreNormalizationTechnique >> scoreNormalizationMethodsMap = Map .of (
27
+ private static final Map <String , Function <Map <String , Object >, ScoreNormalizationTechnique >> scoreNormalizationMethodsMap = Map .of (
28
28
MinMaxScoreNormalizationTechnique .TECHNIQUE_NAME ,
29
29
params -> new MinMaxScoreNormalizationTechnique (params , scoreNormalizationUtil ),
30
30
L2ScoreNormalizationTechnique .TECHNIQUE_NAME ,
@@ -35,27 +35,27 @@ public class ScoreNormalizationFactory {
35
35
params -> new ZScoreNormalizationTechnique ()
36
36
);
37
37
38
- private final Map <String , List <String >> combinationTechniqueForNormalizationTechniqueMap = Map .of (
38
+ private final Map <String , Set <String >> combinationTechniqueForNormalizationTechniqueMap = Map .of (
39
39
MinMaxScoreNormalizationTechnique .TECHNIQUE_NAME ,
40
- List .of (
40
+ Set .of (
41
41
ArithmeticMeanScoreCombinationTechnique .TECHNIQUE_NAME ,
42
42
GeometricMeanScoreCombinationTechnique .TECHNIQUE_NAME ,
43
43
HarmonicMeanScoreCombinationTechnique .TECHNIQUE_NAME
44
44
),
45
45
L2ScoreNormalizationTechnique .TECHNIQUE_NAME ,
46
- List .of (
46
+ Set .of (
47
47
ArithmeticMeanScoreCombinationTechnique .TECHNIQUE_NAME ,
48
48
GeometricMeanScoreCombinationTechnique .TECHNIQUE_NAME ,
49
49
HarmonicMeanScoreCombinationTechnique .TECHNIQUE_NAME
50
50
),
51
51
RRFNormalizationTechnique .TECHNIQUE_NAME ,
52
- List .of (
52
+ Set .of (
53
53
ArithmeticMeanScoreCombinationTechnique .TECHNIQUE_NAME ,
54
54
GeometricMeanScoreCombinationTechnique .TECHNIQUE_NAME ,
55
55
HarmonicMeanScoreCombinationTechnique .TECHNIQUE_NAME
56
56
),
57
57
ZScoreNormalizationTechnique .TECHNIQUE_NAME ,
58
- List .of (ArithmeticMeanScoreCombinationTechnique .TECHNIQUE_NAME )
58
+ Set .of (ArithmeticMeanScoreCombinationTechnique .TECHNIQUE_NAME )
59
59
);
60
60
61
61
/**
@@ -75,21 +75,18 @@ public ScoreNormalizationTechnique createNormalization(final String technique, f
75
75
76
76
/**
77
77
* Validate normalization technique based on combination technique and other params that needs to be validated
78
- * @param normalizationTechnique normalization technique to be validated
79
- * @param validateNormalizationDTO data transfer object that contains combination technique and other params that needs to be validated
78
+ * @param techniqueCompatibilityCheckDTO data transfer object that contains combination technique and other params that needs to be validated
80
79
*/
81
- public void validateNormalizationTechnique (
82
- ScoreNormalizationTechnique normalizationTechnique ,
83
- ValidateNormalizationDTO validateNormalizationDTO
84
- ) {
85
- List <String > supportedTechniques = combinationTechniqueForNormalizationTechniqueMap .get (normalizationTechnique .techniqueName ());
80
+ public void isTechniquesCompatible (TechniqueCompatibilityCheckDTO techniqueCompatibilityCheckDTO ) {
81
+ ScoreNormalizationTechnique normalizationTechnique = techniqueCompatibilityCheckDTO .getScoreNormalizationTechnique ();
82
+ Set <String > supportedTechniques = combinationTechniqueForNormalizationTechniqueMap .get (normalizationTechnique .techniqueName ());
86
83
87
- if (! supportedTechniques .contains (validateNormalizationDTO .getScoreCombinationTechnique ().techniqueName ())) {
84
+ if (supportedTechniques .contains (techniqueCompatibilityCheckDTO .getScoreCombinationTechnique ().techniqueName ()) == false ) {
88
85
throw new IllegalArgumentException (
89
86
String .format (
90
87
Locale .ROOT ,
91
88
"provided combination technique %s is not supported for normalization technique %s. Supported techniques are: %s" ,
92
- validateNormalizationDTO .getScoreCombinationTechnique ().techniqueName (),
89
+ techniqueCompatibilityCheckDTO .getScoreCombinationTechnique ().techniqueName (),
93
90
normalizationTechnique .techniqueName (),
94
91
String .join (", " , supportedTechniques )
95
92
)
0 commit comments