Skip to content

Commit a598bd1

Browse files
committed
impl interface
Signed-off-by: zhichao-aws <[email protected]>
1 parent ec6b30d commit a598bd1

File tree

6 files changed

+296
-1
lines changed

6 files changed

+296
-1
lines changed

build.gradle

+2
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ dependencies {
259259
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
260260
testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
261261
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
262+
implementation group: 'ai.djl', name: 'api', version: '0.28.0'
263+
implementation group: 'ai.djl.huggingface', name: 'tokenizers', version: '0.28.0'
262264
// ml-common excluded reflection for runtime so we need to add it by ourselves.
263265
// https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9
264266
// TODO: Remove following three lines of dependencies if ml-common include them in their jar
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.analysis;
6+
7+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
8+
import ai.djl.util.Utils;
9+
10+
import java.io.BufferedReader;
11+
import java.io.IOException;
12+
import java.io.InputStream;
13+
import java.io.InputStreamReader;
14+
import java.nio.charset.StandardCharsets;
15+
import java.nio.file.Path;
16+
import java.security.AccessController;
17+
import java.security.PrivilegedActionException;
18+
import java.security.PrivilegedExceptionAction;
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
import java.util.concurrent.Callable;
22+
23+
public class DJLUtils {
24+
static private Path ML_CACHE_PATH;
25+
static private String ML_CACHE_DIR_NAME = "ml_cache";
26+
static private String HUGGING_FACE_BASE_URL = "https://huggingface.co/";
27+
static private String HUGGING_FACE_RESOLVE_PATH = "resolve/main/";
28+
29+
static public void buildDJLCachePath(Path opensearchDataFolder) {
30+
// the logic to build cache path is consistent with ml-commons plugin
31+
// see
32+
// https://github.com/opensearch-project/ml-commons/blob/14b971214c488aa3f4ab150d1a6cc379df1758be/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java#L53
33+
ML_CACHE_PATH = opensearchDataFolder.resolve(ML_CACHE_DIR_NAME);
34+
}
35+
36+
public static <T> T withDJLContext(Callable<T> action) throws PrivilegedActionException {
37+
return AccessController.doPrivileged((PrivilegedExceptionAction<T>) () -> {
38+
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
39+
try {
40+
System.setProperty("java.library.path", ML_CACHE_PATH.toAbsolutePath().toString());
41+
System.setProperty("DJL_CACHE_DIR", ML_CACHE_PATH.toAbsolutePath().toString());
42+
Thread.currentThread().setContextClassLoader(ai.djl.Model.class.getClassLoader());
43+
44+
return action.call();
45+
} finally {
46+
Thread.currentThread().setContextClassLoader(contextClassLoader);
47+
}
48+
});
49+
}
50+
51+
public static HuggingFaceTokenizer buildHuggingFaceTokenizer(String tokenizerId) {
52+
try {
53+
return withDJLContext(() -> HuggingFaceTokenizer.newInstance(tokenizerId));
54+
} catch (PrivilegedActionException e) {
55+
throw new RuntimeException("Failed to initialize Hugging Face tokenizer. " + e);
56+
}
57+
}
58+
59+
public static Map<String, Float> parseInputStreamToTokenWeights(InputStream inputStream) {
60+
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
61+
Map<String, Float> tokenWeights = new HashMap<>();
62+
String line;
63+
while ((line = reader.readLine()) != null) {
64+
if (line.trim().isEmpty()) {
65+
continue;
66+
}
67+
String[] parts = line.split("\t");
68+
if (parts.length != 2) {
69+
throw new IllegalArgumentException("Invalid line in token weights file: " + line);
70+
}
71+
String token = parts[0];
72+
float weight = Float.parseFloat(parts[1]);
73+
tokenWeights.put(token, weight);
74+
}
75+
return tokenWeights;
76+
} catch (IOException e) {
77+
throw new RuntimeException("Failed to parse token weights file. " + e);
78+
}
79+
}
80+
81+
public static Map<String, Float> fetchTokenWeights(String tokenizerId, String fileName) {
82+
Map<String, Float> tokenWeights = new HashMap<>();
83+
String url = HUGGING_FACE_BASE_URL + tokenizerId + "/" + HUGGING_FACE_RESOLVE_PATH + fileName;
84+
85+
InputStream inputStream = null;
86+
try {
87+
inputStream = withDJLContext(() -> Utils.openUrl(url));
88+
} catch (PrivilegedActionException e) {
89+
throw new RuntimeException("Failed to download file from " + url, e);
90+
}
91+
92+
return parseInputStreamToTokenWeights(inputStream);
93+
}
94+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.analysis;
6+
7+
import java.io.IOException;
8+
import java.nio.ByteBuffer;
9+
import java.util.Map;
10+
import java.util.Objects;
11+
12+
import com.google.common.io.CharStreams;
13+
import org.apache.lucene.analysis.Tokenizer;
14+
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
15+
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
16+
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
17+
18+
import ai.djl.huggingface.tokenizers.Encoding;
19+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
20+
import org.apache.lucene.util.BytesRef;
21+
22+
public class HFTokenizer extends Tokenizer {
23+
public static final String NAME = "hf_tokenizer";
24+
private static final Float DEFAULT_TOKEN_WEIGHT = 1.0f;
25+
26+
private final CharTermAttribute termAtt;
27+
private final PayloadAttribute payloadAtt;
28+
private final OffsetAttribute offsetAtt;
29+
private final HuggingFaceTokenizer tokenizer;
30+
private final Map<String, Float> tokenWeights;
31+
32+
private Encoding encoding;
33+
private int tokenIdx = 0;
34+
private int overflowingIdx = 0;
35+
36+
public HFTokenizer(HuggingFaceTokenizer huggingFaceTokenizer) {
37+
this(huggingFaceTokenizer, null);
38+
}
39+
40+
public HFTokenizer(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, Float> weights) {
41+
termAtt = addAttribute(CharTermAttribute.class);
42+
offsetAtt = addAttribute(OffsetAttribute.class);
43+
if (Objects.nonNull(weights)) {
44+
payloadAtt = addAttribute(PayloadAttribute.class);
45+
} else {
46+
payloadAtt = null;
47+
}
48+
tokenizer = huggingFaceTokenizer;
49+
tokenWeights = weights;
50+
}
51+
52+
@Override
53+
public void reset() throws IOException {
54+
super.reset();
55+
tokenIdx = 0;
56+
overflowingIdx = -1;
57+
String inputStr = CharStreams.toString(input);
58+
encoding = tokenizer.encode(inputStr, false, true);
59+
}
60+
61+
private static boolean isLastTokenInEncodingSegment(int idx, Encoding encodingSegment) {
62+
return idx >= encodingSegment.getTokens().length || encodingSegment.getAttentionMask()[idx] == 0;
63+
}
64+
65+
private static byte[] floatToBytes(float value) {
66+
return ByteBuffer.allocate(4).putFloat(value).array();
67+
}
68+
69+
private static float bytesToFloat(byte[] bytes) {
70+
return ByteBuffer.wrap(bytes).getFloat();
71+
}
72+
73+
@Override
74+
final public boolean incrementToken() throws IOException {
75+
clearAttributes();
76+
Encoding curEncoding = overflowingIdx == -1 ? encoding : encoding.getOverflowing()[overflowingIdx];
77+
78+
while (!isLastTokenInEncodingSegment(tokenIdx, curEncoding) || overflowingIdx < encoding.getOverflowing().length) {
79+
if (isLastTokenInEncodingSegment(tokenIdx, curEncoding)) {
80+
// reset cur segment, go to the next segment
81+
// until overflowingIdx = encoding.getOverflowing().length
82+
tokenIdx = 0;
83+
overflowingIdx++;
84+
if (overflowingIdx >= encoding.getOverflowing().length) {
85+
return false;
86+
}
87+
curEncoding = encoding.getOverflowing()[overflowingIdx];
88+
} else {
89+
termAtt.append(curEncoding.getTokens()[tokenIdx]);
90+
offsetAtt.setOffset(
91+
curEncoding.getCharTokenSpans()[tokenIdx].getStart(),
92+
curEncoding.getCharTokenSpans()[tokenIdx].getEnd()
93+
);
94+
if (Objects.nonNull(tokenWeights)) {
95+
// for neural sparse query, write the token weight to payload field
96+
payloadAtt.setPayload(
97+
new BytesRef(floatToBytes(tokenWeights.getOrDefault(curEncoding.getTokens()[tokenIdx], DEFAULT_TOKEN_WEIGHT)))
98+
);
99+
}
100+
tokenIdx++;
101+
return true;
102+
}
103+
}
104+
105+
return false;
106+
}
107+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.analysis;
6+
7+
import org.apache.lucene.analysis.Tokenizer;
8+
import org.opensearch.common.settings.Settings;
9+
import org.opensearch.env.Environment;
10+
import org.opensearch.index.IndexSettings;
11+
import org.opensearch.index.analysis.AbstractTokenizerFactory;
12+
13+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
14+
15+
import java.util.Map;
16+
17+
public class HFTokenizerFactory extends AbstractTokenizerFactory {
18+
private final HuggingFaceTokenizer tokenizer;
19+
private final Map<String, Float> tokenWeights;
20+
21+
static private final String DEFAULT_TOKENIZER_ID = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill";
22+
static private final String DEFAULT_TOKEN_WEIGHTS_FILE = "query_token_weights.txt";
23+
static private volatile HuggingFaceTokenizer defaultTokenizer;
24+
static private volatile Map<String, Float> defaultTokenWeights;
25+
26+
static public Tokenizer createDefaultTokenizer() {
27+
// what if throw exception during init?
28+
if (defaultTokenizer == null) {
29+
synchronized (HFTokenizerFactory.class) {
30+
if (defaultTokenizer == null) {
31+
defaultTokenizer = DJLUtils.buildHuggingFaceTokenizer(DEFAULT_TOKENIZER_ID);
32+
defaultTokenWeights = DJLUtils.fetchTokenWeights(DEFAULT_TOKENIZER_ID, DEFAULT_TOKEN_WEIGHTS_FILE);
33+
}
34+
}
35+
}
36+
return new HFTokenizer(defaultTokenizer, defaultTokenWeights);
37+
}
38+
39+
public HFTokenizerFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) {
40+
// For custom tokenizer, the factory is created during IndexModule.newIndexService
41+
// And can be accessed via indexService.getIndexAnalyzers()
42+
super(indexSettings, settings, name);
43+
String tokenizerId = settings.get("tokenizer_id", DEFAULT_TOKENIZER_ID);
44+
String tokenWeightsFileName = settings.get("token_weights_file", null);
45+
tokenizer = DJLUtils.buildHuggingFaceTokenizer(tokenizerId);
46+
if (tokenWeightsFileName != null) {
47+
tokenWeights = DJLUtils.fetchTokenWeights(tokenizerId, tokenWeightsFileName);
48+
} else {
49+
tokenWeights = null;
50+
}
51+
}
52+
53+
@Override
54+
public Tokenizer create() {
55+
// the create method will be called for every single analyze request
56+
return new HFTokenizer(tokenizer, tokenWeights);
57+
}
58+
}

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

+29-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;
88
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;
99

10+
import java.util.ArrayList;
1011
import java.util.Arrays;
1112
import java.util.Collection;
1213
import java.util.List;
@@ -24,8 +25,14 @@
2425
import org.opensearch.core.xcontent.NamedXContentRegistry;
2526
import org.opensearch.env.Environment;
2627
import org.opensearch.env.NodeEnvironment;
28+
import org.opensearch.index.analysis.PreConfiguredTokenizer;
29+
import org.opensearch.index.analysis.TokenizerFactory;
30+
import org.opensearch.indices.analysis.AnalysisModule;
2731
import org.opensearch.ingest.Processor;
2832
import org.opensearch.ml.client.MachineLearningNodeClient;
33+
import org.opensearch.neuralsearch.analysis.DJLUtils;
34+
import org.opensearch.neuralsearch.analysis.HFTokenizer;
35+
import org.opensearch.neuralsearch.analysis.HFTokenizerFactory;
2936
import org.opensearch.neuralsearch.executors.HybridQueryExecutor;
3037
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
3138
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
@@ -56,6 +63,7 @@
5663
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
5764
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
5865
import org.opensearch.plugins.ActionPlugin;
66+
import org.opensearch.plugins.AnalysisPlugin;
5967
import org.opensearch.plugins.ExtensiblePlugin;
6068
import org.opensearch.plugins.IngestPlugin;
6169
import org.opensearch.plugins.Plugin;
@@ -77,7 +85,14 @@
7785
* Neural Search plugin class
7886
*/
7987
@Log4j2
80-
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin {
88+
public class NeuralSearch extends Plugin
89+
implements
90+
ActionPlugin,
91+
SearchPlugin,
92+
IngestPlugin,
93+
ExtensiblePlugin,
94+
SearchPipelinePlugin,
95+
AnalysisPlugin {
8196
private MLCommonsClientAccessor clientAccessor;
8297
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
8398
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
@@ -103,6 +118,7 @@ public Collection<Object> createComponents(
103118
NeuralSparseQueryBuilder.initialize(clientAccessor);
104119
HybridQueryExecutor.initialize(threadPool);
105120
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
121+
DJLUtils.buildDJLCachePath(environment.dataFiles()[0]);
106122
return List.of(clientAccessor);
107123
}
108124

@@ -200,4 +216,16 @@ public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
200216
)
201217
);
202218
}
219+
220+
@Override
221+
public Map<String, AnalysisModule.AnalysisProvider<TokenizerFactory>> getTokenizers() {
222+
return Map.of("hf_tokenizer", HFTokenizerFactory::new);
223+
}
224+
225+
@Override
226+
public List<PreConfiguredTokenizer> getPreConfiguredTokenizers() {
227+
List<PreConfiguredTokenizer> tokenizers = new ArrayList<>();
228+
tokenizers.add(PreConfiguredTokenizer.singleton(HFTokenizer.NAME, HFTokenizerFactory::createDefaultTokenizer));
229+
return tokenizers;
230+
}
203231
}

src/main/plugin-metadata/plugin-security.policy

+6
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,10 @@ grant {
44
permission java.lang.RuntimePermission "accessDeclaredMembers";
55
permission java.lang.reflect.ReflectPermission "suppressAccessChecks";
66
permission java.lang.RuntimePermission "setContextClassLoader";
7+
8+
permission java.net.SocketPermission "*", "connect,resolve";
9+
permission java.lang.RuntimePermission "loadLibrary.*";
10+
permission java.lang.RuntimePermission "setContextClassLoader";
11+
permission java.util.PropertyPermission "DJL_CACHE_DIR", "read,write";
12+
permission java.util.PropertyPermission "java.library.path", "read,write";
713
};

0 commit comments

Comments
 (0)