5
5
6
6
package org .opensearch .knn .index .quantizationservice ;
7
7
8
+ import org .apache .lucene .index .FieldInfo ;
8
9
import org .opensearch .knn .KNNTestCase ;
9
10
import org .junit .Before ;
10
11
11
12
import org .opensearch .knn .index .VectorDataType ;
13
+ import org .opensearch .knn .index .codec .KNNCodecTestUtil ;
14
+ import org .opensearch .knn .index .mapper .KNNVectorFieldMapper ;
12
15
import org .opensearch .knn .index .vectorvalues .KNNVectorValues ;
13
16
import org .opensearch .knn .index .vectorvalues .KNNVectorValuesFactory ;
14
17
import org .opensearch .knn .index .vectorvalues .TestVectorValues ;
24
27
public class QuantizationServiceTests extends KNNTestCase {
25
28
private QuantizationService <float [], byte []> quantizationService ;
26
29
private KNNVectorValues <float []> knnVectorValues ;
30
+ private FieldInfo fieldInfo ;
27
31
28
32
@ Before
29
33
public void setUp () throws Exception {
@@ -42,11 +46,19 @@ public void setUp() throws Exception {
42
46
VectorDataType .FLOAT ,
43
47
new TestVectorValues .PreDefinedFloatVectorValues (floatVectors )
44
48
);
49
+
50
+ fieldInfo = KNNCodecTestUtil .FieldInfoBuilder .builder ("test-field" ).addAttribute (KNNVectorFieldMapper .KNN_FIELD , "true" ).build ();
45
51
}
46
52
47
53
public void testTrain_oneBitQuantizer_success () throws IOException {
48
54
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams (ScalarQuantizationType .ONE_BIT );
49
- QuantizationState quantizationState = quantizationService .train (oneBitParams , knnVectorValues , knnVectorValues .totalLiveDocs ());
55
+
56
+ QuantizationState quantizationState = quantizationService .train (
57
+ oneBitParams ,
58
+ knnVectorValues ,
59
+ knnVectorValues .totalLiveDocs (),
60
+ fieldInfo
61
+ );
50
62
51
63
assertTrue (quantizationState instanceof OneBitScalarQuantizationState );
52
64
OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState ) quantizationState ;
@@ -62,7 +74,12 @@ public void testTrain_oneBitQuantizer_success() throws IOException {
62
74
63
75
public void testTrain_twoBitQuantizer_success () throws IOException {
64
76
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams (ScalarQuantizationType .TWO_BIT );
65
- QuantizationState quantizationState = quantizationService .train (twoBitParams , knnVectorValues , knnVectorValues .totalLiveDocs ());
77
+ QuantizationState quantizationState = quantizationService .train (
78
+ twoBitParams ,
79
+ knnVectorValues ,
80
+ knnVectorValues .totalLiveDocs (),
81
+ fieldInfo
82
+ );
66
83
67
84
assertTrue (quantizationState instanceof MultiBitScalarQuantizationState );
68
85
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState ) quantizationState ;
@@ -85,7 +102,12 @@ public void testTrain_twoBitQuantizer_success() throws IOException {
85
102
86
103
public void testTrain_fourBitQuantizer_success () throws IOException {
87
104
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams (ScalarQuantizationType .FOUR_BIT );
88
- QuantizationState quantizationState = quantizationService .train (fourBitParams , knnVectorValues , knnVectorValues .totalLiveDocs ());
105
+ QuantizationState quantizationState = quantizationService .train (
106
+ fourBitParams ,
107
+ knnVectorValues ,
108
+ knnVectorValues .totalLiveDocs (),
109
+ fieldInfo
110
+ );
89
111
90
112
assertTrue (quantizationState instanceof MultiBitScalarQuantizationState );
91
113
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState ) quantizationState ;
@@ -110,7 +132,12 @@ public void testTrain_fourBitQuantizer_success() throws IOException {
110
132
111
133
public void testQuantize_oneBitQuantizer_success () throws IOException {
112
134
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams (ScalarQuantizationType .ONE_BIT );
113
- QuantizationState quantizationState = quantizationService .train (oneBitParams , knnVectorValues , knnVectorValues .totalLiveDocs ());
135
+ QuantizationState quantizationState = quantizationService .train (
136
+ oneBitParams ,
137
+ knnVectorValues ,
138
+ knnVectorValues .totalLiveDocs (),
139
+ fieldInfo
140
+ );
114
141
115
142
QuantizationOutput quantizationOutput = quantizationService .createQuantizationOutput (oneBitParams );
116
143
@@ -125,7 +152,12 @@ public void testQuantize_oneBitQuantizer_success() throws IOException {
125
152
126
153
public void testQuantize_twoBitQuantizer_success () throws IOException {
127
154
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams (ScalarQuantizationType .TWO_BIT );
128
- QuantizationState quantizationState = quantizationService .train (twoBitParams , knnVectorValues , knnVectorValues .totalLiveDocs ());
155
+ QuantizationState quantizationState = quantizationService .train (
156
+ twoBitParams ,
157
+ knnVectorValues ,
158
+ knnVectorValues .totalLiveDocs (),
159
+ fieldInfo
160
+ );
129
161
QuantizationOutput quantizationOutput = quantizationService .createQuantizationOutput (twoBitParams );
130
162
byte [] quantizedVector = quantizationService .quantize (quantizationState , new float [] { 4.0f , 5.0f , 6.0f }, quantizationOutput );
131
163
@@ -138,7 +170,12 @@ public void testQuantize_twoBitQuantizer_success() throws IOException {
138
170
139
171
public void testQuantize_fourBitQuantizer_success () throws IOException {
140
172
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams (ScalarQuantizationType .FOUR_BIT );
141
- QuantizationState quantizationState = quantizationService .train (fourBitParams , knnVectorValues , knnVectorValues .totalLiveDocs ());
173
+ QuantizationState quantizationState = quantizationService .train (
174
+ fourBitParams ,
175
+ knnVectorValues ,
176
+ knnVectorValues .totalLiveDocs (),
177
+ fieldInfo
178
+ );
142
179
QuantizationOutput quantizationOutput = quantizationService .createQuantizationOutput (fourBitParams );
143
180
144
181
byte [] quantizedVector = quantizationService .quantize (quantizationState , new float [] { 7.0f , 8.0f , 9.0f }, quantizationOutput );
@@ -152,7 +189,12 @@ public void testQuantize_fourBitQuantizer_success() throws IOException {
152
189
153
190
public void testQuantize_whenInvalidInput_thenThrows () throws IOException {
154
191
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams (ScalarQuantizationType .ONE_BIT );
155
- QuantizationState quantizationState = quantizationService .train (oneBitParams , knnVectorValues , knnVectorValues .totalLiveDocs ());
192
+ QuantizationState quantizationState = quantizationService .train (
193
+ oneBitParams ,
194
+ knnVectorValues ,
195
+ knnVectorValues .totalLiveDocs (),
196
+ fieldInfo
197
+ );
156
198
QuantizationOutput quantizationOutput = quantizationService .createQuantizationOutput (oneBitParams );
157
199
assertThrows (IllegalArgumentException .class , () -> quantizationService .quantize (quantizationState , null , quantizationOutput ));
158
200
}
0 commit comments