|
1 | 1 | package com.swmansion.rnexecutorch
|
2 | 2 |
|
| 3 | +import android.util.Log |
3 | 4 | import com.facebook.react.bridge.Promise
|
4 | 5 | import com.facebook.react.bridge.ReactApplicationContext
|
5 | 6 | import com.facebook.react.bridge.ReadableArray
|
6 |
| -import com.swmansion.rnexecutorch.models.speechToText.WhisperDecoder |
| 7 | +import com.swmansion.rnexecutorch.models.speechToText.BaseS2TDecoder |
| 8 | +import com.swmansion.rnexecutorch.models.speechToText.BaseS2TModule |
| 9 | +import com.swmansion.rnexecutorch.models.speechToText.Moonshine |
| 10 | +import com.swmansion.rnexecutorch.models.speechToText.MoonshineEncoder |
| 11 | +import com.swmansion.rnexecutorch.models.speechToText.Whisper |
7 | 12 | import com.swmansion.rnexecutorch.models.speechToText.WhisperEncoder
|
8 |
| -import com.swmansion.rnexecutorch.models.speechToText.WhisperPreprocessor |
9 | 13 | import com.swmansion.rnexecutorch.utils.ArrayUtils
|
10 | 14 | import com.swmansion.rnexecutorch.utils.ETError
|
| 15 | +import org.pytorch.executorch.EValue |
| 16 | +import org.pytorch.executorch.Tensor |
11 | 17 |
|
12 |
| -class SpeechToText(reactContext: ReactApplicationContext) : |
13 |
| - NativeSpeechToTextSpec(reactContext) { |
14 |
| - private var whisperPreprocessor = WhisperPreprocessor(reactContext) |
15 |
| - private var whisperEncoder = WhisperEncoder(reactContext) |
16 |
| - private var whisperDecoder = WhisperDecoder(reactContext) |
17 |
| - private var START_TOKEN = 50257 |
18 |
| - private var EOS_TOKEN = 50256 |
| 18 | +class SpeechToText(reactContext: ReactApplicationContext) : NativeSpeechToTextSpec(reactContext) { |
| 19 | + |
| 20 | + private lateinit var speechToTextModule: BaseS2TModule; |
19 | 21 |
|
20 | 22 | companion object {
|
21 | 23 | const val NAME = "SpeechToText"
|
22 | 24 | }
|
23 | 25 |
|
24 |
| - override fun loadModule(preprocessorSource: String, encoderSource: String, decoderSource: String, promise: Promise) { |
| 26 | + override fun loadModule(modelName: String, modelSources: ReadableArray, promise: Promise): Unit { |
| 27 | + Log.i("rn_executorch", "encoder: ${modelSources.getString(0)!!}, decoder: ${modelSources.getString(1)!!}") |
| 28 | + Log.i("rn_executorch", "${modelName}") |
| 29 | + try { |
| 30 | + if(modelName == "moonshine") { |
| 31 | + this.speechToTextModule = Moonshine(modelName) |
| 32 | + this.speechToTextModule.encoder = MoonshineEncoder(reactApplicationContext) |
| 33 | + } |
| 34 | + if(modelName == "whisper") { |
| 35 | + this.speechToTextModule = Whisper(modelName) |
| 36 | + this.speechToTextModule.encoder = WhisperEncoder(reactApplicationContext) |
| 37 | + } |
| 38 | + this.speechToTextModule.decoder = BaseS2TDecoder(reactApplicationContext) |
| 39 | + } catch(e: Exception){ |
| 40 | + Log.i("rn_executorch", "${e.message}") |
| 41 | + } |
| 42 | + |
| 43 | + |
25 | 44 | try {
|
26 |
| - this.whisperPreprocessor.loadModel(preprocessorSource) |
27 |
| - this.whisperEncoder.loadModel(encoderSource) |
28 |
| - this.whisperDecoder.loadModel(decoderSource) |
| 45 | + Log.i("rn_executorch", "encoder: ${modelSources.getString(0)!!}, decoder: ${modelSources.getString(1)!!}") |
| 46 | + Log.i("rn_executorch", this.speechToTextModule.toString()) |
| 47 | + this.speechToTextModule.loadModel(modelSources.getString(0)!!, modelSources.getString(1)!!) |
29 | 48 | promise.resolve(0)
|
| 49 | + Log.i("rn_executorch", "loaded") |
30 | 50 | } catch (e: Exception) {
|
| 51 | + Log.i("rn_executorch", "error") |
31 | 52 | promise.reject(e.message!!, ETError.InvalidModelSource.toString())
|
32 | 53 | }
|
33 | 54 | }
|
34 | 55 |
|
35 | 56 | override fun generate(waveform: ReadableArray, promise: Promise) {
|
36 |
| - val logMel = this.whisperPreprocessor.runModel(waveform) |
37 |
| - val encoding = this.whisperEncoder.runModel(logMel) |
38 |
| - val generatedTokens = mutableListOf(this.START_TOKEN) |
| 57 | + val encoding = this.speechToTextModule.encode(waveform) |
| 58 | + val generatedTokens = mutableListOf(this.speechToTextModule.START_TOKEN) |
39 | 59 | var lastToken = 0
|
40 | 60 | Thread {
|
41 |
| - while (lastToken != this.EOS_TOKEN) { |
42 |
| - this.whisperDecoder.setGeneratedTokens(generatedTokens) |
43 |
| - lastToken = this.whisperDecoder.runModel(encoding) |
| 61 | + while (lastToken != this.speechToTextModule.EOS_TOKEN) { |
| 62 | + lastToken = this.speechToTextModule.decode(generatedTokens, encoding) |
44 | 63 | emitOnToken(lastToken.toDouble())
|
45 | 64 | generatedTokens.add(lastToken)
|
46 | 65 | }
|
47 |
| - val generatedTokensReadableArray = ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray()) |
| 66 | + val generatedTokensReadableArray = |
| 67 | + ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray()) |
48 | 68 | promise.resolve(generatedTokensReadableArray)
|
49 | 69 | }.start()
|
50 | 70 | }
|
51 | 71 |
|
| 72 | + override fun encode(waveform: ReadableArray, promise: Promise) { |
| 73 | + promise.resolve(this.speechToTextModule.encode(waveform).toDoubleList()) |
| 74 | + } |
| 75 | + |
| 76 | + override fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray, promise: Promise): Unit { |
| 77 | + val size = encoderOutput.size() |
| 78 | + val inputFloatArray = FloatArray(size) |
| 79 | + for (i in 0 until size) { |
| 80 | + inputFloatArray[i] = prevTokens.getDouble(i).toFloat() |
| 81 | + } |
| 82 | + val encoderOutputEValue = EValue.from(Tensor.fromBlob(inputFloatArray, longArrayOf(1, |
| 83 | + (size/288).toLong(), 288))) |
| 84 | + val preTokensMArray = mutableListOf<Int>() |
| 85 | + for (i in 0 until prevTokens.size()) { |
| 86 | + preTokensMArray.add(prevTokens.getLong(i).toInt()) |
| 87 | + } |
| 88 | + promise.resolve(this.speechToTextModule.decode(preTokensMArray, encoderOutputEValue)) |
| 89 | + } |
| 90 | + |
52 | 91 | override fun getName(): String {
|
53 | 92 | return NAME
|
54 | 93 | }
|
|
0 commit comments