diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 709aa683c81..f845937be41 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -59,6 +59,17 @@ public LlmModule(int modelType, String modulePath, String tokenizerPath, float t mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null); } + /** Constructs a LLM Module for a model with the given LlmModuleConfig */ + public LlmModule(LlmModuleConfig config) { + mHybridData = + initHybrid( + config.getModelType(), + config.getModulePath(), + config.getTokenizerPath(), + config.getTemperature(), + config.getDataPath()); + } + public void resetNative() { mHybridData.resetNative(); } @@ -107,6 +118,19 @@ public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo); } + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param config the config for generation + * @param llmCallback callback object to receive results + */ + public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { + int seqLen = config.getSeqLen(); + boolean echo = config.isEcho(); + return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo); + } + /** * Start generating tokens from the module. *