Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e09cdfd

Browse files
committedJan 23, 2025·
Fix failing tests
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
1 parent 7087adb commit e09cdfd

11 files changed

+117
-44
lines changed
 

‎CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1919
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
2020
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
2121
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
22+
- Add support for Faiss onDisk 4x compression (#2425)[https://github.com/opensearch-project/k-NN/pull/2425]
2223
### Enhancements
2324
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
2425
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]

‎src/main/java/org/opensearch/knn/index/engine/EngineResolver.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public KNNEngine resolveEngine(
5454
}
5555

5656
// 4x is supported by Lucene engine before version 2.19.0
57-
if (knnMethodConfigContext.getVersionCreated().before(Version.V_2_19_0) && compressionLevel == CompressionLevel.x4) {
57+
if (compressionLevel == CompressionLevel.x4 && knnMethodConfigContext.getVersionCreated().before(Version.V_2_19_0)) {
5858
return KNNEngine.LUCENE;
5959
}
6060

‎src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java

+20-6
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
2828
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
2929
import org.opensearch.knn.plugin.stats.KNNGraphValue;
30+
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
3031
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
32+
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
3133
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
3234
import org.opensearch.test.OpenSearchTestCase;
3335

@@ -256,9 +258,13 @@ public void testFlush_WithQuantization() {
256258
).thenReturn(expectedVectorValues.get(i));
257259

258260
when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
261+
when(quantizationParams.getTypeIdentifier()).thenReturn(
262+
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT)
263+
);
259264
try {
260-
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
261-
.thenReturn(quantizationState);
265+
when(
266+
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
267+
).thenReturn(quantizationState);
262268
} catch (Exception e) {
263269
throw new RuntimeException(e);
264270
}
@@ -689,9 +695,13 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
689695
).thenReturn(expectedVectorValues.get(i));
690696

691697
when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
698+
when(quantizationParams.getTypeIdentifier()).thenReturn(
699+
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT)
700+
);
692701
try {
693-
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
694-
.thenReturn(quantizationState);
702+
when(
703+
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
704+
).thenReturn(quantizationState);
695705
} catch (Exception e) {
696706
throw new RuntimeException(e);
697707
}
@@ -792,9 +802,13 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
792802
).thenReturn(expectedVectorValues.get(i));
793803

794804
when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
805+
when(quantizationParams.getTypeIdentifier()).thenReturn(
806+
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT)
807+
);
795808
try {
796-
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
797-
.thenReturn(quantizationState);
809+
when(
810+
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
811+
).thenReturn(quantizationState);
798812
} catch (Exception e) {
799813
throw new RuntimeException(e);
800814
}

‎src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
3131
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
3232
import org.opensearch.knn.plugin.stats.KNNGraphValue;
33+
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
3334
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
35+
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
3436
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
3537
import org.opensearch.test.OpenSearchTestCase;
3638

@@ -324,8 +326,13 @@ public void testMerge_WithQuantization() {
324326
.thenReturn(knnVectorValues);
325327

326328
when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
329+
when(quantizationParams.getTypeIdentifier()).thenReturn(
330+
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT)
331+
);
327332
try {
328-
when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState);
333+
when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size(), fieldInfo)).thenReturn(
334+
quantizationState
335+
);
329336
} catch (Exception e) {
330337
throw new RuntimeException(e);
331338
}

‎src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.knn.index.engine;
77

8+
import org.opensearch.Version;
89
import org.opensearch.knn.KNNTestCase;
910
import org.opensearch.knn.index.SpaceType;
1011
import org.opensearch.knn.index.mapper.CompressionLevel;
@@ -68,18 +69,26 @@ public void testResolveEngine_whenCompressionIs1x_thenEngineBasedOnMode() {
6869
);
6970
}
7071

71-
public void testResolveEngine_whenCompressionIs4x_thenEngineIsLucene() {
72+
public void testResolveEngine_whenCompressionIs4x_VersionIsBefore2_19_thenEngineIsLucene() {
7273
assertEquals(
7374
KNNEngine.LUCENE,
7475
ENGINE_RESOLVER.resolveEngine(
75-
KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x4).build(),
76+
KNNMethodConfigContext.builder()
77+
.mode(Mode.ON_DISK)
78+
.compressionLevel(CompressionLevel.x4)
79+
.versionCreated(Version.V_2_18_0)
80+
.build(),
7681
null,
7782
false
7883
)
7984
);
8085
assertEquals(
8186
KNNEngine.LUCENE,
82-
ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x4).build(), null, false)
87+
ENGINE_RESOLVER.resolveEngine(
88+
KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x4).versionCreated(Version.V_2_17_0).build(),
89+
null,
90+
false
91+
)
8392
);
8493
}
8594

‎src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java

-15
Original file line numberDiff line numberDiff line change
@@ -195,21 +195,6 @@ private void validateResolveMethodContext(
195195
}
196196

197197
public void testResolveMethod_whenInvalid_thenThrow() {
198-
// Invalid compression
199-
expectThrows(
200-
ValidationException.class,
201-
() -> TEST_RESOLVER.resolveMethod(
202-
null,
203-
KNNMethodConfigContext.builder()
204-
.vectorDataType(VectorDataType.FLOAT)
205-
.compressionLevel(CompressionLevel.x4)
206-
.versionCreated(Version.CURRENT)
207-
.build(),
208-
false,
209-
SpaceType.L2
210-
)
211-
);
212-
213198
expectThrows(
214199
ValidationException.class,
215200
() -> TEST_RESOLVER.resolveMethod(

‎src/test/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoderTests.java

+14-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,24 @@
66
package org.opensearch.knn.index.engine.faiss;
77

88
import org.opensearch.knn.KNNTestCase;
9+
import org.opensearch.knn.index.engine.MethodComponentContext;
910
import org.opensearch.knn.index.mapper.CompressionLevel;
1011

12+
import java.util.Map;
13+
14+
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
15+
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
16+
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_INT8;
17+
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
18+
1119
public class FaissSQEncoderTests extends KNNTestCase {
1220
public void testCalculateCompressionLevel() {
1321
FaissSQEncoder encoder = new FaissSQEncoder();
14-
assertEquals(CompressionLevel.x2, encoder.calculateCompressionLevel(null, null));
22+
assertEquals(CompressionLevel.x2, encoder.calculateCompressionLevel(generateMethodComponentContext(FAISS_SQ_ENCODER_FP16), null));
23+
assertEquals(CompressionLevel.x4, encoder.calculateCompressionLevel(generateMethodComponentContext(FAISS_SQ_ENCODER_INT8), null));
24+
}
25+
26+
private MethodComponentContext generateMethodComponentContext(String sqType) {
27+
return new MethodComponentContext(ENCODER_SQ, Map.of(FAISS_SQ_TYPE, sqType));
1528
}
1629
}

‎src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ public void testGetDefaultRescoreContext() {
7878
assertNotNull(rescoreContext);
7979
assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);
8080

81-
// x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext)
82-
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT);
81+
// x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext before V2.19.0)
82+
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension, Version.V_2_18_1);
8383
assertNull(rescoreContext);
84-
// x4 with dimension > 1000 should return null (no RescoreContext is configured for x4)
85-
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.CURRENT);
84+
// x4 with dimension > 1000 should return null (no RescoreContext is configured for x4 before V2.19.0)
85+
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.V_2_18_1);
8686
assertNull(rescoreContext);
8787
// Other compression levels should behave similarly with respect to dimension
8888
rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT);

‎src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -1831,7 +1831,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx
18311831
true
18321832
);
18331833

1834-
// For 4x compression on disk, use Lucene
1834+
// For 4x compression on disk, use Faiss
18351835
xContentBuilder = XContentFactory.jsonBuilder()
18361836
.startObject()
18371837
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
@@ -1847,7 +1847,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx
18471847
);
18481848
validateBuilderAfterParsing(
18491849
builder,
1850-
KNNEngine.LUCENE,
1850+
KNNEngine.FAISS,
18511851
SpaceType.L2,
18521852
VectorDataType.FLOAT,
18531853
CompressionLevel.x4,
@@ -1856,7 +1856,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx
18561856
false
18571857
);
18581858

1859-
// For 4x compression in memory, use Lucene
1859+
// For 4x compression in memory, use Faiss
18601860
xContentBuilder = XContentFactory.jsonBuilder()
18611861
.startObject()
18621862
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
@@ -1872,7 +1872,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx
18721872
);
18731873
validateBuilderAfterParsing(
18741874
builder,
1875-
KNNEngine.LUCENE,
1875+
KNNEngine.FAISS,
18761876
SpaceType.L2,
18771877
VectorDataType.FLOAT,
18781878
CompressionLevel.x4,
@@ -1971,7 +1971,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx
19711971
.field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x4.getName())
19721972
.startObject(KNN_METHOD)
19731973
.field(NAME, METHOD_HNSW)
1974-
.field(KNN_ENGINE, KNNEngine.FAISS)
1974+
.field(KNN_ENGINE, KNNEngine.NMSLIB)
19751975
.endObject()
19761976
.endObject();
19771977

‎src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java

+49-7
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55

66
package org.opensearch.knn.index.quantizationservice;
77

8+
import org.apache.lucene.index.FieldInfo;
89
import org.opensearch.knn.KNNTestCase;
910
import org.junit.Before;
1011

1112
import org.opensearch.knn.index.VectorDataType;
13+
import org.opensearch.knn.index.codec.KNNCodecTestUtil;
14+
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
1215
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
1316
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
1417
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
@@ -24,6 +27,7 @@
2427
public class QuantizationServiceTests extends KNNTestCase {
2528
private QuantizationService<float[], byte[]> quantizationService;
2629
private KNNVectorValues<float[]> knnVectorValues;
30+
private FieldInfo fieldInfo;
2731

2832
@Before
2933
public void setUp() throws Exception {
@@ -42,11 +46,19 @@ public void setUp() throws Exception {
4246
VectorDataType.FLOAT,
4347
new TestVectorValues.PreDefinedFloatVectorValues(floatVectors)
4448
);
49+
50+
fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true").build();
4551
}
4652

4753
public void testTrain_oneBitQuantizer_success() throws IOException {
4854
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
49-
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
55+
56+
QuantizationState quantizationState = quantizationService.train(
57+
oneBitParams,
58+
knnVectorValues,
59+
knnVectorValues.totalLiveDocs(),
60+
fieldInfo
61+
);
5062

5163
assertTrue(quantizationState instanceof OneBitScalarQuantizationState);
5264
OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState;
@@ -62,7 +74,12 @@ public void testTrain_oneBitQuantizer_success() throws IOException {
6274

6375
public void testTrain_twoBitQuantizer_success() throws IOException {
6476
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
65-
QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
77+
QuantizationState quantizationState = quantizationService.train(
78+
twoBitParams,
79+
knnVectorValues,
80+
knnVectorValues.totalLiveDocs(),
81+
fieldInfo
82+
);
6683

6784
assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
6885
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
@@ -85,7 +102,12 @@ public void testTrain_twoBitQuantizer_success() throws IOException {
85102

86103
public void testTrain_fourBitQuantizer_success() throws IOException {
87104
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
88-
QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
105+
QuantizationState quantizationState = quantizationService.train(
106+
fourBitParams,
107+
knnVectorValues,
108+
knnVectorValues.totalLiveDocs(),
109+
fieldInfo
110+
);
89111

90112
assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
91113
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
@@ -110,7 +132,12 @@ public void testTrain_fourBitQuantizer_success() throws IOException {
110132

111133
public void testQuantize_oneBitQuantizer_success() throws IOException {
112134
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
113-
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
135+
QuantizationState quantizationState = quantizationService.train(
136+
oneBitParams,
137+
knnVectorValues,
138+
knnVectorValues.totalLiveDocs(),
139+
fieldInfo
140+
);
114141

115142
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);
116143

@@ -125,7 +152,12 @@ public void testQuantize_oneBitQuantizer_success() throws IOException {
125152

126153
public void testQuantize_twoBitQuantizer_success() throws IOException {
127154
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
128-
QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
155+
QuantizationState quantizationState = quantizationService.train(
156+
twoBitParams,
157+
knnVectorValues,
158+
knnVectorValues.totalLiveDocs(),
159+
fieldInfo
160+
);
129161
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams);
130162
byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput);
131163

@@ -138,7 +170,12 @@ public void testQuantize_twoBitQuantizer_success() throws IOException {
138170

139171
public void testQuantize_fourBitQuantizer_success() throws IOException {
140172
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
141-
QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
173+
QuantizationState quantizationState = quantizationService.train(
174+
fourBitParams,
175+
knnVectorValues,
176+
knnVectorValues.totalLiveDocs(),
177+
fieldInfo
178+
);
142179
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams);
143180

144181
byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput);
@@ -152,7 +189,12 @@ public void testQuantize_fourBitQuantizer_success() throws IOException {
152189

153190
public void testQuantize_whenInvalidInput_thenThrows() throws IOException {
154191
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
155-
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
192+
QuantizationState quantizationState = quantizationService.train(
193+
oneBitParams,
194+
knnVectorValues,
195+
knnVectorValues.totalLiveDocs(),
196+
fieldInfo
197+
);
156198
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);
157199
assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput));
158200
}

‎src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@ public void testSQTypesValues() {
1515
ScalarQuantizationType[] expectedValues = {
1616
ScalarQuantizationType.ONE_BIT,
1717
ScalarQuantizationType.TWO_BIT,
18-
ScalarQuantizationType.FOUR_BIT };
18+
ScalarQuantizationType.FOUR_BIT,
19+
ScalarQuantizationType.EIGHT_BIT };
1920
assertArrayEquals(expectedValues, ScalarQuantizationType.values());
2021
}
2122

2223
public void testSQTypesValueOf() {
2324
assertEquals(ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.valueOf("ONE_BIT"));
2425
assertEquals(ScalarQuantizationType.TWO_BIT, ScalarQuantizationType.valueOf("TWO_BIT"));
2526
assertEquals(ScalarQuantizationType.FOUR_BIT, ScalarQuantizationType.valueOf("FOUR_BIT"));
27+
assertEquals(ScalarQuantizationType.EIGHT_BIT, ScalarQuantizationType.valueOf("EIGHT_BIT"));
2628
}
2729

2830
public void testUniqueSQTypeValues() {

0 commit comments

Comments
 (0)
Please sign in to comment.