Skip to content

Commit ca02d17

Browse files
mkopcinsMateusz Kopcińskichmjkb
authored
feat: moonshine and whisper streaming (#110)
## Description Moonshine and whisper with streaming mode in IOS ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [ ] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. --> --------- Co-authored-by: Mateusz Kopciński <[email protected]> Co-authored-by: chmjkb <[email protected]>
1 parent 0a2ef9a commit ca02d17

File tree

138 files changed

+27653
-50738
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

138 files changed

+27653
-50738
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ The minimal supported version is 17.0 for iOS and Android 13.
6969
https://github.com/user-attachments/assets/27ab3406-c7f1-4618-a981-6c86b53547ee
7070

7171
We currently host two example apps demonstrating use cases of our library:
72+
- examples/speech-to-text - Whisper and Moonshine models ready for transcription tasks
7273
- examples/computer-vision - computer vision related tasks
7374
- examples/llama - chat applications showcasing use of LLMs
7475

android/src/main/java/com/swmansion/rnexecutorch/SpeechToText.kt

+44-21
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,75 @@ package com.swmansion.rnexecutorch
33
import com.facebook.react.bridge.Promise
44
import com.facebook.react.bridge.ReactApplicationContext
55
import com.facebook.react.bridge.ReadableArray
6-
import com.swmansion.rnexecutorch.models.speechToText.WhisperDecoder
7-
import com.swmansion.rnexecutorch.models.speechToText.WhisperEncoder
8-
import com.swmansion.rnexecutorch.models.speechToText.WhisperPreprocessor
6+
import com.swmansion.rnexecutorch.models.speechtotext.BaseS2TModule
7+
import com.swmansion.rnexecutorch.models.speechtotext.Moonshine
8+
import com.swmansion.rnexecutorch.models.speechtotext.MoonshineDecoder
9+
import com.swmansion.rnexecutorch.models.speechtotext.MoonshineEncoder
10+
import com.swmansion.rnexecutorch.models.speechtotext.Whisper
11+
import com.swmansion.rnexecutorch.models.speechtotext.WhisperDecoder
12+
import com.swmansion.rnexecutorch.models.speechtotext.WhisperEncoder
913
import com.swmansion.rnexecutorch.utils.ArrayUtils
14+
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.writableArrayToEValue
1015
import com.swmansion.rnexecutorch.utils.ETError
1116

12-
class SpeechToText(reactContext: ReactApplicationContext) :
13-
NativeSpeechToTextSpec(reactContext) {
14-
private var whisperPreprocessor = WhisperPreprocessor(reactContext)
15-
private var whisperEncoder = WhisperEncoder(reactContext)
16-
private var whisperDecoder = WhisperDecoder(reactContext)
17-
private var START_TOKEN = 50257
18-
private var EOS_TOKEN = 50256
17+
class SpeechToText(reactContext: ReactApplicationContext) : NativeSpeechToTextSpec(reactContext) {
18+
19+
private lateinit var speechToTextModule: BaseS2TModule;
1920

2021
companion object {
2122
const val NAME = "SpeechToText"
2223
}
2324

24-
override fun loadModule(preprocessorSource: String, encoderSource: String, decoderSource: String, promise: Promise) {
25+
override fun loadModule(modelName: String, modelSources: ReadableArray, promise: Promise): Unit {
2526
try {
26-
this.whisperPreprocessor.loadModel(preprocessorSource)
27-
this.whisperEncoder.loadModel(encoderSource)
28-
this.whisperDecoder.loadModel(decoderSource)
27+
if(modelName == "moonshine") {
28+
this.speechToTextModule = Moonshine()
29+
this.speechToTextModule.encoder = MoonshineEncoder(reactApplicationContext)
30+
this.speechToTextModule.decoder = MoonshineDecoder(reactApplicationContext)
31+
}
32+
if(modelName == "whisper") {
33+
this.speechToTextModule = Whisper()
34+
this.speechToTextModule.encoder = WhisperEncoder(reactApplicationContext)
35+
this.speechToTextModule.decoder = WhisperDecoder(reactApplicationContext)
36+
}
37+
} catch(e: Exception){
38+
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
39+
return
40+
}
41+
42+
try {
43+
this.speechToTextModule.loadModel(modelSources.getString(0)!!, modelSources.getString(1)!!)
2944
promise.resolve(0)
3045
} catch (e: Exception) {
3146
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
3247
}
3348
}
3449

3550
override fun generate(waveform: ReadableArray, promise: Promise) {
36-
val logMel = this.whisperPreprocessor.runModel(waveform)
37-
val encoding = this.whisperEncoder.runModel(logMel)
38-
val generatedTokens = mutableListOf(this.START_TOKEN)
51+
val encoding = writableArrayToEValue(this.speechToTextModule.encode(waveform))
52+
val generatedTokens = mutableListOf(this.speechToTextModule.START_TOKEN)
3953
var lastToken = 0
4054
Thread {
41-
while (lastToken != this.EOS_TOKEN) {
42-
this.whisperDecoder.setGeneratedTokens(generatedTokens)
43-
lastToken = this.whisperDecoder.runModel(encoding)
55+
while (lastToken != this.speechToTextModule.EOS_TOKEN) {
56+
// TODO uncomment, for now
57+
// lastToken = this.speechToTextModule.decode(generatedTokens, encoding)
4458
emitOnToken(lastToken.toDouble())
4559
generatedTokens.add(lastToken)
4660
}
47-
val generatedTokensReadableArray = ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray())
61+
val generatedTokensReadableArray =
62+
ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray())
4863
promise.resolve(generatedTokensReadableArray)
4964
}.start()
5065
}
5166

67+
override fun encode(waveform: ReadableArray, promise: Promise) {
68+
promise.resolve(this.speechToTextModule.encode(waveform))
69+
}
70+
71+
override fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray, promise: Promise) {
72+
promise.resolve(this.speechToTextModule.decode(prevTokens, encoderOutput))
73+
}
74+
5275
override fun getName(): String {
5376
return NAME
5477
}

android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt

+5-1
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@ abstract class BaseModel<Input, Output>(val context: Context) {
2828
}
2929

3030
protected fun forward(inputs: Array<FloatArray>, shapes: Array<LongArray>) : Array<EValue> {
31+
return this.execute("forward", inputs, shapes);
32+
}
33+
34+
protected fun execute(methodName: String, inputs: Array<FloatArray>, shapes: Array<LongArray>) : Array<EValue> {
3135
// We want to convert each input to EValue, a data structure accepted by ExecuTorch's
3236
// Module. The array below keeps track of that values.
3337
try {
3438
val executorchInputs = inputs.mapIndexed { index, _ -> EValue.from(Tensor.fromBlob(inputs[index], shapes[index]))}
35-
val forwardResult = module.forward(*executorchInputs.toTypedArray())
39+
val forwardResult = module.execute(methodName, *executorchInputs.toTypedArray())
3640
return forwardResult
3741
} catch (e: IllegalArgumentException) {
3842
throw Error(ETError.InvalidArgument.code.toString())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.swmansion.rnexecutorch.models.speechtotext
2+
3+
import com.swmansion.rnexecutorch.models.BaseModel
4+
import org.pytorch.executorch.EValue
5+
import com.facebook.react.bridge.ReactApplicationContext
6+
import com.facebook.react.bridge.ReadableArray
7+
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createFloatArray
8+
import org.pytorch.executorch.Tensor
9+
10+
abstract class BaseS2TDecoder(reactApplicationContext: ReactApplicationContext): BaseModel<ReadableArray, Int>(reactApplicationContext) {
11+
protected abstract var methodName: String
12+
13+
abstract fun setGeneratedTokens(tokens: ReadableArray)
14+
15+
abstract fun getTokensEValue(): EValue
16+
17+
override fun runModel(input: ReadableArray): Int {
18+
val tokensEValue = getTokensEValue()
19+
return this.module
20+
.execute(methodName, tokensEValue, this.preprocess(input))[0]
21+
.toTensor()
22+
.dataAsLongArray.last()
23+
.toInt()
24+
}
25+
26+
abstract fun getInputShape(inputLength: Int): LongArray
27+
28+
override fun preprocess(input: ReadableArray): EValue {
29+
val inputArray = input.getArray(0)!!
30+
val preprocessorInputShape = this.getInputShape(inputArray.size())
31+
return EValue.from(Tensor.fromBlob(createFloatArray(inputArray), preprocessorInputShape))
32+
}
33+
34+
override fun postprocess(output: Array<EValue>): Int {
35+
TODO("Not yet implemented")
36+
}
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package com.swmansion.rnexecutorch.models.speechtotext
2+
3+
import com.facebook.react.bridge.ReadableArray
4+
import com.facebook.react.bridge.WritableArray
5+
import com.swmansion.rnexecutorch.models.BaseModel
6+
7+
8+
abstract class BaseS2TModule() {
9+
lateinit var encoder: BaseModel<ReadableArray, WritableArray>
10+
lateinit var decoder: BaseS2TDecoder
11+
abstract var START_TOKEN:Int
12+
abstract var EOS_TOKEN:Int
13+
14+
fun encode(input: ReadableArray): WritableArray {
15+
return this.encoder.runModel(input)
16+
}
17+
18+
abstract fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray): Int
19+
20+
fun loadModel(encoderSource: String, decoderSource: String) {
21+
this.encoder.loadModel(encoderSource)
22+
this.decoder.loadModel(decoderSource)
23+
}
24+
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package com.swmansion.rnexecutorch.models.speechtotext
2+
3+
import com.facebook.react.bridge.ReadableArray
4+
import com.swmansion.rnexecutorch.utils.ArrayUtils
5+
6+
class Moonshine : BaseS2TModule() {
7+
override var START_TOKEN = 1
8+
override var EOS_TOKEN = 2
9+
override fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray): Int {
10+
this.decoder.setGeneratedTokens(prevTokens)
11+
return this.decoder.runModel(encoderOutput)
12+
}
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.swmansion.rnexecutorch.models.speechtotext
2+
3+
import com.facebook.react.bridge.ReactApplicationContext
4+
import com.facebook.react.bridge.ReadableArray
5+
import com.swmansion.rnexecutorch.utils.ArrayUtils
6+
import org.pytorch.executorch.EValue
7+
import org.pytorch.executorch.Tensor
8+
9+
class MoonshineDecoder(reactApplicationContext: ReactApplicationContext) : BaseS2TDecoder(reactApplicationContext) {
10+
private lateinit var generatedTokens: LongArray
11+
private var innerDim: Long = 288;
12+
13+
override var methodName: String
14+
get() = "forward_cached"
15+
set(value) {}
16+
17+
override fun setGeneratedTokens(tokens: ReadableArray) {
18+
this.generatedTokens = ArrayUtils.createLongArray(tokens)
19+
}
20+
21+
override fun getTokensEValue(): EValue {
22+
return EValue.from(Tensor.fromBlob(this.generatedTokens, longArrayOf(1, generatedTokens.size.toLong())))
23+
}
24+
25+
override fun getInputShape(inputLength: Int): LongArray {
26+
return longArrayOf(1, inputLength.toLong()/innerDim, innerDim)
27+
}
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package com.swmansion.rnexecutorch.models.speechtotext
2+
3+
import com.facebook.react.bridge.Arguments
4+
import com.facebook.react.bridge.ReactApplicationContext
5+
import com.facebook.react.bridge.ReadableArray
6+
import com.facebook.react.bridge.WritableArray
7+
import com.swmansion.rnexecutorch.models.BaseModel
8+
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createFloatArray
9+
import org.pytorch.executorch.EValue
10+
import org.pytorch.executorch.Tensor
11+
12+
class MoonshineEncoder(reactApplicationContext: ReactApplicationContext) :
13+
BaseModel<ReadableArray, WritableArray>(reactApplicationContext) {
14+
15+
override fun runModel(input: ReadableArray): WritableArray {
16+
return this.postprocess(this.module.forward(this.preprocess(input)))
17+
}
18+
19+
override fun preprocess(input: ReadableArray): EValue {
20+
val size = input.size()
21+
val preprocessorInputShape = longArrayOf(1, size.toLong())
22+
return EValue.from(Tensor.fromBlob(createFloatArray(input), preprocessorInputShape))
23+
}
24+
25+
public override fun postprocess(output: Array<EValue>): WritableArray {
26+
val outputWritableArray: WritableArray = Arguments.createArray()
27+
output[0].toTensor().dataAsFloatArray.map {outputWritableArray.pushDouble(
28+
it.toDouble()
29+
)}
30+
return outputWritableArray;
31+
}
32+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package com.swmansion.rnexecutorch.models.speechtotext
2+
3+
import com.facebook.react.bridge.ReadableArray
4+
import com.swmansion.rnexecutorch.utils.ArrayUtils
5+
6+
class Whisper : BaseS2TModule() {
7+
override var START_TOKEN = 50257
8+
override var EOS_TOKEN = 50256
9+
override fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray): Int {
10+
this.decoder.setGeneratedTokens(prevTokens)
11+
return this.decoder.runModel(encoderOutput)
12+
}
13+
}
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,27 @@
1-
package com.swmansion.rnexecutorch.models.speechToText
1+
package com.swmansion.rnexecutorch.models.speechtotext
22

33
import com.facebook.react.bridge.ReactApplicationContext
4-
import com.swmansion.rnexecutorch.models.BaseModel
4+
import com.facebook.react.bridge.ReadableArray
5+
import com.swmansion.rnexecutorch.utils.ArrayUtils
56
import org.pytorch.executorch.EValue
67
import org.pytorch.executorch.Tensor
78

8-
class WhisperDecoder(
9-
reactApplicationContext: ReactApplicationContext,
10-
) : BaseModel<EValue, Int>(reactApplicationContext) {
11-
private var generatedTokens: MutableList<Int> = mutableListOf()
9+
class WhisperDecoder(reactApplicationContext: ReactApplicationContext) : BaseS2TDecoder(reactApplicationContext) {
10+
private lateinit var generatedTokens: IntArray
11+
override var methodName: String
12+
get() = "forward"
13+
set(value) {}
1214

13-
fun setGeneratedTokens(tokens: MutableList<Int>) {
14-
this.generatedTokens = tokens
15-
}
1615

17-
override fun runModel(input: EValue): Int {
18-
val tokensEValue = EValue.from(Tensor.fromBlob(this.generatedTokens.toIntArray(), longArrayOf(1, generatedTokens.size.toLong())))
19-
return this.module
20-
.forward(tokensEValue, input)[0]
21-
.toTensor()
22-
.dataAsLongArray[0]
23-
.toInt()
16+
override fun setGeneratedTokens(tokens: ReadableArray) {
17+
this.generatedTokens = ArrayUtils.createIntArray(tokens)
2418
}
2519

26-
override fun preprocess(input: EValue): EValue {
27-
TODO("Not yet implemented")
20+
override fun getTokensEValue(): EValue {
21+
return EValue.from(Tensor.fromBlob(this.generatedTokens, longArrayOf(1, generatedTokens.size.toLong())))
2822
}
2923

30-
override fun postprocess(output: Array<EValue>): Int {
31-
TODO("Not yet implemented")
24+
override fun getInputShape(inputLength: Int): LongArray {
25+
return longArrayOf(1, 1500, 384)
3226
}
3327
}
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,46 @@
1-
package com.swmansion.rnexecutorch.models.speechToText
1+
package com.swmansion.rnexecutorch.models.speechtotext
22

3+
import android.util.Log
4+
import com.facebook.react.bridge.Arguments
35
import com.facebook.react.bridge.ReactApplicationContext
6+
import com.swmansion.rnexecutorch.utils.ArrayUtils
7+
import com.facebook.react.bridge.ReadableArray
8+
import com.facebook.react.bridge.WritableArray
49
import com.swmansion.rnexecutorch.models.BaseModel
10+
import com.swmansion.rnexecutorch.utils.STFT
511
import org.pytorch.executorch.EValue
612
import org.pytorch.executorch.Tensor
713

814
class WhisperEncoder(reactApplicationContext: ReactApplicationContext) :
9-
BaseModel<EValue, EValue>(reactApplicationContext) {
10-
private val encoderInputShape = longArrayOf(1L, 80L, 3000L)
15+
BaseModel<ReadableArray, WritableArray>(reactApplicationContext) {
1116

12-
override fun runModel(input: EValue): EValue {
17+
private val fftSize = 512
18+
private val hopLength = 160
19+
private val stftFrameSize = (this.fftSize / 2).toLong()
20+
private val stft = STFT(fftSize, hopLength)
21+
22+
override fun runModel(input: ReadableArray): WritableArray {
1323
val inputEValue = this.preprocess(input)
1424
val hiddenState = this.module.forward(inputEValue)
15-
return hiddenState[0]
25+
return this.postprocess(hiddenState)
1626
}
1727

18-
override fun preprocess(input: EValue): EValue {
19-
val inputTensor = Tensor.fromBlob(input.toTensor().dataAsFloatArray, this.encoderInputShape)
28+
override fun preprocess(input: ReadableArray): EValue {
29+
val waveformFloatArray = ArrayUtils.createFloatArray(input)
30+
31+
val stftResult = this.stft.fromWaveform(waveformFloatArray)
32+
val numStftFrames = stftResult.size / this.stftFrameSize
33+
val inputTensor = Tensor.fromBlob(stftResult, longArrayOf(numStftFrames, this.stftFrameSize))
2034
return EValue.from(inputTensor)
2135
}
2236

23-
override fun postprocess(output: Array<EValue>): EValue {
24-
TODO("Not yet implemented")
37+
public override fun postprocess(output: Array<EValue>): WritableArray {
38+
val outputWritableArray: WritableArray = Arguments.createArray()
39+
40+
output[0].toTensor().dataAsFloatArray.map {
41+
outputWritableArray.pushDouble(
42+
it.toDouble()
43+
)}
44+
return outputWritableArray
2545
}
2646
}

0 commit comments

Comments
 (0)