-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathSpeechToText.kt
78 lines (68 loc) · 3.08 KB
/
SpeechToText.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package com.swmansion.rnexecutorch
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.models.speechtotext.BaseS2TModule
import com.swmansion.rnexecutorch.models.speechtotext.Moonshine
import com.swmansion.rnexecutorch.models.speechtotext.MoonshineDecoder
import com.swmansion.rnexecutorch.models.speechtotext.MoonshineEncoder
import com.swmansion.rnexecutorch.models.speechtotext.Whisper
import com.swmansion.rnexecutorch.models.speechtotext.WhisperDecoder
import com.swmansion.rnexecutorch.models.speechtotext.WhisperEncoder
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.writableArrayToEValue
import com.swmansion.rnexecutorch.utils.ETError
class SpeechToText(reactContext: ReactApplicationContext) : NativeSpeechToTextSpec(reactContext) {
private lateinit var speechToTextModule: BaseS2TModule;
companion object {
const val NAME = "SpeechToText"
}
override fun loadModule(modelName: String, modelSources: ReadableArray, promise: Promise): Unit {
try {
if(modelName == "moonshine") {
this.speechToTextModule = Moonshine()
this.speechToTextModule.encoder = MoonshineEncoder(reactApplicationContext)
this.speechToTextModule.decoder = MoonshineDecoder(reactApplicationContext)
}
if(modelName == "whisper") {
this.speechToTextModule = Whisper()
this.speechToTextModule.encoder = WhisperEncoder(reactApplicationContext)
this.speechToTextModule.decoder = WhisperDecoder(reactApplicationContext)
}
} catch(e: Exception){
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
return
}
try {
this.speechToTextModule.loadModel(modelSources.getString(0)!!, modelSources.getString(1)!!)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}
override fun generate(waveform: ReadableArray, promise: Promise) {
val encoding = writableArrayToEValue(this.speechToTextModule.encode(waveform))
val generatedTokens = mutableListOf(this.speechToTextModule.START_TOKEN)
var lastToken = 0
Thread {
while (lastToken != this.speechToTextModule.EOS_TOKEN) {
// TODO uncomment, for now
// lastToken = this.speechToTextModule.decode(generatedTokens, encoding)
emitOnToken(lastToken.toDouble())
generatedTokens.add(lastToken)
}
val generatedTokensReadableArray =
ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray())
promise.resolve(generatedTokensReadableArray)
}.start()
}
override fun encode(waveform: ReadableArray, promise: Promise) {
promise.resolve(this.speechToTextModule.encode(waveform))
}
override fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray, promise: Promise) {
promise.resolve(this.speechToTextModule.decode(prevTokens, encoderOutput))
}
override fun getName(): String {
return NAME
}
}