diff --git a/src/main/java/org/opensearch/knn/index/MultiVectorField.java b/src/main/java/org/opensearch/knn/index/MultiVectorField.java new file mode 100644 index 0000000000..dba509fea5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/MultiVectorField.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.apache.lucene.document.Field; +import org.apache.lucene.index.IndexableFieldType; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.codec.util.KNNVectorAsCollectionOfFloatsSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; + +import java.nio.ByteBuffer; +import java.util.List; + +public class MultiVectorField extends Field{ + + + public MultiVectorField(String name, List value, IndexableFieldType type) { + super(name, new BytesRef(), type); + try { + final KNNVectorSerializer vectorSerializer = KNNVectorAsCollectionOfFloatsSerializer.INSTANCE; + final ByteBuffer floatToByte = vectorSerializer.floatsToByteArray(value); + this.setBytesValue(floatToByte.array()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * @param name FieldType name + * @param value multi arrays of byte vector values + * @param type FieldType to build DocValues + */ + public MultiVectorField(String name, byte[] value, IndexableFieldType type) { + super(name, new BytesRef(), type); + try { + this.setBytesValue(value); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorAsCollectionOfFloatsSerializer.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorAsCollectionOfFloatsSerializer.java index 636433d740..d0e96fc99a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorAsCollectionOfFloatsSerializer.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorAsCollectionOfFloatsSerializer.java @@ -9,6 +9,8 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; import java.util.stream.IntStream; /** @@ -38,4 +40,38 @@ public float[] byteToFloatArray(BytesRef bytesRef) { ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length).asFloatBuffer().get(vector); return vector; } + + @Override + public ByteBuffer floatsToByteArray(List input) { + //TODO + if (input.isEmpty()) { + return null; + } + int bufferSize = input.size() * input.get(0).length * BYTES_IN_FLOAT; + final ByteBuffer bb = ByteBuffer.allocate(bufferSize).order(ByteOrder.BIG_ENDIAN); + + for (float[] vector : input) { + for(float f : vector) { + bb.putFloat(f); + } + } + return bb; + } + + @Override + public List byteToFloatsArray(BytesRef bytesRef, int dims) { + if (bytesRef == null || bytesRef.length % BYTES_IN_FLOAT != 0) { + throw new IllegalArgumentException("Byte stream cannot be deserialized to arrays of floats"); + } + final int sizeOfFloatArray = bytesRef.length / BYTES_IN_FLOAT; + int number_vectors = sizeOfFloatArray / dims; + List vectors = new ArrayList<>(); + for (int i = 0, offset = 0; i < number_vectors; i++) { + final float[] vector = new float[dims]; + ByteBuffer.wrap(bytesRef.bytes, offset, dims * BYTES_IN_FLOAT).asFloatBuffer().get(vector); + vectors.add(vector); + offset += dims * BYTES_IN_FLOAT; + } + return vectors; + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializer.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializer.java index f7e7a6743e..62b8fe40f7 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializer.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializer.java @@ -7,6 +7,9 @@ import org.apache.lucene.util.BytesRef; +import java.nio.ByteBuffer; +import java.util.List; + /** * Interface abstracts the vector serializer object that is responsible for serialization and de-serialization of k-NN vector */ @@ -25,4 +28,20 @@ public interface KNNVectorSerializer { * @return array of floats deserialized from the stream */ float[] byteToFloatArray(BytesRef bytesRef); + + /** + * Serializes multi array of floats to array of bytes + * @param input multi array that will be converted + * @return array of bytes that contains serialized input array + */ + ByteBuffer floatsToByteArray(List input); + + /** + * Deserializes all bytes from the stream to multi array of floats + * + * @param bytesRef bytes that will be used for deserialization to array of floats + * @return array of floats deserialized from the stream + */ + List byteToFloatsArray(BytesRef bytesRef, int dims); + } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNMultiVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNMultiVectorFieldMapper.java new file mode 100644 index 0000000000..00983a7096 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNMultiVectorFieldMapper.java @@ -0,0 +1,350 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.IndexOptions; +import org.opensearch.Version; +import org.opensearch.common.Explicit; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.FieldMapper; +import org.opensearch.index.mapper.Mapper; +import org.opensearch.index.mapper.MapperParsingException; +import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.index.mapper.ParseContext; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.useFullFieldNameValidation; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfCircuitBreakerIsNotTriggered; +import static org.opensearch.knn.index.mapper.ModelFieldMapper.UNSET_MODEL_DIMENSION_IDENTIFIER; + +/** + * "mappings": { + * "properties": { + * "my_multi_knn_vector": { + * "type": "multi_knn_vector" + * "dimension": 3 + * } + * } + * } + * + * PUT xx_index/_doc/1 + * { "my_multi_knn_vector": [ {1,2,3}, {4,5,6}, ...]} + */ +public class KNNMultiVectorFieldMapper extends ParametrizedFieldMapper { + + public static final String CONTENT_TYPE = "multi_knn_vector"; + public static final String MULTI_KNN_FIELD = "multi_knn_field"; + + private static KNNVectorFieldMapper toType(FieldMapper in) { + return (KNNVectorFieldMapper) in; + } + public static class Builder extends ParametrizedFieldMapper.Builder { + protected Boolean ignoreMalformed; + + protected final Parameter stored = Parameter.storeParam(m -> toType(m).stored, false); + + protected Parameter hasDocValues; + protected Version indexCreatedVersion; + + protected final Parameter dimension = new Parameter<>( + KNNConstants.DIMENSION, + false, + () -> UNSET_MODEL_DIMENSION_IDENTIFIER, + (n, c, o) -> { + if (o == null) { + throw new IllegalArgumentException("Dimension cannot be null"); + } + int value; + try { + value = XContentMapValues.nodeIntegerValue(o); + } catch (Exception exception) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Unable to parse [dimension] from provided value [%s] for vector [%s]", o, name) + ); + } + if (value <= 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Dimension value must be greater than 0 for vector: %s", name) + ); + } + return value; + }, + m -> toType(m).originalMappingParameters.getDimension() + ); + + protected final Parameter vectorDataType = new Parameter<>( + VECTOR_DATA_TYPE_FIELD, + false, + () -> DEFAULT_VECTOR_DATA_TYPE_FIELD, + (n, c, o) -> VectorDataType.get((String) o), + m -> toType(m).originalMappingParameters.getVectorDataType() + ); + + public Builder( + String name, + Version indexCreatedVersion + ) { + super(name); + this.indexCreatedVersion = indexCreatedVersion; + + if (indexCreatedVersion.before(Version.V_3_0_0)) { + hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); + } else { + hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, false); + } + } + + @Override + protected List> getParameters() { + return Arrays.asList(stored, hasDocValues, dimension, vectorDataType); + } + + protected Explicit ignoreMalformed(BuilderContext context) { + if (ignoreMalformed != null) { + return new Explicit<>(ignoreMalformed, true); + } + if (context.indexSettings() != null) { + return new Explicit<>(IGNORE_MALFORMED_SETTING.get(context.indexSettings()), false); + } + return KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED; + } + + @Override + public KNNVectorFieldMapper build(BuilderContext context) { + if (useFullFieldNameValidation(indexCreatedVersion)) { + validateFullFieldName(context); + } + + final MultiFields multiFieldsBuilder = this.multiFieldsBuilder.build(this, context); + final CopyTo copyToBuilder = copyTo.build(); + final Explicit ignoreMalformed = ignoreMalformed(context); + + if (indexCreatedVersion.onOrAfter(Version.V_3_0_0) && hasDocValues.isConfigured() == false) { + hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); + } + //TODO + return FlatVectorFieldMapper.createFieldMapper( + buildFullName(context), + name, + KNNMethodConfigContext.builder() + .vectorDataType(vectorDataType.getValue()) + .versionCreated(indexCreatedVersion) + .dimension(dimension.getValue()) + .build(), + multiFieldsBuilder, + copyToBuilder, + ignoreMalformed, + stored.get(), + hasDocValues.get() + ); + } + + private void validateFullFieldName(final BuilderContext context) { + final String fullFieldName = buildFullName(context); + for (char ch : fullFieldName.toCharArray()) { + if (Strings.INVALID_FILENAME_CHARS.contains(ch)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Vector field name must not include invalid characters of %s. " + + "Provided field name=[%s] had a disallowed character [%c]", + Strings.INVALID_FILENAME_CHARS.stream().map(c -> "'" + c + "'").collect(Collectors.toList()), + fullFieldName, + ch + ) + ); + } + } + } + } + + public static class TypeParser implements Mapper.TypeParser { + @Override + public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { + Builder builder = new KNNMultiVectorFieldMapper.Builder(name, parserContext.indexVersionCreated()); + builder.parse(name, parserContext, node); + if (builder.dimension.getValue() != UNSET_MODEL_DIMENSION_IDENTIFIER + && parserContext.indexVersionCreated().onOrAfter(Version.V_2_19_0)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Cannot specify both a modelId and dimension in the mapping: %s", name) + ); + } + + // Check for flat configuration and validate only if index is created after 2.17 + if (isKNNDisabled(parserContext.getSettings()) && parserContext.indexVersionCreated().onOrAfter(Version.V_2_17_0)) { + validateFromFlat(builder); + } + + return builder; + } + + private boolean isKNNDisabled(Settings settings) { + boolean isSettingPresent = KNNSettings.IS_KNN_INDEX_SETTING.exists(settings); + return !isSettingPresent || !KNNSettings.IS_KNN_INDEX_SETTING.get(settings); + } + + private void validateFromFlat(KNNMultiVectorFieldMapper.Builder builder) { + validateDimensionSet(builder); + } + + private void validateDimensionSet(KNNMultiVectorFieldMapper.Builder builder) { + if (builder.dimension.getValue() == UNSET_MODEL_DIMENSION_IDENTIFIER) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Dimension value missing for vector: %s", builder.name())); + } + } + } + + protected Version indexCreatedVersion; + protected Explicit ignoreMalformed; + protected boolean stored; + protected boolean hasDocValues; + protected VectorDataType vectorDataType; + + public KNNMultiVectorFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + Version indexCreatedVersion + ) { + super(simpleName, mappedFieldType, multiFields, copyTo); + this.ignoreMalformed = ignoreMalformed; + this.stored = stored; + this.hasDocValues = hasDocValues; + this.vectorDataType = mappedFieldType.getVectorDataType(); + this.indexCreatedVersion = indexCreatedVersion; + } + + @Override + public Builder getMergeBuilder() { + //TODO + return null; + } + + public KNNMultiVectorFieldMapper clone() { + return (KNNMultiVectorFieldMapper) super.clone(); + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + protected void parseCreateField(ParseContext context) throws IOException { + parseCreateField(context, fieldType().getVectorDimensions(), fieldType().getVectorDataType()); + } + protected void validatePreparse() { + validateIfCircuitBreakerIsNotTriggered(); + } + + //TODO + protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { + validatePreparse(); + + if () + + context.path().remove(); + } + + @Override + public void forEach(Consumer action) { + super.forEach(action); + } + + @Override + public final boolean parsesArrayValue() { + return true; + } + + //TODO + @Override + public KNNMultiVectorFieldType fieldType() { + return (KNNMultiVectorFieldType) super.fieldType(); + } + + Optional getBytesFromContext(ParseContext context, int dimension, VectorDataType dataType) throws IOException { + context.path().add(simpleName()); + + PerDimensionValidator perDimensionValidator = getPerDimensionValidator(); + PerDimensionProcessor perDimensionProcessor = getPerDimensionProcessor(); + + ArrayList vector = new ArrayList<>(); + XContentParser.Token token = context.parser().currentToken(); + + if (token == XContentParser.Token.START_ARRAY) { + token = context.parser().nextToken(); + while (token != XContentParser.Token.END_ARRAY) { + float value = perDimensionProcessor.processByte(context.parser().floatValue()); + perDimensionValidator.validateByte(value); + vector.add((byte) value); + token = context.parser().nextToken(); + } + } else if (token == XContentParser.Token.VALUE_NUMBER) { + float value = perDimensionProcessor.processByte(context.parser().floatValue()); + perDimensionValidator.validateByte(value); + vector.add((byte) value); + context.parser().nextToken(); + } else if (token == XContentParser.Token.VALUE_NULL) { + context.path().remove(); + return Optional.empty(); + } + validateVectorDimension(dimension, vector.size(), dataType); + byte[] array = new byte[vector.size()]; + int i = 0; + for (Byte f : vector) { + array[i++] = f; + } + return Optional.of(array); + } + + @Override + protected void doXContentBody(XContentBuilder builder, boolean includeDefaults, Params params) throws IOException { + super.doXContentBody(builder, includeDefaults, params); + if (includeDefaults || ignoreMalformed.explicit()) { + builder.field(KNNVectorFieldMapper.Names.IGNORE_MALFORMED, ignoreMalformed.value()); + } + } + public static class Names { + public static final String IGNORE_MALFORMED = "ignore_malformed"; + } + + public static class Defaults { + public static final Explicit IGNORE_MALFORMED = new Explicit<>(false, false); + public static final FieldType FIELD_TYPE = new FieldType(); + + static { + FIELD_TYPE.setTokenized(false); + FIELD_TYPE.setIndexOptions(IndexOptions.NONE); + FIELD_TYPE.putAttribute(MULTI_KNN_FIELD, "true"); // This attribute helps to determine knn field type + FIELD_TYPE.freeze(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNMultiVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNMultiVectorFieldType.java new file mode 100644 index 0000000000..79b8b44dd3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNMultiVectorFieldType.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.util.BytesRef; +import org.opensearch.index.fielddata.IndexFieldData; +import org.opensearch.index.mapper.ArraySourceValueFetcher; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.TextSearchInfo; +import org.opensearch.index.mapper.ValueFetcher; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.QueryShardException; +import org.opensearch.knn.index.KNNVectorIndexFieldData; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.search.aggregations.support.CoreValuesSourceType; +import org.opensearch.search.lookup.SearchLookup; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Supplier; + +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; + +@Log4j2 +public class KNNMultiVectorFieldType extends MappedFieldType { + + VectorDataType vectorDataType; + KNNMappingConfig knnMappingConfig; + + /** + * Constructor for KNNMultiVectorFieldType. + * + * @param name name of the field + * @param metadata metadata of the field + * @param vectorDataType data type of the vector + * @param annConfig configuration context for the ANN index + */ + public KNNMultiVectorFieldType(String name, Map metadata, VectorDataType vectorDataType, KNNMappingConfig annConfig) { + super(name, false, false, true, TextSearchInfo.NONE, metadata); + this.vectorDataType = vectorDataType; + this.knnMappingConfig = annConfig; + } + + @Override + public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) { + return new ArraySourceValueFetcher(name(), context) { + @Override + protected Object parseSourceValue(Object value) { + if (value instanceof ArrayList) { + return value; + } else { + log.warn("Expected type ArrayList for value, but got {} ", value.getClass()); + return Collections.emptyList(); + } + } + }; + } + + @Override + public String typeName() { + return KNNMultiVectorFieldMapper.CONTENT_TYPE; + } + + @Override + public Query existsQuery(QueryShardContext context) { + return new FieldExistsQuery(name()); + } + + @Override + public Query termQuery(Object o, QueryShardContext context) { + throw new QueryShardException( + context, + String.format(Locale.ROOT, "KNN vector do not support exact searching, use KNN queries instead: [%s]", name()) + ); + } + + @Override + public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier searchLookup) { + failIfNoDocValues(); + return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType); + } + + @Override + public Object valueForDisplay(Object value) { + //TODO + return deserializeStoredVector((BytesRef) value, vectorDataType); + } + + public int getVectorDimensions() { + return knnMappingConfig.getDimension(); + } + + public VectorDataType getVectorDataType() { + return vectorDataType; + } +}