Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.19...2.x)
### Features
* [Vector Profiler] Adding basic generic vector profiler implementation and tests. [#2624](https://github.com/opensearch-project/k-NN/pull/2624)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we adding these changes? Weren't they already merged?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, updated branch to adjust with changes already implemented in feature branch

* [Vector Profiler] Adding serializer and api implementation for segment profiler state. [#2687](https://github.com/opensearch-project/k-NN/pull/2687)

### Enhancements
### Bug Fixes
* [BUGFIX] FIX nested vector query at efficient filter scenarios [#2641](https://github.com/opensearch-project/k-NN/pull/2641)
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ dependencies {
testFixturesImplementation group: 'net.minidev', name: 'json-smart', version: "${versions.json_smart}"
testFixturesImplementation "org.opensearch:common-utils:${version}"
implementation 'com.github.oshi:oshi-core:6.4.13'
implementation 'org.apache.commons:commons-math3:3.6.1'
api "net.java.dev.jna:jna:${versions.jna}"
api "net.java.dev.jna:jna-platform:${versions.jna}"
// OpenSearch core is using slf4j 1.7.36. Therefore, we cannot change the version here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public class KNNConstants {
public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String MODEL_VERSION = "model_version";
public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate";
public static final String SEGMENT_PROFILE_STATE_FILE_SUFFIX = "segpfstate";

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
Expand Down
39 changes: 39 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.SegmentProfilerUtil;
import org.opensearch.knn.plugin.transport.KNNIndexShardProfileResult;
import org.opensearch.knn.profiler.SegmentProfilerState;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -82,6 +85,42 @@ public String getIndexName() {
return indexShard.shardId().getIndexName();
}

public KNNIndexShardProfileResult profile(final String field) {
try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-profile")) {
log.info("[KNN] Profiling field [{}] in index [{}]", field, getIndexName());

List<SegmentProfilerState> segmentLevelProfilerStates = new ArrayList<>();

// For each leaf, collect the profile
for (LeafReaderContext leaf : searcher.getIndexReader().leaves()) {
try {
SegmentProfilerState state = SegmentProfilerUtil.getSegmentProfileState(leaf.reader(), field);
segmentLevelProfilerStates.add(state);
log.debug("[KNN] Successfully profiled segment for field [{}]", field);
} catch (IOException e) {
log.error("[KNN] Failed to profile segment for field [{}]: {}", field, e.getMessage(), e);
throw new RuntimeException("Failed to profile segment: " + e.getMessage(), e);
}
}

if (segmentLevelProfilerStates.isEmpty()) {
log.warn("[KNN] No segments found with field [{}] in index [{}]", field, getIndexName());
} else {
log.info(
"[KNN] Successfully profiled [{}] segments for field [{}] in index [{}]",
segmentLevelProfilerStates.size(),
field,
getIndexName()
);
}

return new KNNIndexShardProfileResult(segmentLevelProfilerStates, indexShard.shardId().toString());
} catch (Exception e) {
log.error("[KNN] Error profiling field [{}] in index [{}]: {}", field, getIndexName(), e.getMessage(), e);
throw new RuntimeException("Error during profiling: " + e.getMessage(), e);
}
}

/**
* Load all of the k-NN segments for this shard into the cache.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IndexOutput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.profiler.SegmentProfilerState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
Expand Down Expand Up @@ -85,6 +86,20 @@ public void writeState(int fieldNumber, QuantizationState quantizationState) thr
fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position));
}

/**
* Writes a segment profile state as bytes
*
* @param fieldNumber field number
* @param segmentProfilerState segment profiler state
* @throws IOException could be thrown while writing
*/
public void writeState(int fieldNumber, SegmentProfilerState segmentProfilerState) throws IOException {
byte[] stateBytes = segmentProfilerState.toByteArray();
long position = output.getFilePointer();
output.writeBytes(stateBytes, stateBytes.length);
fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position));
}

/**
* Writes index footer and other index information for parsing later
* @throws IOException could be thrown while writing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.memoryoptsearch.VectorSearcher;
import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory;
import org.opensearch.knn.profiler.KNN990ProfileStateReader;
import org.opensearch.knn.profiler.SegmentProfileKNNCollector;
import org.opensearch.knn.profiler.SegmentProfileStateReadConfig;
import org.opensearch.knn.profiler.SegmentProfilerState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;
Expand Down Expand Up @@ -163,6 +167,14 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
return;
}

if (knnCollector instanceof SegmentProfileKNNCollector) {
SegmentProfilerState segmentProfileState = KNN990ProfileStateReader.read(
new SegmentProfileStateReadConfig(segmentReadState, field)
);
((SegmentProfileKNNCollector) knnCollector).setSegmentProfilerState(segmentProfileState);
return;
}

if (trySearchWithMemoryOptimizedSearch(field, target, knnCollector, acceptDocs, true)) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.profiler.SegmentProfilerState;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

Expand All @@ -51,6 +52,7 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
private final SegmentWriteState segmentWriteState;
private final FlatVectorsWriter flatVectorsWriter;
private KNN990QuantizationStateWriter quantizationStateWriter;
private KNN990QuantizationStateWriter segmentStateWriter;
private final List<NativeEngineFieldVectorsWriter<?>> fields = new ArrayList<>();
private boolean finished;
private final Integer approximateThreshold;
Expand Down Expand Up @@ -107,6 +109,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
profile(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
// should skip graph building only for non quantization use case and if threshold is met
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.debug(
Expand Down Expand Up @@ -150,9 +153,13 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
}

final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);

// Write the segment profile state to the directory
profile(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);

// should skip graph building only for non quantization use case and if threshold is met
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.debug(
log.info(
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge",
fieldInfo.name,
totalLiveDocs,
Expand Down Expand Up @@ -188,6 +195,10 @@ public void finish() throws IOException {
if (quantizationStateWriter != null) {
quantizationStateWriter.writeFooter();
}
if (segmentStateWriter != null) {
segmentStateWriter.writeFooter();
}

flatVectorsWriter.finish();
}

Expand All @@ -209,6 +220,9 @@ public void close() throws IOException {
if (quantizationStateWriter != null) {
quantizationStateWriter.closeOutput();
}
if (segmentStateWriter != null) {
segmentStateWriter.closeOutput();
}
IOUtils.close(flatVectorsWriter);
}

Expand Down Expand Up @@ -241,6 +255,23 @@ private QuantizationState train(
return quantizationState;
}

private SegmentProfilerState profile(
final FieldInfo fieldInfo,
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
final int totalLiveDocs
) throws IOException {

SegmentProfilerState segmentProfilerState = null;
if (totalLiveDocs > 0) {
initSegmentStateWriterIfNecessary();
String segmentId = segmentWriteState.segmentInfo.name;
SegmentProfilerState profileResultForSegment = SegmentProfilerState.profileVectors(knnVectorValuesSupplier, segmentId);
segmentStateWriter.writeState(fieldInfo.getFieldNumber(), profileResultForSegment);
}

return segmentProfilerState;
}

/**
* The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the
* vectorsValues object which you plan to use later
Expand All @@ -263,6 +294,13 @@ private void initQuantizationStateWriterIfNecessary() throws IOException {
}
}

private void initSegmentStateWriterIfNecessary() throws IOException {
if (segmentStateWriter == null) {
segmentStateWriter = new KNN990QuantizationStateWriter(segmentWriteState);
segmentStateWriter.writeHeader(segmentWriteState);
}
}

private boolean shouldSkipBuildingVectorDataStructure(final long docCount) {
if (approximateThreshold < 0) {
return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query;

import lombok.experimental.UtilityClass;
import org.apache.lucene.index.LeafReader;
import org.opensearch.knn.profiler.SegmentProfileKNNCollector;
import org.opensearch.knn.profiler.SegmentProfilerState;

import java.io.IOException;
import java.util.Locale;

/**
* Utility class to get segment profiler state for a given field
*/
@UtilityClass
public class SegmentProfilerUtil {

/**
* Gets the segment profile state for a given field
* @param leafReader The leaf reader to query
* @param fieldName The field name to profile
* @return The segment profiler state
* @throws IOException If there's an error reading the segment
*/
public static SegmentProfilerState getSegmentProfileState(final LeafReader leafReader, String fieldName) throws IOException {
final SegmentProfileKNNCollector tempCollector = new SegmentProfileKNNCollector();
leafReader.searchNearestVectors(fieldName, new float[0], tempCollector, null);
if (tempCollector.getSegmentProfilerState() == null) {
throw new IllegalStateException(String.format(Locale.ROOT, "No segment state found for field %s", fieldName));
}
return tempCollector.getSegmentProfilerState();
}
}
6 changes: 6 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction;
import org.opensearch.knn.profiler.RestKNNProfileHandler;
import org.opensearch.knn.plugin.transport.KNNProfileTransportAction;
import org.opensearch.knn.plugin.transport.KNNProfileAction;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache;
import org.opensearch.knn.training.TrainingJobClusterStateListener;
import org.opensearch.knn.training.TrainingJobRunner;
Expand Down Expand Up @@ -269,6 +272,7 @@ public List<RestHandler> getRestHandlers(
) {

RestKNNStatsHandler restKNNStatsHandler = new RestKNNStatsHandler();
RestKNNProfileHandler restKNNProfileHandler = new RestKNNProfileHandler();
RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler(
settings,
restController,
Expand All @@ -284,6 +288,7 @@ public List<RestHandler> getRestHandlers(
return ImmutableList.of(
restKNNStatsHandler,
restKNNWarmupHandler,
restKNNProfileHandler,
restGetModelHandler,
restDeleteModelHandler,
restTrainModelHandler,
Expand All @@ -300,6 +305,7 @@ public List<RestHandler> getRestHandlers(
return Arrays.asList(
new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class),
new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class),
new ActionHandler<>(KNNProfileAction.INSTANCE, KNNProfileTransportAction.class),
new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class),
new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, TrainingJobRouteDecisionInfoTransportAction.class),
new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.transport;

import lombok.AllArgsConstructor;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.knn.profiler.SegmentProfilerState;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

@AllArgsConstructor
public class KNNIndexShardProfileResult implements Writeable {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need deserialization logic with this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may be useful especially when compressing vectors.

List<SegmentProfilerState> segmentProfilerStateList;
String shardId;

/**
* Constructor for reading from StreamInput
*/
public KNNIndexShardProfileResult(StreamInput streamInput) throws IOException {
this.shardId = streamInput.readString();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever use this constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it was used previously in KNNProfileResponse, but when removing the logic and the above constructor I was able to validate the output was still working.

int size = streamInput.readInt();

this.segmentProfilerStateList = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
byte[] stateBytes = streamInput.readByteArray();
segmentProfilerStateList.add(SegmentProfilerState.fromBytes(stateBytes));
}
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
streamOutput.writeString(shardId);

// Write the segment profiler state list size
streamOutput.writeInt(segmentProfilerStateList.size());

// Write each SegmentProfilerState as bytes
for (SegmentProfilerState state : segmentProfilerStateList) {
byte[] stateBytes = state.toByteArray();
streamOutput.writeByteArray(stateBytes);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.transport;

import org.opensearch.action.ActionType;
import org.opensearch.core.common.io.stream.Writeable;

public class KNNProfileAction extends ActionType<KNNProfileResponse> {

public static final KNNProfileAction INSTANCE = new KNNProfileAction();
public static final String NAME = "cluster:admin/knn_profile_action";

/**
* Constructor
*/
private KNNProfileAction() {
super(NAME, KNNProfileResponse::new);
}

@Override
public Writeable.Reader<KNNProfileResponse> getResponseReader() {
return KNNProfileResponse::new;
}
}
Loading