Skip to content

Commit 809a6d0

Browse files
audio float API (#15234)
### Summary Some audio inputs are float[] ### Test plan CI cc @cbilgin Co-authored-by: Hansong Zhang <[email protected]>
1 parent b62c555 commit 809a6d0

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,28 @@ public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames)
233233

234234
private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
235235

236+
/**
237+
* Prefill a multimodal Module with the given audio input.
238+
*
239+
* @param audio Input preprocessed audio as a float array
240+
* @param batch_size Input batch size
241+
* @param n_bins Input number of bins
242+
* @param n_frames Input number of frames
243+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
244+
* exposed to user.
245+
* @throws RuntimeException if the prefill failed
246+
*/
247+
@Experimental
248+
public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) {
249+
int nativeResult = appendAudioInputFloat(audio, batch_size, n_bins, n_frames);
250+
if (nativeResult != 0) {
251+
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
252+
}
253+
return 0;
254+
}
255+
256+
private native int appendAudioInputFloat(float[] audio, int batch_size, int n_bins, int n_frames);
257+
236258
/**
237259
* Prefill a multimodal Module with the given raw audio input.
238260
*

extension/android/jni/jni_layer_llama.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,29 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
325325
return 0;
326326
}
327327

328+
// Returns status_code
329+
jint append_audio_input_float(
330+
facebook::jni::alias_ref<jfloatArray> data,
331+
jint batch_size,
332+
jint n_bins,
333+
jint n_frames) {
334+
if (data == nullptr) {
335+
return static_cast<jint>(Error::EndOfMethod);
336+
}
337+
auto data_size = data->size();
338+
if (data_size != 0) {
339+
std::vector<jfloat> data_jfloat(data_size);
340+
std::vector<float> data_f(data_size);
341+
data->getRegion(0, data_size, data_jfloat.data());
342+
for (int i = 0; i < data_size; i++) {
343+
data_f[i] = data_jfloat[i];
344+
}
345+
llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames};
346+
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
347+
}
348+
return 0;
349+
}
350+
328351
// Returns status_code
329352
jint append_raw_audio_input(
330353
facebook::jni::alias_ref<jbyteArray> data,
@@ -388,6 +411,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
388411
ExecuTorchLlmJni::append_normalized_images_input),
389412
makeNativeMethod(
390413
"appendAudioInput", ExecuTorchLlmJni::append_audio_input),
414+
makeNativeMethod(
415+
"appendAudioInputFloat",
416+
ExecuTorchLlmJni::append_audio_input_float),
391417
makeNativeMethod(
392418
"appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input),
393419
makeNativeMethod(

0 commit comments

Comments
 (0)