Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: moonshine and whisper streaming #110

Merged
merged 34 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b216529
moonshine impl
Feb 22, 2025
d9b7712
implemented moonshine on native side, refactored whisper
Feb 25, 2025
e3e73fb
Added Moonshine with KV-cache
Feb 25, 2025
3e7a639
moved resources to ai.swmansion, draft of hook
Feb 27, 2025
d87e132
rebased with fixes, added download progress to s2t controller
Feb 27, 2025
b844873
review changes
Feb 27, 2025
bf60daf
moonshine finished
Mar 3, 2025
f0b374a
changed whisper to 2 modules, implemented useS2T hook, implemented ho…
Mar 3, 2025
b36306c
updated modelUrls.ts
Mar 3, 2025
97feca1
android s2t wip
Mar 4, 2025
382af9a
fixed android side
Mar 4, 2025
19e6426
removed tokenzier files
Mar 4, 2025
075799c
review changes
Mar 4, 2025
d950ad3
yarn.lock
Mar 4, 2025
63502dc
fix to android, s2tcontroller small changes
Mar 5, 2025
f24f00e
fixed android
Mar 5, 2025
a257278
speech-to-text app wip
Mar 5, 2025
8e0ccf7
some more s2t changes
Mar 5, 2025
5ef98e1
remove s2t app
Mar 5, 2025
ea0bc4b
added S2T example app
Mar 5, 2025
ec66581
final cleanup
Mar 5, 2025
0ee1490
cleaned up package.json
Mar 5, 2025
0801c82
hopefully final changes
Mar 6, 2025
b12a0f2
maybe now is final
Mar 6, 2025
893505f
Merge branch 'main' into @mkopcins/moonshine
Mar 6, 2025
c32c589
Podfile.lock
Mar 6, 2025
d9b5893
fixed text display in example app
Mar 6, 2025
ad25811
fixed some more styling on example app
Mar 6, 2025
e395d3a
smallest change of all
Mar 6, 2025
84eab7f
Rename bundle identifiers
chmjkb Mar 6, 2025
949359f
Rename bundle identifiers - v2
chmjkb Mar 6, 2025
e25f498
Update yarn.lock
chmjkb Mar 6, 2025
808562f
Rebuild android dir in demo app
chmjkb Mar 6, 2025
f8db120
checked out computer vision app to main
Mar 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion android/gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#Mon Mar 03 14:10:10 CET 2025
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip
networkTimeout=10000
validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
Expand Down
69 changes: 50 additions & 19 deletions android/src/main/java/com/swmansion/rnexecutorch/SpeechToText.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,83 @@ package com.swmansion.rnexecutorch
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.facebook.react.bridge.WritableArray
import com.swmansion.rnexecutorch.models.speechToText.BaseS2TModule
import com.swmansion.rnexecutorch.models.speechToText.Moonshine
import com.swmansion.rnexecutorch.models.speechToText.MoonshineDecoder
import com.swmansion.rnexecutorch.models.speechToText.MoonshineEncoder
import com.swmansion.rnexecutorch.models.speechToText.Whisper
import com.swmansion.rnexecutorch.models.speechToText.WhisperDecoder
import com.swmansion.rnexecutorch.models.speechToText.WhisperEncoder
import com.swmansion.rnexecutorch.models.speechToText.WhisperPreprocessor
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createFloatArray
import com.swmansion.rnexecutorch.utils.ETError
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor

class SpeechToText(reactContext: ReactApplicationContext) :
NativeSpeechToTextSpec(reactContext) {
private var whisperPreprocessor = WhisperPreprocessor(reactContext)
private var whisperEncoder = WhisperEncoder(reactContext)
private var whisperDecoder = WhisperDecoder(reactContext)
private var START_TOKEN = 50257
private var EOS_TOKEN = 50256
class SpeechToText(reactContext: ReactApplicationContext) : NativeSpeechToTextSpec(reactContext) {

private lateinit var speechToTextModule: BaseS2TModule;

companion object {
const val NAME = "SpeechToText"
}

override fun loadModule(preprocessorSource: String, encoderSource: String, decoderSource: String, promise: Promise) {
override fun loadModule(modelName: String, modelSources: ReadableArray, promise: Promise): Unit {
try {
if(modelName == "moonshine") {
this.speechToTextModule = Moonshine()
this.speechToTextModule.encoder = MoonshineEncoder(reactApplicationContext)
this.speechToTextModule.decoder = MoonshineDecoder(reactApplicationContext)
}
if(modelName == "whisper") {
this.speechToTextModule = Whisper()
this.speechToTextModule.encoder = WhisperEncoder(reactApplicationContext)
this.speechToTextModule.decoder = WhisperDecoder(reactApplicationContext)
}
} catch(e: Exception){
}


try {
this.whisperPreprocessor.loadModel(preprocessorSource)
this.whisperEncoder.loadModel(encoderSource)
this.whisperDecoder.loadModel(decoderSource)
this.speechToTextModule.loadModel(modelSources.getString(0)!!, modelSources.getString(1)!!)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

override fun generate(waveform: ReadableArray, promise: Promise) {
val logMel = this.whisperPreprocessor.runModel(waveform)
val encoding = this.whisperEncoder.runModel(logMel)
val generatedTokens = mutableListOf(this.START_TOKEN)
val encoding = this.writableArrayToEValue(this.speechToTextModule.encode(waveform))
val generatedTokens = mutableListOf(this.speechToTextModule.START_TOKEN)
var lastToken = 0
Thread {
while (lastToken != this.EOS_TOKEN) {
this.whisperDecoder.setGeneratedTokens(generatedTokens)
lastToken = this.whisperDecoder.runModel(encoding)
while (lastToken != this.speechToTextModule.EOS_TOKEN) {
// TODO uncomment, for now
// lastToken = this.speechToTextModule.decode(generatedTokens, encoding)
emitOnToken(lastToken.toDouble())
generatedTokens.add(lastToken)
}
val generatedTokensReadableArray = ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray())
val generatedTokensReadableArray =
ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray())
promise.resolve(generatedTokensReadableArray)
}.start()
}

private fun writableArrayToEValue(input: WritableArray): EValue {
val size = input.size()
val preprocessorInputShape = longArrayOf(1, size.toLong())
return EValue.from(Tensor.fromBlob(createFloatArray(input), preprocessorInputShape))
}

override fun encode(waveform: ReadableArray, promise: Promise) {
promise.resolve(this.speechToTextModule.encode(waveform))
}

override fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray, promise: Promise) {
promise.resolve(this.speechToTextModule.decode(prevTokens, encoderOutput))
}

override fun getName(): String {
return NAME
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ abstract class BaseModel<Input, Output>(val context: Context) {
}

protected fun forward(inputs: Array<FloatArray>, shapes: Array<LongArray>) : Array<EValue> {
return this.execute("forward", inputs, shapes);
}

protected fun execute(methodName: String, inputs: Array<FloatArray>, shapes: Array<LongArray>) : Array<EValue> {
// We want to convert each input to EValue, a data structure accepted by ExecuTorch's
// Module. The array below keeps track of that values.
try {
val executorchInputs = inputs.mapIndexed { index, _ -> EValue.from(Tensor.fromBlob(inputs[index], shapes[index]))}
val forwardResult = module.forward(*executorchInputs.toTypedArray())
val forwardResult = module.execute(methodName, *executorchInputs.toTypedArray())
return forwardResult
} catch (e: IllegalArgumentException) {
throw Error(ETError.InvalidArgument.code.toString())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.swmansion.rnexecutorch.models.speechToText

import android.util.Log
import com.swmansion.rnexecutorch.models.BaseModel
import org.pytorch.executorch.EValue
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createFloatArray
import org.pytorch.executorch.Tensor

abstract class BaseS2TDecoder(reactApplicationContext: ReactApplicationContext): BaseModel<ReadableArray, Int>(reactApplicationContext) {
protected abstract var methodName: String

abstract fun setGeneratedTokens(tokens: ReadableArray)

abstract fun getTokensEValue(): EValue

override fun runModel(input: ReadableArray): Int {
val tokensEValue = getTokensEValue()
return this.module
.execute(methodName, tokensEValue, this.preprocess(input))[0]
.toTensor()
.dataAsLongArray.last()
.toInt()
}

abstract fun getInputShape(inputLength: Int): LongArray

override fun preprocess(input: ReadableArray): EValue {
val inputArray = input.getArray(0)!!
val preprocessorInputShape = this.getInputShape(inputArray.size())
return EValue.from(Tensor.fromBlob(createFloatArray(inputArray), preprocessorInputShape))
}

override fun postprocess(output: Array<EValue>): Int {
TODO("Not yet implemented")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.swmansion.rnexecutorch.models.speechToText

import com.facebook.react.bridge.ReadableArray
import com.facebook.react.bridge.WritableArray
import com.swmansion.rnexecutorch.models.BaseModel


abstract class BaseS2TModule() {
lateinit var encoder: BaseModel<ReadableArray, WritableArray>
lateinit var decoder: BaseS2TDecoder
abstract var START_TOKEN:Int
abstract var EOS_TOKEN:Int

fun encode(input: ReadableArray): WritableArray {
return this.encoder.runModel(input)
}

abstract fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray): Int

fun loadModel(encoderSource: String, decoderSource: String) {
this.encoder.loadModel(encoderSource)
this.decoder.loadModel(decoderSource)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.swmansion.rnexecutorch.models.speechToText

import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.ArrayUtils

class Moonshine : BaseS2TModule() {
override var START_TOKEN = 1
override var EOS_TOKEN = 2
override fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray): Int {
this.decoder.setGeneratedTokens(prevTokens)
return this.decoder.runModel(encoderOutput)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.swmansion.rnexecutorch.models.speechToText

import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.ArrayUtils
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor

class MoonshineDecoder(reactApplicationContext: ReactApplicationContext) : BaseS2TDecoder(reactApplicationContext) {
private lateinit var generatedTokens: LongArray

override var methodName: String
get() = "forward_cached"
set(value) {}

override fun setGeneratedTokens(tokens: ReadableArray) {
this.generatedTokens = ArrayUtils.createLongArray(tokens)
}

override fun getTokensEValue(): EValue {
return EValue.from(Tensor.fromBlob(this.generatedTokens, longArrayOf(1, generatedTokens.size.toLong())))
}

override fun getInputShape(inputLength: Int): LongArray {
return longArrayOf(1, inputLength.toLong()/288, 288)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.swmansion.rnexecutorch.models.speechToText

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.facebook.react.bridge.WritableArray
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createFloatArray
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor

class MoonshineEncoder(reactApplicationContext: ReactApplicationContext) :
BaseModel<ReadableArray, WritableArray>(reactApplicationContext) {

override fun runModel(input: ReadableArray): WritableArray {
return this.postprocess(this.module.forward(this.preprocess(input)))
}

override fun preprocess(input: ReadableArray): EValue {
val size = input.size()
val preprocessorInputShape = longArrayOf(1, size.toLong())
return EValue.from(Tensor.fromBlob(createFloatArray(input), preprocessorInputShape))
}

public override fun postprocess(output: Array<EValue>): WritableArray {
val outputWritableArray: WritableArray = Arguments.createArray()
output[0].toTensor().dataAsFloatArray.map {outputWritableArray.pushDouble(
it.toDouble()
)}
return outputWritableArray;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.swmansion.rnexecutorch.models.speechToText

import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.ArrayUtils

class Whisper : BaseS2TModule() {
override var START_TOKEN = 50257
override var EOS_TOKEN = 50256
override fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray): Int {
this.decoder.setGeneratedTokens(prevTokens)
return this.decoder.runModel(encoderOutput)
}
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,27 @@
package com.swmansion.rnexecutorch.models.speechToText

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.BaseModel
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.ArrayUtils
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor

class WhisperDecoder(
reactApplicationContext: ReactApplicationContext,
) : BaseModel<EValue, Int>(reactApplicationContext) {
private var generatedTokens: MutableList<Int> = mutableListOf()
class WhisperDecoder(reactApplicationContext: ReactApplicationContext) : BaseS2TDecoder(reactApplicationContext) {
private lateinit var generatedTokens: IntArray
override var methodName: String
get() = "forward"
set(value) {}

fun setGeneratedTokens(tokens: MutableList<Int>) {
this.generatedTokens = tokens
}

override fun runModel(input: EValue): Int {
val tokensEValue = EValue.from(Tensor.fromBlob(this.generatedTokens.toIntArray(), longArrayOf(1, generatedTokens.size.toLong())))
return this.module
.forward(tokensEValue, input)[0]
.toTensor()
.dataAsLongArray[0]
.toInt()
override fun setGeneratedTokens(tokens: ReadableArray) {
this.generatedTokens = ArrayUtils.createIntArray(tokens)
}

override fun preprocess(input: EValue): EValue {
TODO("Not yet implemented")
override fun getTokensEValue(): EValue {
return EValue.from(Tensor.fromBlob(this.generatedTokens, longArrayOf(1, generatedTokens.size.toLong())))
}

override fun postprocess(output: Array<EValue>): Int {
TODO("Not yet implemented")
override fun getInputShape(inputLength: Int): LongArray {
return longArrayOf(1, 1500, 384)
}
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,47 @@
package com.swmansion.rnexecutorch.models.speechToText

import android.util.Log
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.facebook.react.bridge.ReadableArray
import com.facebook.react.bridge.WritableArray
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.STFT
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor

class WhisperEncoder(reactApplicationContext: ReactApplicationContext) :
BaseModel<EValue, EValue>(reactApplicationContext) {
private val encoderInputShape = longArrayOf(1L, 80L, 3000L)
BaseModel<ReadableArray, WritableArray>(reactApplicationContext) {

override fun runModel(input: EValue): EValue {
private val fftSize = 512
private val hopLength = 160
private val stftFrameSize = (this.fftSize / 2).toLong()
private val stft = STFT(fftSize, hopLength)

override fun runModel(input: ReadableArray): WritableArray {
val inputEValue = this.preprocess(input)
val hiddenState = this.module.forward(inputEValue)
return hiddenState[0]
val tmp = this.postprocess(hiddenState)
return tmp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this

}

override fun preprocess(input: EValue): EValue {
val inputTensor = Tensor.fromBlob(input.toTensor().dataAsFloatArray, this.encoderInputShape)
override fun preprocess(input: ReadableArray): EValue {
val waveformFloatArray = ArrayUtils.createFloatArray(input)

val stftResult = this.stft.fromWaveform(waveformFloatArray)
val numStftFrames = stftResult.size / this.stftFrameSize
val inputTensor = Tensor.fromBlob(stftResult, longArrayOf(numStftFrames, this.stftFrameSize))
return EValue.from(inputTensor)
}

override fun postprocess(output: Array<EValue>): EValue {
TODO("Not yet implemented")
public override fun postprocess(output: Array<EValue>): WritableArray {
val outputWritableArray: WritableArray = Arguments.createArray()

output[0].toTensor().dataAsFloatArray.map {
outputWritableArray.pushDouble(
it.toDouble()
)}
return outputWritableArray
}
}
Loading