19
19
import org .apache .commons .math3 .stat .descriptive .DescriptiveStatistics ;
20
20
21
21
import com .google .common .primitives .Floats ;
22
- import org .opensearch .neuralsearch .processor .combination .ArithmeticMeanScoreCombinationTechnique ;
23
- import org .opensearch .neuralsearch .processor .combination .GeometricMeanScoreCombinationTechnique ;
24
- import org .opensearch .neuralsearch .processor .combination .HarmonicMeanScoreCombinationTechnique ;
25
- import org .opensearch .neuralsearch .processor .combination .ScoreCombinationTechnique ;
26
22
import org .opensearch .neuralsearch .processor .explain .DocIdAtSearchShard ;
27
23
import org .opensearch .neuralsearch .processor .explain .ExplainableTechnique ;
28
24
import org .opensearch .neuralsearch .processor .explain .ExplanationDetails ;
29
25
30
- import java .util .List ;
31
- import java .util .Objects ;
32
-
33
26
import static org .opensearch .neuralsearch .processor .explain .ExplanationUtils .getDocIdAtQueryForNormalization ;
27
+ import static org .opensearch .neuralsearch .processor .util .ProcessorUtils .getNumOfSubqueries ;
34
28
35
29
/**
36
30
* Abstracts normalization of scores based on z score method
@@ -85,21 +79,8 @@ public void normalize(NormalizeScoresDTO normalizeScoresDTO) {
85
79
}
86
80
87
81
@ Override
88
- public void validateCombinationTechnique (ScoreCombinationTechnique combinationTechnique ) throws IllegalArgumentException {
89
- switch (combinationTechnique .techniqueName ()) {
90
- case ArithmeticMeanScoreCombinationTechnique .TECHNIQUE_NAME :
91
- // This is the supported technique, so we do nothing
92
- break ;
93
- case GeometricMeanScoreCombinationTechnique .TECHNIQUE_NAME :
94
- throw new IllegalArgumentException ("Z Score does not support geometric mean combination technique" );
95
- case HarmonicMeanScoreCombinationTechnique .TECHNIQUE_NAME :
96
- throw new IllegalArgumentException ("Z Score does not support harmonic mean combination technique" );
97
- default :
98
- throw new IllegalArgumentException (
99
- "Z Score does not support the provided combination technique {}: Supported technique is arithmetic_mean"
100
- + combinationTechnique .techniqueName ()
101
- );
102
- }
82
+ public String techniqueName () {
83
+ return TECHNIQUE_NAME ;
103
84
}
104
85
105
86
@ Override
@@ -143,41 +124,31 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(final List<CompoundTo
143
124
return getDocIdAtQueryForNormalization (normalizedScores , this );
144
125
}
145
126
146
- private int getNumOfSubqueries (final List <CompoundTopDocs > queryTopDocs ) {
147
- return queryTopDocs .stream ()
148
- .filter (Objects ::nonNull )
149
- .filter (topDocs -> !topDocs .getTopDocs ().isEmpty ())
150
- .findAny ()
151
- .get ()
152
- .getTopDocs ()
153
- .size ();
154
- }
155
-
156
- private static float [] calculateMaxScorePerSubquery (final List <CompoundTopDocs > queryTopDocs , final int numOfScores ) {
157
- DescriptiveStatistics [] statsPerSubquery = calculateStatsPerSubquery (queryTopDocs , numOfScores );
127
+ private static float [] calculateMaxScorePerSubquery (final List <CompoundTopDocs > queryTopDocs , final int numOfSubqueries ) {
128
+ DescriptiveStatistics [] statsPerSubquery = calculateStatsPerSubquery (queryTopDocs , numOfSubqueries );
158
129
159
- float [] maxPerSubQuery = new float [numOfScores ];
160
- for (int i = 0 ; i < numOfScores ; i ++) {
130
+ float [] maxPerSubQuery = new float [numOfSubqueries ];
131
+ for (int i = 0 ; i < numOfSubqueries ; i ++) {
161
132
maxPerSubQuery [i ] = (float ) statsPerSubquery [i ].getMax ();
162
133
}
163
134
164
135
return maxPerSubQuery ;
165
136
}
166
137
167
- private static float [] calculateMinScorePerSubquery (final List <CompoundTopDocs > queryTopDocs , final int numOfScores ) {
168
- DescriptiveStatistics [] statsPerSubquery = calculateStatsPerSubquery (queryTopDocs , numOfScores );
138
+ private static float [] calculateMinScorePerSubquery (final List <CompoundTopDocs > queryTopDocs , final int numOfSubqueries ) {
139
+ DescriptiveStatistics [] statsPerSubquery = calculateStatsPerSubquery (queryTopDocs , numOfSubqueries );
169
140
170
- float [] minPerSubQuery = new float [numOfScores ];
171
- for (int i = 0 ; i < numOfScores ; i ++) {
141
+ float [] minPerSubQuery = new float [numOfSubqueries ];
142
+ for (int i = 0 ; i < numOfSubqueries ; i ++) {
172
143
minPerSubQuery [i ] = (float ) statsPerSubquery [i ].getMin ();
173
144
}
174
145
175
146
return minPerSubQuery ;
176
147
}
177
148
178
- private static DescriptiveStatistics [] calculateStatsPerSubquery (final List <CompoundTopDocs > queryTopDocs , final int numOfScores ) {
179
- DescriptiveStatistics [] statsPerSubquery = new DescriptiveStatistics [numOfScores ];
180
- for (int i = 0 ; i < numOfScores ; i ++) {
149
+ private static DescriptiveStatistics [] calculateStatsPerSubquery (final List <CompoundTopDocs > queryTopDocs , final int numOfSubqueries ) {
150
+ DescriptiveStatistics [] statsPerSubquery = new DescriptiveStatistics [numOfSubqueries ];
151
+ for (int i = 0 ; i < numOfSubqueries ; i ++) {
181
152
statsPerSubquery [i ] = new DescriptiveStatistics ();
182
153
}
183
154
@@ -197,22 +168,22 @@ private static DescriptiveStatistics[] calculateStatsPerSubquery(final List<Comp
197
168
return statsPerSubquery ;
198
169
}
199
170
200
- private static float [] calculateMeanPerSubquery (final List <CompoundTopDocs > queryTopDocs , final int numOfScores ) {
201
- DescriptiveStatistics [] statsPerSubquery = calculateStatsPerSubquery (queryTopDocs , numOfScores );
171
+ private static float [] calculateMeanPerSubquery (final List <CompoundTopDocs > queryTopDocs , final int numOfSubqueries ) {
172
+ DescriptiveStatistics [] statsPerSubquery = calculateStatsPerSubquery (queryTopDocs , numOfSubqueries );
202
173
203
- float [] meanPerSubQuery = new float [numOfScores ];
204
- for (int i = 0 ; i < numOfScores ; i ++) {
174
+ float [] meanPerSubQuery = new float [numOfSubqueries ];
175
+ for (int i = 0 ; i < numOfSubqueries ; i ++) {
205
176
meanPerSubQuery [i ] = (float ) statsPerSubquery [i ].getMean ();
206
177
}
207
178
208
179
return meanPerSubQuery ;
209
180
}
210
181
211
- private static float [] calculateStandardDeviationPerSubquery (final List <CompoundTopDocs > queryTopDocs , final int numOfScores ) {
212
- DescriptiveStatistics [] statsPerSubquery = calculateStatsPerSubquery (queryTopDocs , numOfScores );
182
+ private static float [] calculateStandardDeviationPerSubquery (final List <CompoundTopDocs > queryTopDocs , final int numOfSubqueries ) {
183
+ DescriptiveStatistics [] statsPerSubquery = calculateStatsPerSubquery (queryTopDocs , numOfSubqueries );
213
184
214
- float [] stdPerSubQuery = new float [numOfScores ];
215
- for (int i = 0 ; i < numOfScores ; i ++) {
185
+ float [] stdPerSubQuery = new float [numOfSubqueries ];
186
+ for (int i = 0 ; i < numOfSubqueries ; i ++) {
216
187
stdPerSubQuery [i ] = (float ) statsPerSubquery [i ].getStandardDeviation ();
217
188
}
218
189
0 commit comments