Skip to content

Commit f1a1765

Browse files
committed
merge main; add analyzer impl
Signed-off-by: zhichao-aws <[email protected]>
1 parent b084838 commit f1a1765

File tree

10 files changed

+440
-6
lines changed

10 files changed

+440
-6
lines changed

build.gradle

+2
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ dependencies {
259259
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
260260
testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
261261
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
262+
implementation group: 'ai.djl', name: 'api', version: '0.28.0'
263+
implementation group: 'ai.djl.huggingface', name: 'tokenizers', version: '0.28.0'
262264
// ml-common excluded reflection for runtime so we need to add it by ourselves.
263265
// https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9
264266
// TODO: Remove following three lines of dependencies if ml-common include them in their jar
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.analysis;
6+
7+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
8+
import ai.djl.util.Utils;
9+
10+
import java.io.BufferedReader;
11+
import java.io.IOException;
12+
import java.io.InputStream;
13+
import java.io.InputStreamReader;
14+
import java.nio.charset.StandardCharsets;
15+
import java.nio.file.Path;
16+
import java.security.AccessController;
17+
import java.security.PrivilegedActionException;
18+
import java.security.PrivilegedExceptionAction;
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
import java.util.concurrent.Callable;
22+
23+
public class DJLUtils {
24+
static private Path ML_CACHE_PATH;
25+
static private String ML_CACHE_DIR_NAME = "ml_cache";
26+
static private String HUGGING_FACE_BASE_URL = "https://huggingface.co/";
27+
static private String HUGGING_FACE_RESOLVE_PATH = "resolve/main/";
28+
29+
static public void buildDJLCachePath(Path opensearchDataFolder) {
30+
// the logic to build cache path is consistent with ml-commons plugin
31+
// see
32+
// https://github.com/opensearch-project/ml-commons/blob/14b971214c488aa3f4ab150d1a6cc379df1758be/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java#L53
33+
ML_CACHE_PATH = opensearchDataFolder.resolve(ML_CACHE_DIR_NAME);
34+
}
35+
36+
public static <T> T withDJLContext(Callable<T> action) throws PrivilegedActionException {
37+
return AccessController.doPrivileged((PrivilegedExceptionAction<T>) () -> {
38+
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
39+
try {
40+
System.setProperty("java.library.path", ML_CACHE_PATH.toAbsolutePath().toString());
41+
System.setProperty("DJL_CACHE_DIR", ML_CACHE_PATH.toAbsolutePath().toString());
42+
Thread.currentThread().setContextClassLoader(ai.djl.Model.class.getClassLoader());
43+
44+
return action.call();
45+
} finally {
46+
Thread.currentThread().setContextClassLoader(contextClassLoader);
47+
}
48+
});
49+
}
50+
51+
public static HuggingFaceTokenizer buildHuggingFaceTokenizer(String tokenizerId) {
52+
try {
53+
return withDJLContext(() -> HuggingFaceTokenizer.newInstance(tokenizerId));
54+
} catch (PrivilegedActionException e) {
55+
throw new RuntimeException("Failed to initialize Hugging Face tokenizer. " + e);
56+
}
57+
}
58+
59+
public static Map<String, Float> parseInputStreamToTokenWeights(InputStream inputStream) {
60+
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
61+
Map<String, Float> tokenWeights = new HashMap<>();
62+
String line;
63+
while ((line = reader.readLine()) != null) {
64+
if (line.trim().isEmpty()) {
65+
continue;
66+
}
67+
String[] parts = line.split("\t");
68+
if (parts.length != 2) {
69+
throw new IllegalArgumentException("Invalid line in token weights file: " + line);
70+
}
71+
String token = parts[0];
72+
float weight = Float.parseFloat(parts[1]);
73+
tokenWeights.put(token, weight);
74+
}
75+
return tokenWeights;
76+
} catch (IOException e) {
77+
throw new RuntimeException("Failed to parse token weights file. " + e);
78+
}
79+
}
80+
81+
public static Map<String, Float> fetchTokenWeights(String tokenizerId, String fileName) {
82+
Map<String, Float> tokenWeights = new HashMap<>();
83+
String url = HUGGING_FACE_BASE_URL + tokenizerId + "/" + HUGGING_FACE_RESOLVE_PATH + fileName;
84+
85+
InputStream inputStream = null;
86+
try {
87+
inputStream = withDJLContext(() -> Utils.openUrl(url));
88+
} catch (PrivilegedActionException e) {
89+
throw new RuntimeException("Failed to download file from " + url, e);
90+
}
91+
92+
return parseInputStreamToTokenWeights(inputStream);
93+
}
94+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.analysis;
6+
7+
import org.apache.lucene.analysis.Analyzer;
8+
import org.apache.lucene.analysis.Tokenizer;
9+
10+
import java.util.function.Supplier;
11+
12+
public class HFModelAnalyzer extends Analyzer {
13+
public static final String NAME = "hf_model_tokenizer";
14+
Supplier<Tokenizer> tokenizerSupplier;
15+
16+
public HFModelAnalyzer() {
17+
this.tokenizerSupplier = HFModelTokenizerFactory::createDefault;
18+
}
19+
20+
HFModelAnalyzer(Supplier<Tokenizer> tokenizerSupplier) {
21+
this.tokenizerSupplier = tokenizerSupplier;
22+
}
23+
24+
@Override
25+
protected TokenStreamComponents createComponents(String fieldName) {
26+
final Tokenizer src = tokenizerSupplier.get();
27+
return new TokenStreamComponents(src, src);
28+
}
29+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.analysis;
6+
7+
import org.opensearch.common.settings.Settings;
8+
import org.opensearch.env.Environment;
9+
import org.opensearch.index.IndexSettings;
10+
import org.opensearch.index.analysis.AbstractIndexAnalyzerProvider;
11+
12+
public class HFModelAnalyzerProvider extends AbstractIndexAnalyzerProvider<HFModelAnalyzer> {
13+
private final HFModelAnalyzer analyzer;
14+
15+
public HFModelAnalyzerProvider(IndexSettings indexSettings, Environment environment, String name, Settings settings) {
16+
super(indexSettings, name, settings);
17+
HFModelTokenizerFactory tokenizerFactory = new HFModelTokenizerFactory(indexSettings, environment, name, settings);
18+
analyzer = new HFModelAnalyzer(tokenizerFactory::create);
19+
}
20+
21+
@Override
22+
public HFModelAnalyzer get() {
23+
return analyzer;
24+
}
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.analysis;
6+
7+
import java.io.IOException;
8+
import java.nio.ByteBuffer;
9+
import java.util.Map;
10+
import java.util.Objects;
11+
12+
import com.google.common.io.CharStreams;
13+
import org.apache.lucene.analysis.Tokenizer;
14+
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
15+
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
16+
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
17+
18+
import ai.djl.huggingface.tokenizers.Encoding;
19+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
20+
import org.apache.lucene.util.BytesRef;
21+
22+
public class HFModelTokenizer extends Tokenizer {
23+
public static final String NAME = "hf_model_tokenizer";
24+
private static final Float DEFAULT_TOKEN_WEIGHT = 1.0f;
25+
26+
private final CharTermAttribute termAtt;
27+
private final PayloadAttribute payloadAtt;
28+
private final OffsetAttribute offsetAtt;
29+
private final HuggingFaceTokenizer tokenizer;
30+
private final Map<String, Float> tokenWeights;
31+
32+
private Encoding encoding;
33+
private int tokenIdx = 0;
34+
private int overflowingIdx = 0;
35+
36+
public HFModelTokenizer(HuggingFaceTokenizer huggingFaceTokenizer) {
37+
this(huggingFaceTokenizer, null);
38+
}
39+
40+
public HFModelTokenizer(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, Float> weights) {
41+
termAtt = addAttribute(CharTermAttribute.class);
42+
offsetAtt = addAttribute(OffsetAttribute.class);
43+
if (Objects.nonNull(weights)) {
44+
payloadAtt = addAttribute(PayloadAttribute.class);
45+
} else {
46+
payloadAtt = null;
47+
}
48+
tokenizer = huggingFaceTokenizer;
49+
tokenWeights = weights;
50+
}
51+
52+
@Override
53+
public void reset() throws IOException {
54+
super.reset();
55+
tokenIdx = 0;
56+
overflowingIdx = -1;
57+
String inputStr = CharStreams.toString(input);
58+
encoding = tokenizer.encode(inputStr, false, true);
59+
}
60+
61+
private static boolean isLastTokenInEncodingSegment(int idx, Encoding encodingSegment) {
62+
return idx >= encodingSegment.getTokens().length || encodingSegment.getAttentionMask()[idx] == 0;
63+
}
64+
65+
public static byte[] floatToBytes(float value) {
66+
return ByteBuffer.allocate(4).putFloat(value).array();
67+
}
68+
69+
public static float bytesToFloat(byte[] bytes) {
70+
return ByteBuffer.wrap(bytes).getFloat();
71+
}
72+
73+
@Override
74+
final public boolean incrementToken() throws IOException {
75+
clearAttributes();
76+
Encoding curEncoding = overflowingIdx == -1 ? encoding : encoding.getOverflowing()[overflowingIdx];
77+
78+
while (!isLastTokenInEncodingSegment(tokenIdx, curEncoding) || overflowingIdx < encoding.getOverflowing().length) {
79+
if (isLastTokenInEncodingSegment(tokenIdx, curEncoding)) {
80+
// reset cur segment, go to the next segment
81+
// until overflowingIdx = encoding.getOverflowing().length
82+
tokenIdx = 0;
83+
overflowingIdx++;
84+
if (overflowingIdx >= encoding.getOverflowing().length) {
85+
return false;
86+
}
87+
curEncoding = encoding.getOverflowing()[overflowingIdx];
88+
} else {
89+
termAtt.append(curEncoding.getTokens()[tokenIdx]);
90+
offsetAtt.setOffset(
91+
curEncoding.getCharTokenSpans()[tokenIdx].getStart(),
92+
curEncoding.getCharTokenSpans()[tokenIdx].getEnd()
93+
);
94+
if (Objects.nonNull(tokenWeights)) {
95+
// for neural sparse query, write the token weight to payload field
96+
payloadAtt.setPayload(
97+
new BytesRef(floatToBytes(tokenWeights.getOrDefault(curEncoding.getTokens()[tokenIdx], DEFAULT_TOKEN_WEIGHT)))
98+
);
99+
}
100+
tokenIdx++;
101+
return true;
102+
}
103+
}
104+
105+
return false;
106+
}
107+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.analysis;
6+
7+
import org.apache.lucene.analysis.Tokenizer;
8+
import org.opensearch.common.settings.Settings;
9+
import org.opensearch.env.Environment;
10+
import org.opensearch.index.IndexSettings;
11+
import org.opensearch.index.analysis.AbstractTokenizerFactory;
12+
13+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
14+
15+
import java.util.Map;
16+
import java.util.Objects;
17+
18+
public class HFModelTokenizerFactory extends AbstractTokenizerFactory {
19+
private final HuggingFaceTokenizer tokenizer;
20+
private final Map<String, Float> tokenWeights;
21+
22+
/**
23+
* Atomically loads the HF tokenizer in a lazy fashion once the outer class accesses the static final set the first time.;
24+
*/
25+
private static class DefaultTokenizerHolder {
26+
static final HuggingFaceTokenizer TOKENIZER;
27+
static final Map<String, Float> TOKEN_WEIGHTS;
28+
static private final String DEFAULT_TOKENIZER_ID = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill";
29+
static private final String DEFAULT_TOKEN_WEIGHTS_FILE = "query_token_weights.txt";
30+
31+
static {
32+
try {
33+
TOKENIZER = DJLUtils.buildHuggingFaceTokenizer(DEFAULT_TOKENIZER_ID);
34+
TOKEN_WEIGHTS = DJLUtils.fetchTokenWeights(DEFAULT_TOKENIZER_ID, DEFAULT_TOKEN_WEIGHTS_FILE);
35+
} catch (Exception e) {
36+
throw new RuntimeException("Failed to initialize default hf_model_tokenizer", e);
37+
}
38+
}
39+
}
40+
41+
static public Tokenizer createDefault() {
42+
return new HFModelTokenizer(DefaultTokenizerHolder.TOKENIZER, DefaultTokenizerHolder.TOKEN_WEIGHTS);
43+
}
44+
45+
public HFModelTokenizerFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) {
46+
// For custom tokenizer, the factory is created during IndexModule.newIndexService
47+
// And can be accessed via indexService.getIndexAnalyzers()
48+
super(indexSettings, settings, name);
49+
String tokenizerId = settings.get("tokenizer_id", null);
50+
Objects.requireNonNull(tokenizerId, "tokenizer_id is required");
51+
String tokenWeightsFileName = settings.get("token_weights_file", null);
52+
tokenizer = DJLUtils.buildHuggingFaceTokenizer(tokenizerId);
53+
if (tokenWeightsFileName != null) {
54+
tokenWeights = DJLUtils.fetchTokenWeights(tokenizerId, tokenWeightsFileName);
55+
} else {
56+
tokenWeights = null;
57+
}
58+
}
59+
60+
@Override
61+
public Tokenizer create() {
62+
// the create method will be called for every single analyze request
63+
return new HFModelTokenizer(tokenizer, tokenWeights);
64+
}
65+
}

0 commit comments

Comments
 (0)