Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 58a66e4

Browse files
committedJan 23, 2025·
Fix failing tests
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
1 parent ebe809c commit 58a66e4

File tree

4 files changed

+62
-14
lines changed

4 files changed

+62
-14
lines changed
 

‎CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2121
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
2222
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
2323
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
24+
- Add support for Faiss onDisk 4x compression (#2425)[https://github.com/opensearch-project/k-NN/pull/2425]
2425
### Enhancements
2526
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
2627
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]

‎src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java

+9-6
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,9 @@ public void testFlush_WithQuantization() {
257257

258258
when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
259259
try {
260-
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
261-
.thenReturn(quantizationState);
260+
when(
261+
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
262+
).thenReturn(quantizationState);
262263
} catch (Exception e) {
263264
throw new RuntimeException(e);
264265
}
@@ -690,8 +691,9 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
690691

691692
when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
692693
try {
693-
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
694-
.thenReturn(quantizationState);
694+
when(
695+
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
696+
).thenReturn(quantizationState);
695697
} catch (Exception e) {
696698
throw new RuntimeException(e);
697699
}
@@ -793,8 +795,9 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
793795

794796
when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
795797
try {
796-
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
797-
.thenReturn(quantizationState);
798+
when(
799+
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
800+
).thenReturn(quantizationState);
798801
} catch (Exception e) {
799802
throw new RuntimeException(e);
800803
}

‎src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,9 @@ public void testMerge_WithQuantization() {
325325

326326
when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
327327
try {
328-
when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState);
328+
when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size(), fieldInfo)).thenReturn(
329+
quantizationState
330+
);
329331
} catch (Exception e) {
330332
throw new RuntimeException(e);
331333
}

‎src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java

+49-7
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55

66
package org.opensearch.knn.index.quantizationservice;
77

8+
import org.apache.lucene.index.FieldInfo;
89
import org.opensearch.knn.KNNTestCase;
910
import org.junit.Before;
1011

1112
import org.opensearch.knn.index.VectorDataType;
13+
import org.opensearch.knn.index.codec.KNNCodecTestUtil;
14+
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
1215
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
1316
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
1417
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
@@ -24,6 +27,7 @@
2427
public class QuantizationServiceTests extends KNNTestCase {
2528
private QuantizationService<float[], byte[]> quantizationService;
2629
private KNNVectorValues<float[]> knnVectorValues;
30+
private FieldInfo fieldInfo;
2731

2832
@Before
2933
public void setUp() throws Exception {
@@ -42,11 +46,19 @@ public void setUp() throws Exception {
4246
VectorDataType.FLOAT,
4347
new TestVectorValues.PreDefinedFloatVectorValues(floatVectors)
4448
);
49+
50+
fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true").build();
4551
}
4652

4753
public void testTrain_oneBitQuantizer_success() throws IOException {
4854
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
49-
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
55+
56+
QuantizationState quantizationState = quantizationService.train(
57+
oneBitParams,
58+
knnVectorValues,
59+
knnVectorValues.totalLiveDocs(),
60+
fieldInfo
61+
);
5062

5163
assertTrue(quantizationState instanceof OneBitScalarQuantizationState);
5264
OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState;
@@ -62,7 +74,12 @@ public void testTrain_oneBitQuantizer_success() throws IOException {
6274

6375
public void testTrain_twoBitQuantizer_success() throws IOException {
6476
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
65-
QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
77+
QuantizationState quantizationState = quantizationService.train(
78+
twoBitParams,
79+
knnVectorValues,
80+
knnVectorValues.totalLiveDocs(),
81+
fieldInfo
82+
);
6683

6784
assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
6885
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
@@ -85,7 +102,12 @@ public void testTrain_twoBitQuantizer_success() throws IOException {
85102

86103
public void testTrain_fourBitQuantizer_success() throws IOException {
87104
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
88-
QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
105+
QuantizationState quantizationState = quantizationService.train(
106+
fourBitParams,
107+
knnVectorValues,
108+
knnVectorValues.totalLiveDocs(),
109+
fieldInfo
110+
);
89111

90112
assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
91113
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
@@ -110,7 +132,12 @@ public void testTrain_fourBitQuantizer_success() throws IOException {
110132

111133
public void testQuantize_oneBitQuantizer_success() throws IOException {
112134
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
113-
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
135+
QuantizationState quantizationState = quantizationService.train(
136+
oneBitParams,
137+
knnVectorValues,
138+
knnVectorValues.totalLiveDocs(),
139+
fieldInfo
140+
);
114141

115142
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);
116143

@@ -125,7 +152,12 @@ public void testQuantize_oneBitQuantizer_success() throws IOException {
125152

126153
public void testQuantize_twoBitQuantizer_success() throws IOException {
127154
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
128-
QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
155+
QuantizationState quantizationState = quantizationService.train(
156+
twoBitParams,
157+
knnVectorValues,
158+
knnVectorValues.totalLiveDocs(),
159+
fieldInfo
160+
);
129161
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams);
130162
byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput);
131163

@@ -138,7 +170,12 @@ public void testQuantize_twoBitQuantizer_success() throws IOException {
138170

139171
public void testQuantize_fourBitQuantizer_success() throws IOException {
140172
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
141-
QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
173+
QuantizationState quantizationState = quantizationService.train(
174+
fourBitParams,
175+
knnVectorValues,
176+
knnVectorValues.totalLiveDocs(),
177+
fieldInfo
178+
);
142179
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams);
143180

144181
byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput);
@@ -152,7 +189,12 @@ public void testQuantize_fourBitQuantizer_success() throws IOException {
152189

153190
public void testQuantize_whenInvalidInput_thenThrows() throws IOException {
154191
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
155-
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
192+
QuantizationState quantizationState = quantizationService.train(
193+
oneBitParams,
194+
knnVectorValues,
195+
knnVectorValues.totalLiveDocs(),
196+
fieldInfo
197+
);
156198
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);
157199
assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput));
158200
}

0 commit comments

Comments
 (0)
Please sign in to comment.