Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/main/java/org/opensearch/knn/index/MultiVectorField.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.KNNVectorAsCollectionOfFloatsSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;

import java.nio.ByteBuffer;
import java.util.List;

public class MultiVectorField extends Field{


public MultiVectorField(String name, List<float[]> value, IndexableFieldType type) {
super(name, new BytesRef(), type);
try {
final KNNVectorSerializer vectorSerializer = KNNVectorAsCollectionOfFloatsSerializer.INSTANCE;
final ByteBuffer floatToByte = vectorSerializer.floatsToByteArray(value);
this.setBytesValue(floatToByte.array());
} catch (Exception e) {
throw new RuntimeException(e);
}
}

/**
* @param name FieldType name
* @param value multi arrays of byte vector values
* @param type FieldType to build DocValues
*/
public MultiVectorField(String name, byte[] value, IndexableFieldType type) {
super(name, new BytesRef(), type);
try {
this.setBytesValue(value);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;

/**
Expand Down Expand Up @@ -38,4 +40,38 @@ public float[] byteToFloatArray(BytesRef bytesRef) {
ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length).asFloatBuffer().get(vector);
return vector;
}

@Override
public ByteBuffer floatsToByteArray(List<float[]> input) {
//TODO
if (input.isEmpty()) {
return null;
}
int bufferSize = input.size() * input.get(0).length * BYTES_IN_FLOAT;
final ByteBuffer bb = ByteBuffer.allocate(bufferSize).order(ByteOrder.BIG_ENDIAN);

for (float[] vector : input) {
for(float f : vector) {
bb.putFloat(f);
}
}
return bb;
}

@Override
public List<float[]> byteToFloatsArray(BytesRef bytesRef, int dims) {
if (bytesRef == null || bytesRef.length % BYTES_IN_FLOAT != 0) {
throw new IllegalArgumentException("Byte stream cannot be deserialized to arrays of floats");
}
final int sizeOfFloatArray = bytesRef.length / BYTES_IN_FLOAT;
int number_vectors = sizeOfFloatArray / dims;
List<float[]> vectors = new ArrayList<>();
for (int i = 0, offset = 0; i < number_vectors; i++) {
final float[] vector = new float[dims];
ByteBuffer.wrap(bytesRef.bytes, offset, dims * BYTES_IN_FLOAT).asFloatBuffer().get(vector);
vectors.add(vector);
offset += dims * BYTES_IN_FLOAT;
}
return vectors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import org.apache.lucene.util.BytesRef;

import java.nio.ByteBuffer;
import java.util.List;

/**
* Interface abstracts the vector serializer object that is responsible for serialization and de-serialization of k-NN vector
*/
Expand All @@ -25,4 +28,20 @@ public interface KNNVectorSerializer {
* @return array of floats deserialized from the stream
*/
float[] byteToFloatArray(BytesRef bytesRef);

/**
* Serializes multi array of floats to array of bytes
* @param input multi array that will be converted
* @return array of bytes that contains serialized input array
*/
ByteBuffer floatsToByteArray(List<float[]> input);

/**
* Deserializes all bytes from the stream to multi array of floats
*
* @param bytesRef bytes that will be used for deserialization to array of floats
* @return array of floats deserialized from the stream
*/
List<float[]> byteToFloatsArray(BytesRef bytesRef, int dims);

}
Loading
Loading