diff --git a/CHANGELOG.md b/CHANGELOG.md index a595d6dc66..a4f3bf85fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.19...2.x) ### Features * [Vector Profiler] Adding basic generic vector profiler implementation and tests. [#2624](https://github.com/opensearch-project/k-NN/pull/2624) +* [Vector Profiler] Adding serializer and api implementation for segment profiler state. [#2687](https://github.com/opensearch-project/k-NN/pull/2687) ### Enhancements ### Bug Fixes diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index ce21527e74..8b7818aef7 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -84,6 +84,7 @@ public class KNNConstants { public static final String RADIAL_SEARCH_KEY = "radial_search"; public static final String MODEL_VERSION = "model_version"; public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate"; + public static final String SEGMENT_PROFILE_STATE_FILE_SUFFIX = "segpfstate"; // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index 0188f488f0..a02dfa3314 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -28,6 +28,9 @@ import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.SegmentProfilerUtil; +import org.opensearch.knn.plugin.transport.KNNIndexShardProfileResult; +import org.opensearch.knn.profiler.SegmentProfilerState; import java.io.IOException; import java.util.ArrayList; @@ -82,6 +85,42 @@ public String getIndexName() { return indexShard.shardId().getIndexName(); } + public KNNIndexShardProfileResult profile(final String field) { + try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-profile")) { + log.info("[KNN] Profiling field [{}] in index [{}]", field, getIndexName()); + + List segmentLevelProfilerStates = new ArrayList<>(); + + // For each leaf, collect the profile + for (LeafReaderContext leaf : searcher.getIndexReader().leaves()) { + try { + SegmentProfilerState state = SegmentProfilerUtil.getSegmentProfileState(leaf.reader(), field); + segmentLevelProfilerStates.add(state); + log.debug("[KNN] Successfully profiled segment for field [{}]", field); + } catch (IOException e) { + log.error("[KNN] Failed to profile segment for field [{}]: {}", field, e.getMessage(), e); + throw new RuntimeException("Failed to profile segment: " + e.getMessage(), e); + } + } + + if (segmentLevelProfilerStates.isEmpty()) { + log.warn("[KNN] No segments found with field [{}] in index [{}]", field, getIndexName()); + } else { + log.info( + "[KNN] Successfully profiled [{}] segments for field [{}] in index [{}]", + segmentLevelProfilerStates.size(), + field, + getIndexName() + ); + } + + return new KNNIndexShardProfileResult(segmentLevelProfilerStates, indexShard.shardId().toString()); + } catch (Exception e) { + log.error("[KNN] Error profiling field [{}] in index [{}]: {}", field, getIndexName(), e.getMessage(), e); + throw new RuntimeException("Error during profiling: " + e.getMessage(), e); + } + } + /** * Load all of the k-NN segments for this shard into the cache. * diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java index 49b1819c10..0301a31073 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java @@ -12,6 +12,7 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.store.IndexOutput; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.profiler.SegmentProfilerState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -46,14 +47,18 @@ public final class KNN990QuantizationStateWriter { * @param segmentWriteState segment write state containing segment information * @throws IOException exception could be thrown while creating the output */ - public KNN990QuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException { - String quantizationStateFileName = IndexFileNames.segmentFileName( + public KNN990QuantizationStateWriter(SegmentWriteState segmentWriteState, String fileSuffix) throws IOException { + String stateFileName = IndexFileNames.segmentFileName( segmentWriteState.segmentInfo.name, segmentWriteState.segmentSuffix, - KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX + fileSuffix ); - output = segmentWriteState.directory.createOutput(quantizationStateFileName, segmentWriteState.context); + output = segmentWriteState.directory.createOutput(stateFileName, segmentWriteState.context); + } + + public KNN990QuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException { + this(segmentWriteState, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX); } /** @@ -85,6 +90,20 @@ public void writeState(int fieldNumber, QuantizationState quantizationState) thr fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position)); } + /** + * Writes a segment profile state as bytes + * + * @param fieldNumber field number + * @param segmentProfilerState segment profiler state + * @throws IOException could be thrown while writing + */ + public void writeState(int fieldNumber, SegmentProfilerState segmentProfilerState) throws IOException { + byte[] stateBytes = segmentProfilerState.toByteArray(); + long position = output.getFilePointer(); + output.writeBytes(stateBytes, stateBytes.length); + fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position)); + } + /** * Writes index footer and other index information for parsing later * @throws IOException could be thrown while writing diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index 1b0e2a8397..57bb8f8e28 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -35,6 +35,10 @@ import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.memoryoptsearch.VectorSearcher; import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory; +import org.opensearch.knn.profiler.KNN990ProfileStateReader; +import org.opensearch.knn.profiler.SegmentProfileKNNCollector; +import org.opensearch.knn.profiler.SegmentProfileStateReadConfig; +import org.opensearch.knn.profiler.SegmentProfilerState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; @@ -163,6 +167,14 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits return; } + if (knnCollector instanceof SegmentProfileKNNCollector) { + SegmentProfilerState segmentProfileState = KNN990ProfileStateReader.read( + new SegmentProfileStateReadConfig(segmentReadState, field) + ); + ((SegmentProfileKNNCollector) knnCollector).setSegmentProfilerState(segmentProfileState); + return; + } + if (trySearchWithMemoryOptimizedSearch(field, target, knnCollector, acceptDocs, true)) { return; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index c54a1a61b7..3e954006a7 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -15,14 +15,12 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.MergeState; -import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.*; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.common.StopWatch; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; @@ -52,6 +50,7 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private final SegmentWriteState segmentWriteState; private final FlatVectorsWriter flatVectorsWriter; private KNN990QuantizationStateWriter quantizationStateWriter; + private KNN990QuantizationStateWriter segmentStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; private final Integer approximateThreshold; @@ -108,7 +107,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field.getVectors() ); final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); - SegmentProfilerState.profileVectors(knnVectorValuesSupplier); + profile(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); // should skip graph building only for non quantization use case and if threshold is met if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { log.debug( @@ -152,9 +151,13 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState } final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); + + // Write the segment profile state to the directory + profile(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); + // should skip graph building only for non quantization use case and if threshold is met if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { - log.debug( + log.info( "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge", fieldInfo.name, totalLiveDocs, @@ -190,6 +193,10 @@ public void finish() throws IOException { if (quantizationStateWriter != null) { quantizationStateWriter.writeFooter(); } + if (segmentStateWriter != null) { + segmentStateWriter.writeFooter(); + } + flatVectorsWriter.finish(); } @@ -211,6 +218,9 @@ public void close() throws IOException { if (quantizationStateWriter != null) { quantizationStateWriter.closeOutput(); } + if (segmentStateWriter != null) { + segmentStateWriter.closeOutput(); + } IOUtils.close(flatVectorsWriter); } @@ -243,6 +253,23 @@ private QuantizationState train( return quantizationState; } + private SegmentProfilerState profile( + final FieldInfo fieldInfo, + final Supplier> knnVectorValuesSupplier, + final int totalLiveDocs + ) throws IOException { + + SegmentProfilerState segmentProfilerState = null; + if (totalLiveDocs > 0) { + initSegmentStateWriterIfNecessary(); + String segmentId = segmentWriteState.segmentInfo.name; + SegmentProfilerState profileResultForSegment = SegmentProfilerState.profileVectors(knnVectorValuesSupplier, segmentId); + segmentStateWriter.writeState(fieldInfo.getFieldNumber(), profileResultForSegment); + } + + return segmentProfilerState; + } + /** * The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the * vectorsValues object which you plan to use later @@ -260,11 +287,18 @@ private int getLiveDocs(KNNVectorValues vectorValues) throws IOException { private void initQuantizationStateWriterIfNecessary() throws IOException { if (quantizationStateWriter == null) { - quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); + quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX); quantizationStateWriter.writeHeader(segmentWriteState); } } + private void initSegmentStateWriterIfNecessary() throws IOException { + if (segmentStateWriter == null) { + segmentStateWriter = new KNN990QuantizationStateWriter(segmentWriteState, KNNConstants.SEGMENT_PROFILE_STATE_FILE_SUFFIX); + segmentStateWriter.writeHeader(segmentWriteState); + } + } + private boolean shouldSkipBuildingVectorDataStructure(final long docCount) { if (approximateThreshold < 0) { return true; diff --git a/src/main/java/org/opensearch/knn/index/query/SegmentProfilerUtil.java b/src/main/java/org/opensearch/knn/index/query/SegmentProfilerUtil.java new file mode 100644 index 0000000000..0aaf925331 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/SegmentProfilerUtil.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import lombok.experimental.UtilityClass; +import org.apache.lucene.index.LeafReader; +import org.opensearch.knn.profiler.SegmentProfileKNNCollector; +import org.opensearch.knn.profiler.SegmentProfilerState; + +import java.io.IOException; +import java.util.Locale; + +/** + * Utility class to get segment profiler state for a given field + */ +@UtilityClass +public class SegmentProfilerUtil { + + /** + * Gets the segment profile state for a given field + * @param leafReader The leaf reader to query + * @param fieldName The field name to profile + * @return The segment profiler state + * @throws IOException If there's an error reading the segment + */ + public static SegmentProfilerState getSegmentProfileState(final LeafReader leafReader, String fieldName) throws IOException { + final SegmentProfileKNNCollector tempCollector = new SegmentProfileKNNCollector(); + leafReader.searchNearestVectors(fieldName, new float[0], tempCollector, null); + if (tempCollector.getSegmentProfilerState() == null) { + throw new IllegalStateException(String.format(Locale.ROOT, "No segment state found for field %s", fieldName)); + } + return tempCollector.getSegmentProfilerState(); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 4e40b119ac..1e5707e995 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -84,6 +84,9 @@ import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; +import org.opensearch.knn.profiler.RestKNNProfileHandler; +import org.opensearch.knn.plugin.transport.KNNProfileTransportAction; +import org.opensearch.knn.plugin.transport.KNNProfileAction; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; import org.opensearch.knn.training.TrainingJobClusterStateListener; import org.opensearch.knn.training.TrainingJobRunner; @@ -269,6 +272,7 @@ public List getRestHandlers( ) { RestKNNStatsHandler restKNNStatsHandler = new RestKNNStatsHandler(); + RestKNNProfileHandler restKNNProfileHandler = new RestKNNProfileHandler(); RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler( settings, restController, @@ -284,6 +288,7 @@ public List getRestHandlers( return ImmutableList.of( restKNNStatsHandler, restKNNWarmupHandler, + restKNNProfileHandler, restGetModelHandler, restDeleteModelHandler, restTrainModelHandler, @@ -300,6 +305,7 @@ public List getRestHandlers( return Arrays.asList( new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class), new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class), + new ActionHandler<>(KNNProfileAction.INSTANCE, KNNProfileTransportAction.class), new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class), new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, TrainingJobRouteDecisionInfoTransportAction.class), new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class), diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNIndexShardProfileResult.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNIndexShardProfileResult.java new file mode 100644 index 0000000000..95b724625c --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNIndexShardProfileResult.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.AllArgsConstructor; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.knn.profiler.SegmentProfilerState; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +@AllArgsConstructor +public class KNNIndexShardProfileResult implements Writeable { + List segmentProfilerStateList; + String shardId; + + /** + * Constructor to deserialize from StreamInput + * + * @param in StreamInput to read from + * @throws IOException if there's an error reading from the stream + */ + public KNNIndexShardProfileResult(StreamInput in) throws IOException { + this.shardId = in.readString(); + + // Read the number of SegmentProfilerStates + int size = in.readInt(); + this.segmentProfilerStateList = new ArrayList<>(size); + + // Read each SegmentProfilerState + for (int i = 0; i < size; i++) { + byte[] stateBytes = in.readByteArray(); + this.segmentProfilerStateList.add(SegmentProfilerState.fromBytes(stateBytes)); + } + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeString(shardId); + + // Write the segment profiler state list size + streamOutput.writeInt(segmentProfilerStateList.size()); + + // Write each SegmentProfilerState as bytes + for (SegmentProfilerState state : segmentProfilerStateList) { + byte[] stateBytes = state.toByteArray(); + streamOutput.writeByteArray(stateBytes); + } + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileAction.java new file mode 100644 index 0000000000..8807348abe --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileAction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.core.common.io.stream.Writeable; + +public class KNNProfileAction extends ActionType { + + public static final KNNProfileAction INSTANCE = new KNNProfileAction(); + public static final String NAME = "cluster:admin/knn_profile_action"; + + /** + * Constructor + */ + private KNNProfileAction() { + super(NAME, KNNProfileResponse::new); + } + + @Override + public Writeable.Reader getResponseReader() { + return KNNProfileResponse::new; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileRequest.java new file mode 100644 index 0000000000..0df82cf37b --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileRequest.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.Getter; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.broadcast.BroadcastRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +@Getter +public class KNNProfileRequest extends BroadcastRequest { + + private final String index; + private final String field; + + /** + * Constructor + */ + public KNNProfileRequest(String index, String field) { + super(); + this.index = index; + this.field = field; + } + + /** + * Constructor + * + * @param in input stream + * @throws IOException in case of I/O errors + */ + public KNNProfileRequest(StreamInput in) throws IOException { + super(in); + index = in.readString(); + field = in.readString(); + } + + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(index); + out.writeString(field); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileResponse.java new file mode 100644 index 0000000000..901e8355b6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileResponse.java @@ -0,0 +1,304 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.apache.commons.math3.stat.descriptive.AggregateSummaryStatistics; +import org.apache.commons.math3.stat.descriptive.StatisticalSummary; +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues; +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; +import org.opensearch.action.support.broadcast.BroadcastResponse; +import org.opensearch.core.action.support.DefaultShardOperationFailedException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.profiler.SegmentProfilerState; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Response object for KNN profile requests that provides statistical information about vectors + * at segment, shard, and cluster levels. + * + * Example response: + * { + * "total_shards": 2, + * "successful_shards": 2, + * "failed_shards": 0, + * "shard_profiles": { + * "0": { + * "segments": [ + * { + * "segment_id": "_0", + * "dimension": 128, + * "vector_statistics": [ + * { + * "dimension_index": 0, + * "statistics": { + * "count": 1000, + * "min": -0.523, + * "max": 0.785, + * "sum": 156.78, + * "mean": 0.157, + * "geometric_mean": 0.145, + * "variance": 0.089, + * "std_deviation": 0.298, + * "sum_of_squares": 245.67 + * } + * } + * // ... additional dimensions + * ] + * } + * // ... additional segments + * ], + * "aggregated": { + * "total_segments": 3, + * "dimension": 128, + * "dimensions": [ + * { + * "dimension_id": 0, + * "count": 3000, + * "min": -0.723, + * "max": 0.892, + * "mean": 0.167, + * "std_deviation": 0.312, + * "sum": 501.23, + * "variance": 0.097 + * } + * // ... additional dimensions + * ] + * } + * } + * // ... additional shards + * }, + * "cluster_aggregation": { + * "total_shards": 2, + * "dimension": 128, + * "dimensions": [ + * { + * "dimension_id": 0, + * "count": 6000, + * "min": -0.723, + * "max": 0.892, + * "mean": 0.172, + * "std_deviation": 0.315, + * "sum": 1032.45, + * "variance": 0.099 + * } + * // ... additional dimensions + * ] + * }, + * "failures": [] + * } + */ +public class KNNProfileResponse extends BroadcastResponse { + private static final String FIELD_SHARD_PROFILES = "shard_profiles"; + private static final String FIELD_SEGMENTS = "segments"; + private static final String FIELD_SEGMENT_ID = "segment_id"; + private static final String FIELD_DIMENSION = "dimension"; + private static final String FIELD_VECTOR_STATISTICS = "vector_statistics"; + private static final String FIELD_DIMENSION_INDEX = "dimension_index"; + private static final String FIELD_STATISTICS = "statistics"; + private static final String FIELD_COUNT = "count"; + private static final String FIELD_MIN = "min"; + private static final String FIELD_MAX = "max"; + private static final String FIELD_SUM = "sum"; + private static final String FIELD_MEAN = "mean"; + private static final String FIELD_GEOMETRIC_MEAN = "geometric_mean"; + private static final String FIELD_VARIANCE = "variance"; + private static final String FIELD_STD_DEVIATION = "std_deviation"; + private static final String FIELD_SUM_OF_SQUARES = "sum_of_squares"; + private static final String FIELD_AGGREGATED = "aggregated"; + private static final String FIELD_TOTAL_SEGMENTS = "total_segments"; + private static final String FIELD_DIMENSIONS = "dimensions"; + private static final String FIELD_DIMENSION_ID = "dimension_id"; + private static final String FIELD_CLUSTER_AGGREGATION = "cluster_aggregation"; + private static final String FIELD_TOTAL_SHARDS = "total_shards"; + private static final String FIELD_FAILURES = "failures"; + + List shardProfileResults; + + public KNNProfileResponse() {} + + public KNNProfileResponse(StreamInput in) throws IOException { + super(in); + int size = in.readInt(); + List results = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + results.add(new KNNIndexShardProfileResult(in)); + } + this.shardProfileResults = results; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeInt(shardProfileResults.size()); + for (KNNIndexShardProfileResult result : shardProfileResults) { + result.writeTo(out); + } + } + + public KNNProfileResponse( + List shardProfileResults, + int totalShards, + int successfulShards, + int failedShards, + List shardFailures + ) { + super(totalShards, successfulShards, failedShards, shardFailures); + + this.shardProfileResults = shardProfileResults; + } + + @Override + protected void addCustomXContentFields(XContentBuilder builder, Params params) throws IOException { + addShardProfiles(builder); + addClusterAggregation(builder); + addShardFailures(builder, params); + } + + private void addShardProfiles(XContentBuilder builder) throws IOException { + builder.startObject(FIELD_SHARD_PROFILES); + for (KNNIndexShardProfileResult shardProfileResult : shardProfileResults) { + builder.startObject(shardProfileResult.shardId); + addSegmentStatistics(builder, shardProfileResult); + addShardAggregatedStatistics(builder, shardProfileResult); + builder.endObject(); + } + builder.endObject(); + } + + private void addSegmentStatistics(XContentBuilder builder, KNNIndexShardProfileResult shardProfileResult) throws IOException { + builder.startArray(FIELD_SEGMENTS); + for (SegmentProfilerState state : shardProfileResult.segmentProfilerStateList) { + builder.startObject() + .field(FIELD_SEGMENT_ID, state.getSegmentId()) + .field(FIELD_DIMENSION, state.getDimension()) + .startArray(FIELD_VECTOR_STATISTICS); + addDimensionStatistics(builder, state); + builder.endArray().endObject(); + } + builder.endArray(); + } + + private void addDimensionStatistics(XContentBuilder builder, SegmentProfilerState state) throws IOException { + for (int i = 0; i < state.getStatistics().size(); i++) { + SummaryStatistics stats = state.getStatistics().get(i); + builder.startObject() + .field(FIELD_DIMENSION_INDEX, i) + .startObject(FIELD_STATISTICS) + .field(FIELD_COUNT, stats.getN()) + .field(FIELD_MIN, stats.getMin()) + .field(FIELD_MAX, stats.getMax()) + .field(FIELD_SUM, stats.getSum()) + .field(FIELD_MEAN, stats.getMean()) + .field(FIELD_GEOMETRIC_MEAN, stats.getGeometricMean()) + .field(FIELD_VARIANCE, stats.getVariance()) + .field(FIELD_STD_DEVIATION, stats.getStandardDeviation()) + .field(FIELD_SUM_OF_SQUARES, stats.getSumsq()) + .endObject() + .endObject(); + } + } + + private void addShardAggregatedStatistics(XContentBuilder builder, KNNIndexShardProfileResult shardProfileResult) throws IOException { + if (!shardProfileResult.segmentProfilerStateList.isEmpty()) { + SegmentProfilerState firstState = shardProfileResult.segmentProfilerStateList.get(0); + int dimensionCount = firstState.getDimension(); + + builder.startObject(FIELD_AGGREGATED) + .field(FIELD_TOTAL_SEGMENTS, shardProfileResult.segmentProfilerStateList.size()) + .field(FIELD_DIMENSION, dimensionCount) + .startArray(FIELD_DIMENSIONS); + + addAggregatedDimensionStatistics(builder, shardProfileResult, dimensionCount); + + builder.endArray().endObject(); + } + } + + private void addClusterAggregation(XContentBuilder builder) throws IOException { + if (!shardProfileResults.isEmpty() && !shardProfileResults.get(0).segmentProfilerStateList.isEmpty()) { + SegmentProfilerState firstState = shardProfileResults.get(0).segmentProfilerStateList.get(0); + int dimensionCount = firstState.getDimension(); + + builder.startObject(FIELD_CLUSTER_AGGREGATION) + .field(FIELD_TOTAL_SHARDS, getSuccessfulShards()) + .field(FIELD_DIMENSION, dimensionCount) + .startArray(FIELD_DIMENSIONS); + + addClusterDimensionStatistics(builder, dimensionCount); + + builder.endArray().endObject(); + } + } + + private void addAggregatedDimensionStatistics( + XContentBuilder builder, + KNNIndexShardProfileResult shardProfileResult, + int dimensionCount + ) throws IOException { + for (int dim = 0; dim < dimensionCount; dim++) { + List dimensionStats = collectDimensionStats(shardProfileResult.segmentProfilerStateList, dim); + addAggregatedStats(builder, dim, dimensionStats); + } + } + + private void addClusterDimensionStatistics(XContentBuilder builder, int dimensionCount) throws IOException { + for (int dim = 0; dim < dimensionCount; dim++) { + List dimensionStats = collectClusterDimensionStats(dim); + addAggregatedStats(builder, dim, dimensionStats); + } + } + + private List collectDimensionStats(List states, int dimension) { + List stats = new ArrayList<>(); + for (SegmentProfilerState state : states) { + if (dimension < state.getStatistics().size()) { + stats.add(state.getStatistics().get(dimension)); + } + } + return stats; + } + + private List collectClusterDimensionStats(int dimension) { + List stats = new ArrayList<>(); + for (KNNIndexShardProfileResult shardResult : shardProfileResults) { + for (SegmentProfilerState state : shardResult.segmentProfilerStateList) { + if (dimension < state.getStatistics().size()) { + stats.add(state.getStatistics().get(dimension)); + } + } + } + return stats; + } + + private void addAggregatedStats(XContentBuilder builder, int dimension, List stats) throws IOException { + StatisticalSummaryValues aggregatedStats = AggregateSummaryStatistics.aggregate(stats); + builder.startObject() + .field(FIELD_DIMENSION_ID, dimension) + .field(FIELD_COUNT, aggregatedStats.getN()) + .field(FIELD_MIN, aggregatedStats.getMin()) + .field(FIELD_MAX, aggregatedStats.getMax()) + .field(FIELD_MEAN, aggregatedStats.getMean()) + .field(FIELD_STD_DEVIATION, Math.sqrt(aggregatedStats.getVariance())) + .field(FIELD_SUM, aggregatedStats.getSum()) + .field(FIELD_VARIANCE, aggregatedStats.getVariance()) + .endObject(); + } + + private void addShardFailures(XContentBuilder builder, Params params) throws IOException { + if (getShardFailures() != null && getShardFailures().length > 0) { + builder.startArray(FIELD_FAILURES); + for (DefaultShardOperationFailedException failure : getShardFailures()) { + failure.toXContent(builder, params); + } + builder.endArray(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileTransportAction.java new file mode 100644 index 0000000000..e3eb4d6f7a --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileTransportAction.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.broadcast.node.TransportBroadcastByNodeAction; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockException; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardsIterator; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.support.DefaultShardOperationFailedException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.indices.IndicesService; +import org.opensearch.knn.index.KNNIndexShard; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.List; + +@Log4j2 +public class KNNProfileTransportAction extends TransportBroadcastByNodeAction< + KNNProfileRequest, + KNNProfileResponse, + KNNIndexShardProfileResult> { + + private IndicesService indicesService; + + @Inject + public KNNProfileTransportAction( + ClusterService clusterService, + TransportService transportService, + IndicesService indicesService, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + super( + KNNProfileAction.NAME, + clusterService, + transportService, + actionFilters, + indexNameExpressionResolver, + KNNProfileRequest::new, + ThreadPool.Names.SEARCH + ); + this.indicesService = indicesService; + } + + // @Override + // protected KNNIndexShardProfileResult readShardResult(StreamInput in) throws IOException { + // return new KNNIndexShardProfileResult(null, null); + // } + + @Override + protected KNNIndexShardProfileResult readShardResult(StreamInput in) throws IOException { + return new KNNIndexShardProfileResult(in); + } + + @Override + protected KNNProfileResponse newResponse( + KNNProfileRequest request, + int totalShards, + int successfulShards, + int failedShards, + List profileResults, + List shardFailures, + ClusterState clusterState + ) { + return new KNNProfileResponse(profileResults, totalShards, successfulShards, failedShards, shardFailures); + } + + @Override + protected KNNProfileRequest readRequestFrom(StreamInput in) throws IOException { + return new KNNProfileRequest(in); + } + + @Override + protected KNNIndexShardProfileResult shardOperation(KNNProfileRequest request, ShardRouting shardRouting) throws IOException { + KNNIndexShard knnIndexShard = new KNNIndexShard( + indicesService.indexServiceSafe(shardRouting.shardId().getIndex()).getShard(shardRouting.shardId().id()) + ); + + return knnIndexShard.profile(request.getField()); + } + + @Override + protected ShardsIterator shards(ClusterState state, KNNProfileRequest request, String[] concreteIndices) { + return state.routingTable().allShards(concreteIndices); + } + + @Override + protected ClusterBlockException checkGlobalBlock(ClusterState state, KNNProfileRequest request) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_READ); + } + + @Override + protected ClusterBlockException checkRequestBlock(ClusterState state, KNNProfileRequest request, String[] concreteIndices) { + return state.blocks().indicesBlockedException(ClusterBlockLevel.METADATA_READ, concreteIndices); + } +} diff --git a/src/main/java/org/opensearch/knn/profiler/KNN990ProfileStateReader.java b/src/main/java/org/opensearch/knn/profiler/KNN990ProfileStateReader.java new file mode 100644 index 0000000000..d5a23c0d3f --- /dev/null +++ b/src/main/java/org/opensearch/knn/profiler/KNN990ProfileStateReader.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profiler; + +import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.opensearch.knn.common.KNNConstants; + +import java.io.IOException; + +/** + * Reader class for segment profiler states + */ +@Log4j2 +public final class KNN990ProfileStateReader { + + /** + * Reads a segment profiler state for a given field + * + * @param readConfig config for reading the profiler state + * @return SegmentProfilerState object + * @throws IOException if there's an error reading the state + */ + public static SegmentProfilerState read(SegmentProfileStateReadConfig readConfig) throws IOException { + SegmentReadState segmentReadState = readConfig.getSegmentReadState(); + String field = readConfig.getField(); + String stateFileName = getProfileStateFileName(segmentReadState); + int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber(); + + try (IndexInput input = segmentReadState.directory.openInput(stateFileName, IOContext.DEFAULT)) { + CodecUtil.retrieveChecksum(input); + int numFields = getNumFields(input); + + long position = -1; + int length = 0; + + // Read each field's metadata from the index section, break when correct field is found + for (int i = 0; i < numFields; i++) { + int tempFieldNumber = input.readInt(); + int tempLength = input.readInt(); + long tempPosition = input.readVLong(); + if (tempFieldNumber == fieldNumber) { + position = tempPosition; + length = tempLength; + break; + } + } + + if (position == -1 || length == 0) { + throw new IllegalArgumentException(String.format("Field %s not found", field)); + } + + byte[] stateBytes = readStateBytes(input, position, length); + return SegmentProfilerState.fromBytes(stateBytes); + } + } + + @VisibleForTesting + static int getNumFields(IndexInput input) throws IOException { + long footerStart = input.length() - CodecUtil.footerLength(); + long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; + input.seek(markerAndIndexPosition); + long indexStartPosition = input.readLong(); + input.seek(indexStartPosition); + return input.readInt(); + } + + @VisibleForTesting + static byte[] readStateBytes(IndexInput input, long position, int length) throws IOException { + input.seek(position); + byte[] stateBytes = new byte[length]; + input.readBytes(stateBytes, 0, length); + return stateBytes; + } + + @VisibleForTesting + static String getProfileStateFileName(SegmentReadState state) { + return IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, KNNConstants.SEGMENT_PROFILE_STATE_FILE_SUFFIX); + } +} diff --git a/src/main/java/org/opensearch/knn/profiler/RestKNNProfileHandler.java b/src/main/java/org/opensearch/knn/profiler/RestKNNProfileHandler.java new file mode 100644 index 0000000000..c4e639d259 --- /dev/null +++ b/src/main/java/org/opensearch/knn/profiler/RestKNNProfileHandler.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profiler; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.knn.plugin.transport.KNNProfileAction; +import org.opensearch.knn.plugin.transport.KNNProfileRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import java.util.List; + +/** + * RestHandler for k-NN index warmup API. API provides the ability for a user to load specific indices' k-NN graphs + * into memory. + */ +public class RestKNNProfileHandler extends BaseRestHandler { + private static final Logger logger = LogManager.getLogger(RestKNNProfileHandler.class); + private static final String URL_PATH = "/profile/{index}/{field}"; + public static String NAME = "knn_profile_action"; + private IndexNameExpressionResolver indexNameExpressionResolver; + private ClusterService clusterService; + + public RestKNNProfileHandler() {} + + @Override + public String getName() { + return NAME; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, KNNPlugin.KNN_BASE_URI + URL_PATH)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + KNNProfileRequest knnProfileRequest = createKNNProfileRequest(request); + return channel -> client.execute(KNNProfileAction.INSTANCE, knnProfileRequest, new RestToXContentListener<>(channel)); + } + + private KNNProfileRequest createKNNProfileRequest(RestRequest request) { + String indexName = request.param("index"); + String fieldName = request.param("field"); + + return new KNNProfileRequest(indexName, fieldName); + } +} diff --git a/src/main/java/org/opensearch/knn/profiler/SegmentProfileKNNCollector.java b/src/main/java/org/opensearch/knn/profiler/SegmentProfileKNNCollector.java new file mode 100644 index 0000000000..9ba8a750d3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/profiler/SegmentProfileKNNCollector.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profiler; + +import lombok.Getter; +import lombok.Setter; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopDocs; + +/** + * Segment profiler collector for KNN plugin which is used to + * collect profiling information for a segment. + */ +@Setter +@Getter +public class SegmentProfileKNNCollector implements KnnCollector { + + private SegmentProfilerState segmentProfilerState; + + private final String NATIVE_ENGINE_SEARCH_ERROR_MESSAGE = "Search functionality using codec is not supported with Native Engine Reader"; + + @Override + public boolean earlyTerminated() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public void incVisitedCount(int i) { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public long visitedCount() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public long visitLimit() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public int k() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public boolean collect(int i, float v) { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public float minCompetitiveSimilarity() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public TopDocs topDocs() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } +} diff --git a/src/main/java/org/opensearch/knn/profiler/SegmentProfileStateReadConfig.java b/src/main/java/org/opensearch/knn/profiler/SegmentProfileStateReadConfig.java new file mode 100644 index 0000000000..a13c7a1479 --- /dev/null +++ b/src/main/java/org/opensearch/knn/profiler/SegmentProfileStateReadConfig.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profiler; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.index.SegmentReadState; + +@Getter +@AllArgsConstructor +public class SegmentProfileStateReadConfig { + private SegmentReadState segmentReadState; + private String field; +} diff --git a/src/main/java/org/opensearch/knn/profiler/SegmentProfilerState.java b/src/main/java/org/opensearch/knn/profiler/SegmentProfilerState.java index 5fe88018e3..c520ecb82d 100644 --- a/src/main/java/org/opensearch/knn/profiler/SegmentProfilerState.java +++ b/src/main/java/org/opensearch/knn/profiler/SegmentProfilerState.java @@ -5,13 +5,19 @@ package org.opensearch.knn.profiler; +import lombok.AllArgsConstructor; import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.commons.math3.stat.descriptive.SummaryStatistics; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; import java.io.IOException; +import java.io.ObjectOutputStream; +import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; @@ -23,19 +29,17 @@ * This class calculates statistical measurements for each dimension of the vectors in a segment. */ @Log4j2 -public class SegmentProfilerState { +@AllArgsConstructor +public class SegmentProfilerState implements Serializable { - // Stores statistical summaries for each dimension of the vectors @Getter private final List statistics; - /** - * Constructor to initialize the SegmentProfilerState - * @param statistics - */ - public SegmentProfilerState(final List statistics) { - this.statistics = statistics; - } + @Getter + private final int dimension; + + @Getter + private final String segmentId; /** * Profiles vectors in a segment by analyzing their statistical values @@ -43,12 +47,13 @@ public SegmentProfilerState(final List statistics) { * @return SegmentProfilerState * @throws IOException */ - public static SegmentProfilerState profileVectors(final Supplier> knnVectorValuesSupplier) throws IOException { + public static SegmentProfilerState profileVectors(final Supplier> knnVectorValuesSupplier, final String segmentId) + throws IOException { KNNVectorValues vectorValues = knnVectorValuesSupplier.get(); if (vectorValues == null) { log.info("No vector values available"); - return new SegmentProfilerState(new ArrayList<>()); + return new SegmentProfilerState(new ArrayList<>(), 0, segmentId); } // Initialize vector values @@ -58,11 +63,11 @@ public static SegmentProfilerState profileVectors(final Supplier sta * @param statistics * @param dimension */ - private static void logDimensionStatistics(final List statistics, final int dimension) { + private static void logDimensionStatistics(final List statistics, final int dimension, final String segmentId) { for (int i = 0; i < dimension; i++) { SummaryStatistics stats = statistics.get(i); log.info( - "Dimension {} stats: mean={}, std={}, min={}, max={}", + "Segment {} - Dimension {} stats: mean={}, std={}, min={}, max={}", + segmentId, i, stats.getMean(), stats.getStandardDeviation(), @@ -137,4 +144,31 @@ private static void logDimensionStatistics(final List statist ); } } + + /** + * Serializes a SegmentProfilerState to a byte array + * @return + */ + public byte[] toByteArray() { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)) { + + oos.writeObject(this); + return baos.toByteArray(); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize SegmentProfilerStates", e); + } + } + + /** + * Deserializes a SegmentProfilerState from a byte array + * @param bytes + * @return + */ + public static SegmentProfilerState fromBytes(byte[] bytes) { + try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ObjectInputStream ois = new ObjectInputStream(bais)) { + return (SegmentProfilerState) ois.readObject(); + } catch (IOException | ClassNotFoundException e) { + throw new RuntimeException("Failed to deserialize SegmentProfilerState", e); + } + } } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/KNNProfileTransportActionIT.java b/src/test/java/org/opensearch/knn/plugin/transport/KNNProfileTransportActionIT.java new file mode 100644 index 0000000000..4dee801717 --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/transport/KNNProfileTransportActionIT.java @@ -0,0 +1,282 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; +import java.util.TreeMap; +import java.util.List; +import java.util.Random; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.hamcrest.Matchers.closeTo; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; + +@AllArgsConstructor +public class KNNProfileTransportActionIT extends KNNRestTestCase { + private String description; + private SpaceType spaceType; + private static final String FIELD_NAME = "test-field"; + private static final String INDEX_NAME = "test-index"; + + @Before + public void setup() throws Exception {} + + @ParametersFactory(argumentFormatting = "description:%1$s; spaceType:%2$s") + public static Collection parameters() { + return Arrays.asList( + $$( + $("SpaceType L2", SpaceType.L2), + $("SpaceType INNER_PRODUCT", SpaceType.INNER_PRODUCT), + $("SpaceType COSINESIMIL", SpaceType.COSINESIMIL) + ) + ); + } + + @SneakyThrows + public void testProfileEndToEnd() { + final int dimension = 128; + final int numDocs = 1000; + + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + final Settings knnIndexSettings = Settings.builder() + .put("number_of_shards", 2) + .put("number_of_replicas", 0) + .put(KNN_INDEX, true) + .build(); + + createKnnIndex(INDEX_NAME, knnIndexSettings, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + indexTestData(INDEX_NAME, FIELD_NAME, dimension, numDocs); + + Response profileResponse = executeProfileRequest(INDEX_NAME, FIELD_NAME); + validateProfileResponse(profileResponse, dimension); + } + + @SneakyThrows + public void testProfileWithMultipleSegments() { + final int dimension = 4; + final int numDocs = 100; + + createKnnIndex(INDEX_NAME, getIndexSettings(), createKnnIndexMapping(FIELD_NAME, dimension)); + + for (int i = 0; i < numDocs; i++) { + float[] vector = new float[dimension]; + Arrays.fill(vector, (float) i); + addKnnDoc(INDEX_NAME, Integer.toString(i), FIELD_NAME, vector); + + if (i % 10 == 0) { + refreshIndex(INDEX_NAME); + } + } + + refreshIndex(INDEX_NAME); + + Response profileResponse = executeProfileRequest(INDEX_NAME, FIELD_NAME); + Map responseMap = parseResponse(profileResponse); + + Map shardProfiles = (Map) responseMap.get("shard_profiles"); + for (Object shardProfile : shardProfiles.values()) { + List> segments = (List>) ((Map) shardProfile).get("segments"); + assertTrue("Should have multiple segments", segments.size() > 1); + } + } + + private void indexTestData(String indexName, String fieldName, int dimension, int numDocs) throws Exception { + for (int i = 0; i < numDocs; i++) { + float[] vector = new float[dimension]; + Arrays.fill(vector, (float) i); + addKnnDoc(indexName, Integer.toString(i), fieldName, vector); + } + + refreshAllNonSystemIndices(); + assertEquals(numDocs, getDocCount(indexName)); + } + + @SneakyThrows + protected Response executeProfileRequest(String indexName, String fieldName) throws IOException { + Request request = new Request("GET", "/_plugins/_knn/profile/" + indexName + "/" + fieldName); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return response; + } + + private void validateProfileResponse(Response response, int dimension) throws IOException, ParseException { + Map responseMap = parseResponse(response); + + assertNotNull(responseMap.get("total_shards")); + assertNotNull(responseMap.get("successful_shards")); + assertNotNull(responseMap.get("failed_shards")); + + Map shardProfiles = (Map) responseMap.get("shard_profiles"); + assertFalse(shardProfiles.isEmpty()); + + Map clusterAgg = (Map) responseMap.get("cluster_aggregation"); + assertEquals(dimension, clusterAgg.get("dimension")); + + List> dimensions = (List>) clusterAgg.get("dimensions"); + assertEquals(dimension, dimensions.size()); + + for (Map dimStats : dimensions) { + validateDimensionStatistics(dimStats); + } + } + + private void validateDimensionStatistics(Map dimStats) { + assertNotNull("Count should be present", dimStats.get("count")); + assertNotNull("Min should be present", dimStats.get("min")); + assertNotNull("Max should be present", dimStats.get("max")); + assertNotNull("Mean should be present", dimStats.get("mean")); + assertNotNull("Standard deviation should be present", dimStats.get("std_deviation")); + assertNotNull("Variance should be present", dimStats.get("variance")); + } + + private Map parseResponse(Response response) throws IOException, ParseException { + return createParser(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity())).map(); + } + + private Settings getIndexSettings() { + return Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put(KNN_INDEX, true).build(); + } + + @SneakyThrows + public void testProfileWithNormalizedDistribution() { + final int dimension = 4; + final int numDocs = 1000; + final double expectedMean = 0.5; + final double allowedDeviation = 0.05; + + createKnnIndex(INDEX_NAME, getIndexSettings(), createKnnIndexMapping(FIELD_NAME, dimension)); + + Random random = new Random(42); + for (int i = 0; i < numDocs; i++) { + float[] vector = new float[dimension]; + for (int j = 0; j < dimension; j++) { + vector[j] = random.nextFloat(); + } + addKnnDoc(INDEX_NAME, Integer.toString(i), FIELD_NAME, vector); + } + + refreshIndex(INDEX_NAME); + forceMerge(INDEX_NAME); + + Response profileResponse = executeProfileRequest(INDEX_NAME, FIELD_NAME); + Map responseMap = parseResponse(profileResponse); + + Map clusterAgg = (Map) responseMap.get("cluster_aggregation"); + List> dimensions = (List>) clusterAgg.get("dimensions"); + + for (Map dimStats : dimensions) { + double mean = (Double) dimStats.get("mean"); + double min = (Double) dimStats.get("min"); + double max = (Double) dimStats.get("max"); + double stdDev = (Double) dimStats.get("std_deviation"); + + assertTrue(Math.abs(mean - expectedMean) < allowedDeviation); + + assertTrue(min >= 0.0 && min < 1.0); + assertTrue(max > 0.0 && max <= 1.0); + + assertTrue(Math.abs(stdDev - 0.289) < allowedDeviation); + } + } + + @SneakyThrows + public void testExplicitSegmentAggregation() { + final int dimension = 4; + + createKnnIndex(INDEX_NAME, getIndexSettings(), createKnnIndexMapping(FIELD_NAME, dimension)); + + float[] firstSegmentValue = new float[] { 0.1f, 0.2f, 0.3f, 0.4f }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, firstSegmentValue); + refreshIndex(INDEX_NAME); + + float[] secondSegmentValue = new float[] { 0.5f, 0.6f, 0.7f, 0.8f }; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, secondSegmentValue); + refreshIndex(INDEX_NAME); + + Response profileResponse = executeProfileRequest(INDEX_NAME, FIELD_NAME); + Map responseMap = parseResponse(profileResponse); + + Map shardProfiles = (Map) responseMap.get("shard_profiles"); + Map shardProfile = (Map) shardProfiles.get("0"); + List> segments = (List>) shardProfile.get("segments"); + + assertEquals(2, segments.size()); + + Map aggregated = (Map) shardProfile.get("aggregated"); + List> dimensions = (List>) aggregated.get("dimensions"); + + for (int i = 0; i < dimension; i++) { + Map dimStats = dimensions.get(i); + + double expectedMin = Math.min(firstSegmentValue[i], secondSegmentValue[i]); + double expectedMax = Math.max(firstSegmentValue[i], secondSegmentValue[i]); + double expectedMean = (firstSegmentValue[i] + secondSegmentValue[i]) / 2.0; + + assertEquals(2L, dimStats.get("count")); + assertThat((Double) dimStats.get("min"), closeTo(expectedMin, 0.0001)); + assertThat((Double) dimStats.get("max"), closeTo(expectedMax, 0.0001)); + assertThat((Double) dimStats.get("mean"), closeTo(expectedMean, 0.0001)); + } + } + + private void forceMerge(String indexName) throws IOException { + Request request = new Request("POST", "/" + indexName + "/_forcemerge"); + request.addParameter("max_num_segments", "1"); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/transport/KNNProfileTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/KNNProfileTransportActionTests.java new file mode 100644 index 0000000000..1afa2a8b5f --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/transport/KNNProfileTransportActionTests.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; +import org.junit.Before; +import org.opensearch.core.action.support.DefaultShardOperationFailedException; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.profiler.SegmentProfilerState; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.closeTo; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +public class KNNProfileTransportActionTests extends KNNTestCase { + + private static final int DIMENSION = 4; + private static final double DELTA = 0.0001; + + private List shardProfileResults; + private KNNProfileResponse response; + + @Before + public void setup() { + shardProfileResults = new ArrayList<>(); + + for (int shardId = 0; shardId < 2; shardId++) { + List segmentStates = new ArrayList<>(); + + for (int segId = 0; segId < 2; segId++) { + List stats = new ArrayList<>(); + + for (int dim = 0; dim < DIMENSION; dim++) { + SummaryStatistics dimStats = new SummaryStatistics(); + dimStats.addValue(1.0 + shardId + segId); + dimStats.addValue(2.0 + shardId + segId); + dimStats.addValue(3.0 + shardId + segId); + stats.add(dimStats); + } + + segmentStates.add(new SegmentProfilerState(stats, DIMENSION, "_" + segId)); + } + + shardProfileResults.add(new KNNIndexShardProfileResult(segmentStates, String.valueOf(shardId))); + } + + response = new KNNProfileResponse(shardProfileResults, 2, 2, 0, new ArrayList<>()); + } + + public void testShardAggregation() throws IOException { + XContentBuilder builder = jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + Map responseMap = createParser(builder).map(); + + Map shardProfiles = (Map) responseMap.get("shard_profiles"); + + for (int shardId = 0; shardId < 2; shardId++) { + Map shardProfile = (Map) shardProfiles.get(String.valueOf(shardId)); + Map aggregated = (Map) shardProfile.get("aggregated"); + List> dimensions = (List>) aggregated.get("dimensions"); + + for (int dim = 0; dim < DIMENSION; dim++) { + Map dimStats = dimensions.get(dim); + + double expectedMin = 1.0 + shardId; + double expectedMax = 3.0 + shardId + 1; + double expectedMean = (expectedMin + expectedMax) / 2.0; + + assertEquals(6, dimStats.get("count")); + assertThat((Double) dimStats.get("min"), closeTo(expectedMin, DELTA)); + assertThat((Double) dimStats.get("max"), closeTo(expectedMax, DELTA)); + assertThat((Double) dimStats.get("mean"), closeTo(expectedMean, DELTA)); + } + } + } + + public void testClusterAggregation() throws IOException { + XContentBuilder builder = jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + Map responseMap = createParser(builder).map(); + + Map clusterAgg = (Map) responseMap.get("cluster_aggregation"); + List> dimensions = (List>) clusterAgg.get("dimensions"); + + for (int dim = 0; dim < DIMENSION; dim++) { + Map dimStats = dimensions.get(dim); + assertEquals(12, dimStats.get("count")); + assertThat((Double) dimStats.get("min"), closeTo(1.0, DELTA)); + assertThat((Double) dimStats.get("max"), closeTo(5.0, DELTA)); + assertThat((Double) dimStats.get("mean"), closeTo(3.0, DELTA)); + } + } + + public void testEmptyResponse() throws IOException { + KNNProfileResponse emptyResponse = new KNNProfileResponse(new ArrayList<>(), 0, 0, 0, new ArrayList<>()); + + XContentBuilder builder = jsonBuilder(); + emptyResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + Map responseMap = createParser(builder).map(); + + assertEquals(Integer.valueOf(0), responseMap.get("total_shards")); + assertEquals(Integer.valueOf(0), responseMap.get("successful_shards")); + assertEquals(Integer.valueOf(0), responseMap.get("failed_shards")); + assertTrue(((Map) responseMap.get("shard_profiles")).isEmpty()); + } + + public void testFailedShards() throws IOException { + List failures = new ArrayList<>(); + failures.add(new DefaultShardOperationFailedException("test_index", 0, new RuntimeException("Test failure"))); + + KNNProfileResponse responseWithFailures = new KNNProfileResponse(shardProfileResults, 2, 1, 1, failures); + + XContentBuilder builder = jsonBuilder(); + responseWithFailures.toXContent(builder, ToXContent.EMPTY_PARAMS); + Map responseMap = createParser(builder).map(); + + assertEquals(Integer.valueOf(2), responseMap.get("total_shards")); + assertEquals(Integer.valueOf(1), responseMap.get("successful_shards")); + assertEquals(Integer.valueOf(1), responseMap.get("failed_shards")); + + List> failuresList = (List>) responseMap.get("failures"); + assertEquals(1, failuresList.size()); + assertEquals("test_index", failuresList.get(0).get("index")); + } +} diff --git a/src/test/java/org/opensearch/knn/profiler/SegmentProfilerStateTests.java b/src/test/java/org/opensearch/knn/profiler/SegmentProfilerStateTests.java index bed64f8701..e29023e0f2 100644 --- a/src/test/java/org/opensearch/knn/profiler/SegmentProfilerStateTests.java +++ b/src/test/java/org/opensearch/knn/profiler/SegmentProfilerStateTests.java @@ -23,6 +23,7 @@ public class SegmentProfilerStateTests extends OpenSearchTestCase { private KNNVectorValues mockVectorValues; private Supplier> mockSupplier; + private static final String TEST_SEGMENT_ID = "test_segment"; @Before public void setUp() throws Exception { @@ -34,36 +35,47 @@ public void setUp() throws Exception { public void testConstructor() { List statistics = new ArrayList<>(); statistics.add(new SummaryStatistics()); + int dimension = 1; - SegmentProfilerState state = new SegmentProfilerState(statistics); + SegmentProfilerState state = new SegmentProfilerState(statistics, dimension, TEST_SEGMENT_ID); assertEquals(statistics, state.getStatistics()); + assertEquals(dimension, state.getDimension()); + assertEquals(TEST_SEGMENT_ID, state.getSegmentId()); } public void testProfileVectorsWithNullVectorValues() throws IOException { Supplier> nullSupplier = () -> null; - SegmentProfilerState state = SegmentProfilerState.profileVectors(nullSupplier); + SegmentProfilerState state = SegmentProfilerState.profileVectors(nullSupplier, TEST_SEGMENT_ID); assertTrue(state.getStatistics().isEmpty()); + assertEquals(0, state.getDimension()); + assertEquals(TEST_SEGMENT_ID, state.getSegmentId()); } public void testProfileVectorsWithNoDocuments() throws IOException { when(mockVectorValues.docId()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + when(mockVectorValues.dimension()).thenReturn(3); - SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier, TEST_SEGMENT_ID); assertTrue(state.getStatistics().isEmpty()); + assertEquals(3, state.getDimension()); + assertEquals(TEST_SEGMENT_ID, state.getSegmentId()); } public void testProfileVectorsWithSingleFloatVector() throws IOException { float[] vector = new float[] { 1.0f, 2.0f, 3.0f }; + int dimension = 3; when(mockVectorValues.docId()).thenReturn(0); - when(mockVectorValues.dimension()).thenReturn(3); + when(mockVectorValues.dimension()).thenReturn(dimension); when(mockVectorValues.getVector()).thenReturn(vector); when(mockVectorValues.nextDoc()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); - SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier, TEST_SEGMENT_ID); - assertEquals(3, state.getStatistics().size()); + assertEquals(dimension, state.getDimension()); + assertEquals(dimension, state.getStatistics().size()); + assertEquals(TEST_SEGMENT_ID, state.getSegmentId()); assertEquals(1.0, state.getStatistics().get(0).getMean(), 0.001); assertEquals(2.0, state.getStatistics().get(1).getMean(), 0.001); assertEquals(3.0, state.getStatistics().get(2).getMean(), 0.001); @@ -71,86 +83,55 @@ public void testProfileVectorsWithSingleFloatVector() throws IOException { public void testProfileVectorsWithSingleByteVector() throws IOException { byte[] vector = new byte[] { 1, 2, 3 }; + int dimension = 3; when(mockVectorValues.docId()).thenReturn(0); - when(mockVectorValues.dimension()).thenReturn(3); + when(mockVectorValues.dimension()).thenReturn(dimension); when(mockVectorValues.getVector()).thenReturn(vector); when(mockVectorValues.nextDoc()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); - SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier, TEST_SEGMENT_ID); - assertEquals(3, state.getStatistics().size()); + assertEquals(dimension, state.getDimension()); + assertEquals(dimension, state.getStatistics().size()); + assertEquals(TEST_SEGMENT_ID, state.getSegmentId()); assertEquals(1.0, state.getStatistics().get(0).getMean(), 0.001); assertEquals(2.0, state.getStatistics().get(1).getMean(), 0.001); assertEquals(3.0, state.getStatistics().get(2).getMean(), 0.001); } - public void testProfileVectorsWithMultipleFloatVectors() throws IOException { - float[] vector1 = new float[] { 1.0f, 2.0f }; - float[] vector2 = new float[] { 3.0f, 4.0f }; - - when(mockVectorValues.docId()).thenReturn(0); - when(mockVectorValues.dimension()).thenReturn(2); - when(mockVectorValues.getVector()).thenReturn(vector1).thenReturn(vector2); - when(mockVectorValues.nextDoc()).thenReturn(1).thenReturn(DocIdSetIterator.NO_MORE_DOCS); - - SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); - - assertEquals(2, state.getStatistics().size()); - assertEquals(2.0, state.getStatistics().get(0).getMean(), 0.001); - assertEquals(3.0, state.getStatistics().get(1).getMean(), 0.001); - } - - public void testProfileVectorsWithMultipleByteVectors() throws IOException { - byte[] vector1 = new byte[] { 1, 2 }; - byte[] vector2 = new byte[] { 3, 4 }; - - when(mockVectorValues.docId()).thenReturn(0); - when(mockVectorValues.dimension()).thenReturn(2); - when(mockVectorValues.getVector()).thenReturn(vector1).thenReturn(vector2); - when(mockVectorValues.nextDoc()).thenReturn(1).thenReturn(DocIdSetIterator.NO_MORE_DOCS); - - SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); - - assertEquals(2, state.getStatistics().size()); - assertEquals(2.0, state.getStatistics().get(0).getMean(), 0.001); - assertEquals(3.0, state.getStatistics().get(1).getMean(), 0.001); + public void testSerializationDeserialization() { + List statistics = new ArrayList<>(); + SummaryStatistics stats = new SummaryStatistics(); + stats.addValue(1.0); + stats.addValue(2.0); + statistics.add(stats); + + SegmentProfilerState originalState = new SegmentProfilerState(statistics, 1, TEST_SEGMENT_ID); + byte[] serialized = originalState.toByteArray(); + SegmentProfilerState deserializedState = SegmentProfilerState.fromBytes(serialized); + + assertEquals(originalState.getDimension(), deserializedState.getDimension()); + assertEquals(originalState.getSegmentId(), deserializedState.getSegmentId()); + assertEquals(originalState.getStatistics().size(), deserializedState.getStatistics().size()); + assertEquals(originalState.getStatistics().get(0).getMean(), deserializedState.getStatistics().get(0).getMean(), 0.001); } public void testProfileVectorsStatisticalValues() throws IOException { float[] vector1 = new float[] { 1.0f, 2.0f }; float[] vector2 = new float[] { 3.0f, 4.0f }; float[] vector3 = new float[] { 5.0f, 6.0f }; + int dimension = 2; when(mockVectorValues.docId()).thenReturn(0); - when(mockVectorValues.dimension()).thenReturn(2); + when(mockVectorValues.dimension()).thenReturn(dimension); when(mockVectorValues.getVector()).thenReturn(vector1).thenReturn(vector2).thenReturn(vector3); when(mockVectorValues.nextDoc()).thenReturn(1).thenReturn(2).thenReturn(DocIdSetIterator.NO_MORE_DOCS); - SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); - - assertEquals(3.0, state.getStatistics().get(0).getMean(), 0.001); - assertEquals(2.0, state.getStatistics().get(0).getStandardDeviation(), 0.001); - assertEquals(1.0, state.getStatistics().get(0).getMin(), 0.001); - assertEquals(5.0, state.getStatistics().get(0).getMax(), 0.001); - - assertEquals(4.0, state.getStatistics().get(1).getMean(), 0.001); - assertEquals(2.0, state.getStatistics().get(1).getStandardDeviation(), 0.001); - assertEquals(2.0, state.getStatistics().get(1).getMin(), 0.001); - assertEquals(6.0, state.getStatistics().get(1).getMax(), 0.001); - } - - public void testProfileVectorsWithByteStatisticalValues() throws IOException { - byte[] vector1 = new byte[] { 1, 2 }; - byte[] vector2 = new byte[] { 3, 4 }; - byte[] vector3 = new byte[] { 5, 6 }; - - when(mockVectorValues.docId()).thenReturn(0); - when(mockVectorValues.dimension()).thenReturn(2); - when(mockVectorValues.getVector()).thenReturn(vector1).thenReturn(vector2).thenReturn(vector3); - when(mockVectorValues.nextDoc()).thenReturn(1).thenReturn(2).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier, TEST_SEGMENT_ID); - SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + assertEquals(dimension, state.getDimension()); + assertEquals(TEST_SEGMENT_ID, state.getSegmentId()); assertEquals(3.0, state.getStatistics().get(0).getMean(), 0.001); assertEquals(2.0, state.getStatistics().get(0).getStandardDeviation(), 0.001);