Skip to content

Commit f5cf255

Browse files
authored
Add cosine similarity support for faiss engine (opensearch-project#2376)
* Add cosine similarity support for faiss engine FAISS engine doesn't support cosine similarity natively. However we can use inner product to achieve the same, because, when vectors are normalized then inner product will be same as cosine similarity. Hence, before ingestion and perform search, normalize the input vector and add it to faiss index with type as inner product. Since we will be storing normalized vector in segments, to get actual vectors, source can be used. By saving as normalized vector, we don't have to normalize whenever segments are merged. This will keep force merge time and search at competitive, provided we will face additional latency during indexing (one time where we normalize). We also support radial search for cosine similarity. Signed-off-by: Vijayan Balasubramanian <[email protected]>
1 parent 8374b8f commit f5cf255

22 files changed

+605
-50
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1919
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
2020
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
2121
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
22+
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
2223
### Enhancements
2324
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
2425
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]

src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java

+25-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
1414
import org.opensearch.knn.index.mapper.PerDimensionValidator;
1515
import org.opensearch.knn.index.mapper.SpaceVectorValidator;
16+
import org.opensearch.knn.index.mapper.VectorTransformer;
17+
import org.opensearch.knn.index.mapper.VectorTransformerFactory;
1618
import org.opensearch.knn.index.mapper.VectorValidator;
1719

1820
import java.util.ArrayList;
@@ -106,6 +108,10 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
106108
return PerDimensionProcessor.NOOP_PROCESSOR;
107109
}
108110

111+
protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
112+
return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER;
113+
}
114+
109115
@Override
110116
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
111117
KNNMethodContext knnMethodContext,
@@ -116,19 +122,37 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
116122
knnMethodConfigContext
117123
);
118124
Map<String, Object> parameterMap = knnLibraryIndexingContext.getLibraryParameters();
119-
parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue());
125+
parameterMap.put(KNNConstants.SPACE_TYPE, convertUserToMethodSpaceType(knnMethodContext.getSpaceType()).getValue());
120126
parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue());
121127
return KNNLibraryIndexingContextImpl.builder()
122128
.quantizationConfig(knnLibraryIndexingContext.getQuantizationConfig())
123129
.parameters(parameterMap)
124130
.vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext))
125131
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
126132
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext))
133+
.vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType()))
127134
.build();
128135
}
129136

130137
@Override
131138
public KNNLibrarySearchContext getKNNLibrarySearchContext() {
132139
return knnLibrarySearchContext;
133140
}
141+
142+
/**
143+
* Converts user defined space type to method space type that is supported by library.
144+
* The subclass can override this method and returns the appropriate space type that
145+
* is supported by the library. This is required because, some libraries may not
146+
* support all the space types supported by OpenSearch, however. this can be achieved by using compatible space type by the library.
147+
* For example, faiss does not support cosine similarity. However, we can use inner product space type for cosine similarity after normalization.
148+
* In this case, we can return the inner product space type for cosine similarity.
149+
*
150+
* @param spaceType The space type to check for compatibility
151+
* @return The compatible space type for the given input, returns the same
152+
* space type if it's already compatible
153+
* @see SpaceType
154+
*/
155+
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
156+
return spaceType;
157+
}
134158
}

src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java

+9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
99
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
1010
import org.opensearch.knn.index.mapper.PerDimensionValidator;
11+
import org.opensearch.knn.index.mapper.VectorTransformer;
1112
import org.opensearch.knn.index.mapper.VectorValidator;
1213

1314
import java.util.Map;
@@ -47,4 +48,12 @@ public interface KNNLibraryIndexingContext {
4748
* @return Get the per dimension processor
4849
*/
4950
PerDimensionProcessor getPerDimensionProcessor();
51+
52+
/**
53+
* Get the vector transformer that will be used to transform the vector before indexing.
54+
* This will be applied at vector level once entire vector is parsed and validated.
55+
*
56+
* @return VectorTransformer
57+
*/
58+
VectorTransformer getVectorTransformer();
5059
}

src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java

+7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
1010
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
1111
import org.opensearch.knn.index.mapper.PerDimensionValidator;
12+
import org.opensearch.knn.index.mapper.VectorTransformer;
1213
import org.opensearch.knn.index.mapper.VectorValidator;
1314

1415
import java.util.Collections;
@@ -23,6 +24,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext
2324
private VectorValidator vectorValidator;
2425
private PerDimensionValidator perDimensionValidator;
2526
private PerDimensionProcessor perDimensionProcessor;
27+
private VectorTransformer vectorTransformer;
2628
@Builder.Default
2729
private Map<String, Object> parameters = Collections.emptyMap();
2830
@Builder.Default
@@ -43,6 +45,11 @@ public VectorValidator getVectorValidator() {
4345
return vectorValidator;
4446
}
4547

48+
@Override
49+
public VectorTransformer getVectorTransformer() {
50+
return vectorTransformer;
51+
}
52+
4653
@Override
4754
public PerDimensionValidator getPerDimensionValidator() {
4855
return perDimensionValidator;

src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java

+19-7
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@
88
import org.apache.commons.lang.StringUtils;
99
import org.opensearch.knn.index.SpaceType;
1010
import org.opensearch.knn.index.VectorDataType;
11-
import org.opensearch.knn.index.engine.AbstractKNNMethod;
12-
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
13-
import org.opensearch.knn.index.engine.KNNLibrarySearchContext;
14-
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
15-
import org.opensearch.knn.index.engine.KNNMethodContext;
16-
import org.opensearch.knn.index.engine.MethodComponent;
17-
import org.opensearch.knn.index.engine.MethodComponentContext;
11+
import org.opensearch.knn.index.engine.*;
1812
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
1913
import org.opensearch.knn.index.mapper.PerDimensionValidator;
14+
import org.opensearch.knn.index.mapper.VectorTransformer;
15+
import org.opensearch.knn.index.mapper.VectorTransformerFactory;
2016

2117
import java.util.Objects;
2218
import java.util.Set;
@@ -132,4 +128,20 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m
132128
}
133129
return (MethodComponentContext) object;
134130
}
131+
132+
@Override
133+
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
134+
// While FAISS doesn't directly support cosine similarity, we can leverage the mathematical
135+
// relationship between cosine similarity and inner product for normalized vectors to add support.
136+
// When ||a|| = ||b|| = 1, cos(θ) = a · b
137+
if (spaceType == SpaceType.COSINESIMIL) {
138+
return SpaceType.INNER_PRODUCT;
139+
}
140+
return super.convertUserToMethodSpaceType(spaceType);
141+
}
142+
143+
@Override
144+
protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
145+
return VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType);
146+
}
135147
}

src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java

+21-5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
*/
2727
public class Faiss extends NativeLibrary {
2828
public static final String FAISS_BINARY_INDEX_DESCRIPTION_PREFIX = "B";
29+
Map<SpaceType, Function<Float, Float>> distanceTransform;
2930
Map<SpaceType, Function<Float, Float>> scoreTransform;
3031

3132
// TODO: Current version is not really current version. Instead, it encodes information in the file name
@@ -36,14 +37,24 @@ public class Faiss extends NativeLibrary {
3637
// Map that overrides OpenSearch score translation by space type of scores returned by faiss
3738
private final static Map<SpaceType, Function<Float, Float>> SCORE_TRANSLATIONS = ImmutableMap.of(
3839
SpaceType.INNER_PRODUCT,
39-
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore)
40+
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore),
41+
// COSINESIMIL expects the raw score in 1 - cosine(x,y)
42+
SpaceType.COSINESIMIL,
43+
rawScore -> SpaceType.COSINESIMIL.scoreTranslation(1 - rawScore)
4044
);
4145

4246
// Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation:
4347
// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces
4448
private final static Map<SpaceType, Function<Float, Float>> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.<
4549
SpaceType,
46-
Function<Float, Float>>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build();
50+
Function<Float, Float>>builder()
51+
.put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : (1 / score) - 1)
52+
.put(SpaceType.COSINESIMIL, score -> 2 * score - 1)
53+
.build();
54+
55+
private final static Map<SpaceType, Function<Float, Float>> DISTANCE_TRANSLATIONS = ImmutableMap.<
56+
SpaceType,
57+
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> 1 - distance).build();
4758

4859
// Package private so that the method resolving logic can access the methods
4960
final static Map<String, KNNMethod> METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod());
@@ -53,7 +64,8 @@ public class Faiss extends NativeLibrary {
5364
SCORE_TRANSLATIONS,
5465
CURRENT_VERSION,
5566
KNNConstants.FAISS_EXTENSION,
56-
SCORE_TO_DISTANCE_TRANSFORMATIONS
67+
SCORE_TO_DISTANCE_TRANSFORMATIONS,
68+
DISTANCE_TRANSLATIONS
5769
);
5870

5971
private final MethodResolver methodResolver;
@@ -71,16 +83,20 @@ private Faiss(
7183
Map<SpaceType, Function<Float, Float>> scoreTranslation,
7284
String currentVersion,
7385
String extension,
74-
Map<SpaceType, Function<Float, Float>> scoreTransform
86+
Map<SpaceType, Function<Float, Float>> scoreTransform,
87+
Map<SpaceType, Function<Float, Float>> distanceTransform
7588
) {
7689
super(methods, scoreTranslation, currentVersion, extension);
7790
this.scoreTransform = scoreTransform;
91+
this.distanceTransform = distanceTransform;
7892
this.methodResolver = new FaissMethodResolver();
7993
}
8094

8195
@Override
8296
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
83-
// Faiss engine uses distance as is and does not need transformation
97+
if (this.distanceTransform.containsKey(spaceType)) {
98+
return this.distanceTransform.get(spaceType).apply(distance);
99+
}
84100
return distance;
85101
}
86102

src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ public class FaissHNSWMethod extends AbstractFaissMethod {
4646
SpaceType.UNDEFINED,
4747
SpaceType.HAMMING,
4848
SpaceType.L2,
49-
SpaceType.INNER_PRODUCT
49+
SpaceType.INNER_PRODUCT,
50+
SpaceType.COSINESIMIL
5051
);
5152

5253
private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext(

src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ public class FaissIVFMethod extends AbstractFaissMethod {
4949
SpaceType.UNDEFINED,
5050
SpaceType.L2,
5151
SpaceType.INNER_PRODUCT,
52-
SpaceType.HAMMING
52+
SpaceType.HAMMING,
53+
SpaceType.COSINESIMIL
5354
);
5455

5556
private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext(

src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java

+1
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,5 @@ protected PerDimensionValidator getPerDimensionValidator() {
109109
protected PerDimensionProcessor getPerDimensionProcessor() {
110110
return PerDimensionProcessor.NOOP_PROCESSOR;
111111
}
112+
112113
}

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

+20-1
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ protected void validatePreparse() {
675675
protected abstract VectorValidator getVectorValidator();
676676

677677
/**
678-
* Getter for per dimension validator during vector parsing
678+
* Getter for per dimension validator during vector parsing, and before any transformation
679679
*
680680
* @return PerDimensionValidator
681681
*/
@@ -688,6 +688,23 @@ protected void validatePreparse() {
688688
*/
689689
protected abstract PerDimensionProcessor getPerDimensionProcessor();
690690

691+
/**
692+
* Retrieves the vector transformer for the KNN vector field.
693+
* This method provides access to the vector transformer instance that will be used
694+
* for processing vectors in the KNN field. The transformer is responsible for any
695+
* necessary vector transformations before indexing or searching.
696+
* This implementation delegates to the VectorTransformerFactory to obtain
697+
* the appropriate transformer instance. The returned transformer is typically
698+
* stateless and thread-safe.
699+
*
700+
* @return VectorTransformer An instance of VectorTransformer that will be used
701+
* for vector transformations in this field
702+
*
703+
*/
704+
protected VectorTransformer getVectorTransformer() {
705+
return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER;
706+
}
707+
691708
protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException {
692709
validatePreparse();
693710

@@ -698,6 +715,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
698715
}
699716
final byte[] array = bytesArrayOptional.get();
700717
getVectorValidator().validateVector(array);
718+
getVectorTransformer().transform(array);
701719
context.doc().addAll(getFieldsForByteVector(array));
702720
} else if (VectorDataType.FLOAT == vectorDataType) {
703721
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);
@@ -707,6 +725,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
707725
}
708726
final float[] array = floatsArrayOptional.get();
709727
getVectorValidator().validateVector(array);
728+
getVectorTransformer().transform(array);
710729
context.doc().addAll(getFieldsForFloatVector(array));
711730
} else {
712731
throw new IllegalArgumentException(

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java

+38
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@
2020
import org.opensearch.index.query.QueryShardException;
2121
import org.opensearch.knn.index.KNNVectorIndexFieldData;
2222
import org.opensearch.knn.index.VectorDataType;
23+
import org.opensearch.knn.index.engine.KNNMethodContext;
2324
import org.opensearch.knn.index.query.rescore.RescoreContext;
25+
import org.opensearch.knn.indices.ModelDao;
26+
import org.opensearch.knn.indices.ModelMetadata;
2427
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
2528
import org.opensearch.search.lookup.SearchLookup;
2629

2730
import java.util.ArrayList;
2831
import java.util.Collections;
2932
import java.util.Locale;
3033
import java.util.Map;
34+
import java.util.Optional;
3135
import java.util.function.Supplier;
3236

3337
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector;
@@ -115,4 +119,38 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext)
115119
Mode mode = knnMappingConfig.getMode();
116120
return compressionLevel.getDefaultRescoreContext(mode, dimension);
117121
}
122+
123+
/**
124+
* Transforms a query vector based on the field's configuration. The transformation is performed
125+
* in-place on the input vector according to either the KNN method context or the model ID.
126+
*
127+
* @param vector The float array to be transformed in-place. Must not be null.
128+
* @throws IllegalStateException if neither KNN method context nor Model ID is configured
129+
*
130+
* The transformation process follows this order:
131+
* 1. If vector is not FLOAT type, no transformation is performed
132+
* 2. Attempts to use KNN method context if present
133+
* 3. Falls back to model ID if KNN method context is not available
134+
* 4. Throws exception if neither configuration is present
135+
*/
136+
public void transformQueryVector(float[] vector) {
137+
if (VectorDataType.FLOAT != vectorDataType) {
138+
return;
139+
}
140+
final Optional<KNNMethodContext> knnMethodContext = knnMappingConfig.getKnnMethodContext();
141+
if (knnMethodContext.isPresent()) {
142+
KNNMethodContext context = knnMethodContext.get();
143+
VectorTransformerFactory.getVectorTransformer(context.getKnnEngine(), context.getSpaceType()).transform(vector);
144+
return;
145+
}
146+
final Optional<String> modelId = knnMappingConfig.getModelId();
147+
if (modelId.isPresent()) {
148+
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
149+
final ModelMetadata metadata = modelDao.getMetadata(modelId.get());
150+
VectorTransformerFactory.getVectorTransformer(metadata.getKnnEngine(), metadata.getSpaceType()).transform(vector);
151+
return;
152+
}
153+
throw new IllegalStateException("Either KNN method context or Model Id should be configured");
154+
155+
}
118156
}

src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java

+7
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public class MethodFieldMapper extends KNNVectorFieldMapper {
3939
private final PerDimensionProcessor perDimensionProcessor;
4040
private final PerDimensionValidator perDimensionValidator;
4141
private final VectorValidator vectorValidator;
42+
private final VectorTransformer vectorTransformer;
4243

4344
public static MethodFieldMapper createFieldMapper(
4445
String fullname,
@@ -180,6 +181,7 @@ private MethodFieldMapper(
180181
this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor();
181182
this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator();
182183
this.vectorValidator = knnLibraryIndexingContext.getVectorValidator();
184+
this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer();
183185
}
184186

185187
@Override
@@ -196,4 +198,9 @@ protected PerDimensionValidator getPerDimensionValidator() {
196198
protected PerDimensionProcessor getPerDimensionProcessor() {
197199
return perDimensionProcessor;
198200
}
201+
202+
@Override
203+
protected VectorTransformer getVectorTransformer() {
204+
return vectorTransformer;
205+
}
199206
}

0 commit comments

Comments
 (0)