Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.19...2.x)
### Features
* [Vector Profiler] Adding basic generic vector profiler implementation and tests. [#2624](https://github.com/opensearch-project/k-NN/pull/2624)

### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ dependencies {
testFixturesImplementation group: 'net.minidev', name: 'json-smart', version: "${versions.json_smart}"
testFixturesImplementation "org.opensearch:common-utils:${version}"
implementation 'com.github.oshi:oshi-core:6.4.13'
implementation 'org.apache.commons:commons-math3:3.6.1'
api "net.java.dev.jna:jna:${versions.jna}"
api "net.java.dev.jna:jna-platform:${versions.jna}"
// OpenSearch core is using slf4j 1.7.36. Therefore, we cannot change the version here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public class KNNConstants {
public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String MODEL_VERSION = "model_version";
public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate";
public static final String SEGMENT_PROFILE_STATE_FILE_SUFFIX = "segpfstate";

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
Expand Down
46 changes: 46 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.math3.stat.descriptive.AggregateSummaryStatistics;
import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -28,6 +31,8 @@
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.SegmentLevelQuantizationUtil;
import org.opensearch.knn.profiler.SegmentProfilerState;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -82,6 +87,47 @@ public String getIndexName() {
return indexShard.shardId().getIndexName();
}

public void profile() {
try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-warmup")) {

List<SegmentProfilerState> segmentLevelProfilerStates = new ArrayList<>();

//TODO: Must specify field name as API input - leave it as a constant for now
//For each leaf, collect the profile
searcher.getIndexReader().leaves().forEach(leaf -> {
try {
segmentLevelProfilerStates.add(SegmentLevelQuantizationUtil.getSegmentProfileState(leaf.reader(), "my_vector_field"));
} catch (IOException e) {
//TODO:Better Exception Handling
throw new RuntimeException(e);
}
});

//Aggregate profile per field/dimension for the shard
List<StatisticalSummaryValues> shardVectorProfile = new ArrayList<>();

//TODO: See if there's a better way to get the dimension other than the first element
//Transpose our list to aggregate per dimension
for (int i = 0; i < segmentLevelProfilerStates.getFirst().getDimension(); i++) {
int dimensionId = i;
List<SummaryStatistics> transposed = segmentLevelProfilerStates.stream()
.map(state -> state.getStatistics().get(dimensionId))
.toList();

shardVectorProfile.add(AggregateSummaryStatistics.aggregate(transposed));
}

//TODO: Return this as a API call instead of logging
for (StatisticalSummaryValues statisticalSummaryValues : shardVectorProfile) {

//Use the toString for now
log.info(statisticalSummaryValues.toString());
}

//TODO: Write unit tests to ensure that the segment statistic aggregation is correct.
}
}

/**
* Load all of the k-NN segments for this shard into the cache.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.profiler.SegmentProfileStateReadConfig;
import org.opensearch.knn.profiler.SegmentProfilerState;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
Expand Down Expand Up @@ -94,6 +96,45 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr
}
}

/**
* TODO: Refactor to separate non-quantization stuff
*/
public static SegmentProfilerState read(SegmentProfileStateReadConfig readConfig) throws IOException {
SegmentReadState segmentReadState = readConfig.getSegmentReadState();
String field = readConfig.getField();
String quantizationStateFileName = getQuantizationStateFileName(segmentReadState);
int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber();

try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.DEFAULT)) {

CodecUtil.retrieveChecksum(input);
int numFields = getNumFields(input);

long position = -1;
int length = 0;

// Read each field's metadata from the index section, break when correct field is found
for (int i = 0; i < numFields; i++) {
int tempFieldNumber = input.readInt();
int tempLength = input.readInt();
long tempPosition = input.readVLong();
if (tempFieldNumber == fieldNumber) {
position = tempPosition;
length = tempLength;
break;
}
}

if (position == -1 || length == 0) {
throw new IllegalArgumentException(String.format("Field %s not found", field));
}

byte[] stateBytes = readStateBytes(input, position, length);

return SegmentProfilerState.fromBytes(stateBytes);
}
}

@VisibleForTesting
static int getNumFields(IndexInput input) throws IOException {
long footerStart = input.length() - CodecUtil.footerLength();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IndexOutput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.profiler.SegmentProfilerState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
Expand Down Expand Up @@ -85,6 +86,20 @@ public void writeState(int fieldNumber, QuantizationState quantizationState) thr
fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position));
}

/**
* Writes a segment profile state as bytes
*
* @param fieldNumber field number
* @param quantizationState quantization state
* @throws IOException could be thrown while writing
*/
public void writeState(int fieldNumber, SegmentProfilerState segmentProfilerState) throws IOException {
byte[] stateBytes = segmentProfilerState.toByteArray();
long position = output.getFilePointer();
output.writeBytes(stateBytes, stateBytes.length);
fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position));
}

/**
* Writes index footer and other index information for parsing later
* @throws IOException could be thrown while writing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.memoryoptsearch.VectorSearcher;
import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory;
import org.opensearch.knn.profiler.SegmentProfileKNNCollector;
import org.opensearch.knn.profiler.SegmentProfileStateReadConfig;
import org.opensearch.knn.profiler.SegmentProfilerState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;
Expand Down Expand Up @@ -163,6 +166,12 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
return;
}

if (knnCollector instanceof SegmentProfileKNNCollector) {
SegmentProfilerState segmentProfileState = KNN990QuantizationStateReader.read(new SegmentProfileStateReadConfig(segmentReadState, field));
((SegmentProfileKNNCollector) knnCollector).setSegmentProfilerState(segmentProfileState);
return;
}

if (trySearchWithMemoryOptimizedSearch(field, target, knnCollector, acceptDocs, true)) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.profiler.SegmentProfilerState;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

Expand Down Expand Up @@ -107,6 +108,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
profile(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
// should skip graph building only for non quantization use case and if threshold is met
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.info(
Expand Down Expand Up @@ -150,6 +152,10 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
}

final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);

//Write the segment profile state to the directory
profile(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);

// should skip graph building only for non quantization use case and if threshold is met
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.info(
Expand Down Expand Up @@ -188,6 +194,7 @@ public void finish() throws IOException {
if (quantizationStateWriter != null) {
quantizationStateWriter.writeFooter();
}

flatVectorsWriter.finish();
}

Expand Down Expand Up @@ -241,6 +248,23 @@ private QuantizationState train(
return quantizationState;
}

private SegmentProfilerState profile(
final FieldInfo fieldInfo,
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
final int totalLiveDocs
) throws IOException {

SegmentProfilerState segmentProfilerState = null;
if (totalLiveDocs > 0) {
//TODO:Refactor to another init
initQuantizationStateWriterIfNecessary();
SegmentProfilerState profileResultForSegment = SegmentProfilerState.profileVectors(knnVectorValuesSupplier);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), profileResultForSegment);
}

return segmentProfilerState;
}

/**
* The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the
* vectorsValues object which you plan to use later
Expand All @@ -263,6 +287,7 @@ private void initQuantizationStateWriterIfNecessary() throws IOException {
}
}


private boolean shouldSkipBuildingVectorDataStructure(final long docCount) {
if (approximateThreshold < 0) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import org.apache.lucene.index.LeafReader;
import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.profiler.SegmentProfileKNNCollector;
import org.opensearch.knn.profiler.SegmentProfilerState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
Expand Down Expand Up @@ -57,4 +59,24 @@ static QuantizationState getQuantizationState(final LeafReader leafReader, Strin
}
return tempCollector.getQuantizationState();
}

/**
* A utility function to get {@link SegmentProfilerState} for a given segment and field.
* This needs to public as we are accessing this on a transport action
* TODO: move this out of this Util class and into another one.
* @param leafReader {@link LeafReader}
* @param fieldName {@link String}
* @return {@link SegmentProfilerState}
* @throws IOException exception during reading the {@link SegmentProfilerState}
*/
public static SegmentProfilerState getSegmentProfileState(final LeafReader leafReader, String fieldName) throws IOException {
final SegmentProfileKNNCollector tempCollector = new SegmentProfileKNNCollector();
leafReader.searchNearestVectors(fieldName, new float[0], tempCollector, null);
if (tempCollector.getSegmentProfilerState() == null) {
throw new IllegalStateException(String.format(Locale.ROOT, "No segment state found for field %s", fieldName));
}
return tempCollector.getSegmentProfilerState();
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ protected EmptyResult shardOperation(KNNWarmupRequest request, ShardRouting shar
KNNIndexShard knnIndexShard = new KNNIndexShard(
indicesService.indexServiceSafe(shardRouting.shardId().getIndex()).getShard(shardRouting.shardId().id())
);

knnIndexShard.warmup();

//TODO: Move this to our own transport action.
knnIndexShard.profile();

//TODO: Move this to another TransportAction and don't use warmup

return EmptyResult.INSTANCE;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.profiler;

import lombok.Getter;
import lombok.Setter;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;

@Setter
@Getter
public class SegmentProfileKNNCollector implements KnnCollector {

private SegmentProfilerState segmentProfilerState;

private final String NATIVE_ENGINE_SEARCH_ERROR_MESSAGE = "Search functionality using codec is not supported with Native Engine Reader";

@Override
public boolean earlyTerminated() {
throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE);
}

@Override
public void incVisitedCount(int i) {
throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE);
}

@Override
public long visitedCount() {
throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE);
}

@Override
public long visitLimit() {
throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE);
}

@Override
public int k() {
throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE);
}

@Override
public boolean collect(int i, float v) {
throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE);
}

@Override
public float minCompetitiveSimilarity() {
throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE);
}

@Override
public TopDocs topDocs() {
throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.profiler;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.lucene.index.SegmentReadState;

@Getter
@AllArgsConstructor
public class SegmentProfileStateReadConfig {
private SegmentReadState segmentReadState;
private String field;
}
Loading
Loading