Skip to content

Commit 97feca1

Browse files
author
Mateusz Kopciński
committed
android s2t wip
1 parent b36306c commit 97feca1

File tree

14 files changed

+239
-43
lines changed

14 files changed

+239
-43
lines changed

android/gradle/wrapper/gradle-wrapper.properties

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
#Mon Mar 03 14:10:10 CET 2025
12
distributionBase=GRADLE_USER_HOME
23
distributionPath=wrapper/dists
3-
distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip
4+
distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip
45
networkTimeout=10000
56
validateDistributionUrl=true
67
zipStoreBase=GRADLE_USER_HOME

android/src/main/java/com/swmansion/rnexecutorch/SpeechToText.kt

+59-20
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,93 @@
11
package com.swmansion.rnexecutorch
22

3+
import android.util.Log
34
import com.facebook.react.bridge.Promise
45
import com.facebook.react.bridge.ReactApplicationContext
56
import com.facebook.react.bridge.ReadableArray
6-
import com.swmansion.rnexecutorch.models.speechToText.WhisperDecoder
7+
import com.swmansion.rnexecutorch.models.speechToText.BaseS2TDecoder
8+
import com.swmansion.rnexecutorch.models.speechToText.BaseS2TModule
9+
import com.swmansion.rnexecutorch.models.speechToText.Moonshine
10+
import com.swmansion.rnexecutorch.models.speechToText.MoonshineEncoder
11+
import com.swmansion.rnexecutorch.models.speechToText.Whisper
712
import com.swmansion.rnexecutorch.models.speechToText.WhisperEncoder
8-
import com.swmansion.rnexecutorch.models.speechToText.WhisperPreprocessor
913
import com.swmansion.rnexecutorch.utils.ArrayUtils
1014
import com.swmansion.rnexecutorch.utils.ETError
15+
import org.pytorch.executorch.EValue
16+
import org.pytorch.executorch.Tensor
1117

12-
class SpeechToText(reactContext: ReactApplicationContext) :
13-
NativeSpeechToTextSpec(reactContext) {
14-
private var whisperPreprocessor = WhisperPreprocessor(reactContext)
15-
private var whisperEncoder = WhisperEncoder(reactContext)
16-
private var whisperDecoder = WhisperDecoder(reactContext)
17-
private var START_TOKEN = 50257
18-
private var EOS_TOKEN = 50256
18+
class SpeechToText(reactContext: ReactApplicationContext) : NativeSpeechToTextSpec(reactContext) {
19+
20+
private lateinit var speechToTextModule: BaseS2TModule;
1921

2022
companion object {
2123
const val NAME = "SpeechToText"
2224
}
2325

24-
override fun loadModule(preprocessorSource: String, encoderSource: String, decoderSource: String, promise: Promise) {
26+
override fun loadModule(modelName: String, modelSources: ReadableArray, promise: Promise): Unit {
27+
Log.i("rn_executorch", "encoder: ${modelSources.getString(0)!!}, decoder: ${modelSources.getString(1)!!}")
28+
Log.i("rn_executorch", "${modelName}")
29+
try {
30+
if(modelName == "moonshine") {
31+
this.speechToTextModule = Moonshine(modelName)
32+
this.speechToTextModule.encoder = MoonshineEncoder(reactApplicationContext)
33+
}
34+
if(modelName == "whisper") {
35+
this.speechToTextModule = Whisper(modelName)
36+
this.speechToTextModule.encoder = WhisperEncoder(reactApplicationContext)
37+
}
38+
this.speechToTextModule.decoder = BaseS2TDecoder(reactApplicationContext)
39+
} catch(e: Exception){
40+
Log.i("rn_executorch", "${e.message}")
41+
}
42+
43+
2544
try {
26-
this.whisperPreprocessor.loadModel(preprocessorSource)
27-
this.whisperEncoder.loadModel(encoderSource)
28-
this.whisperDecoder.loadModel(decoderSource)
45+
Log.i("rn_executorch", "encoder: ${modelSources.getString(0)!!}, decoder: ${modelSources.getString(1)!!}")
46+
Log.i("rn_executorch", this.speechToTextModule.toString())
47+
this.speechToTextModule.loadModel(modelSources.getString(0)!!, modelSources.getString(1)!!)
2948
promise.resolve(0)
49+
Log.i("rn_executorch", "loaded")
3050
} catch (e: Exception) {
51+
Log.i("rn_executorch", "error")
3152
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
3253
}
3354
}
3455

3556
override fun generate(waveform: ReadableArray, promise: Promise) {
36-
val logMel = this.whisperPreprocessor.runModel(waveform)
37-
val encoding = this.whisperEncoder.runModel(logMel)
38-
val generatedTokens = mutableListOf(this.START_TOKEN)
57+
val encoding = this.speechToTextModule.encode(waveform)
58+
val generatedTokens = mutableListOf(this.speechToTextModule.START_TOKEN)
3959
var lastToken = 0
4060
Thread {
41-
while (lastToken != this.EOS_TOKEN) {
42-
this.whisperDecoder.setGeneratedTokens(generatedTokens)
43-
lastToken = this.whisperDecoder.runModel(encoding)
61+
while (lastToken != this.speechToTextModule.EOS_TOKEN) {
62+
lastToken = this.speechToTextModule.decode(generatedTokens, encoding)
4463
emitOnToken(lastToken.toDouble())
4564
generatedTokens.add(lastToken)
4665
}
47-
val generatedTokensReadableArray = ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray())
66+
val generatedTokensReadableArray =
67+
ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray())
4868
promise.resolve(generatedTokensReadableArray)
4969
}.start()
5070
}
5171

72+
override fun encode(waveform: ReadableArray, promise: Promise) {
73+
promise.resolve(this.speechToTextModule.encode(waveform).toDoubleList())
74+
}
75+
76+
override fun decode(prevTokens: ReadableArray, encoderOutput: ReadableArray, promise: Promise): Unit {
77+
val size = encoderOutput.size()
78+
val inputFloatArray = FloatArray(size)
79+
for (i in 0 until size) {
80+
inputFloatArray[i] = prevTokens.getDouble(i).toFloat()
81+
}
82+
val encoderOutputEValue = EValue.from(Tensor.fromBlob(inputFloatArray, longArrayOf(1,
83+
(size/288).toLong(), 288)))
84+
val preTokensMArray = mutableListOf<Int>()
85+
for (i in 0 until prevTokens.size()) {
86+
preTokensMArray.add(prevTokens.getLong(i).toInt())
87+
}
88+
promise.resolve(this.speechToTextModule.decode(preTokensMArray, encoderOutputEValue))
89+
}
90+
5291
override fun getName(): String {
5392
return NAME
5493
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package com.swmansion.rnexecutorch.models.speechToText
2+
3+
import com.swmansion.rnexecutorch.models.BaseModel
4+
import org.pytorch.executorch.EValue
5+
import com.facebook.react.bridge.ReactApplicationContext
6+
import org.pytorch.executorch.Tensor
7+
8+
class BaseS2TDecoder(reactApplicationContext: ReactApplicationContext): BaseModel<EValue, Int>(reactApplicationContext) {
9+
private lateinit var generatedTokens: MutableList<Int>
10+
11+
fun setGeneratedTokens(tokens: MutableList<Int>) {
12+
this.generatedTokens = tokens
13+
}
14+
15+
override fun runModel(input: EValue): Int {
16+
val tokensEValue = EValue.from(Tensor.fromBlob(this.generatedTokens.toIntArray(), longArrayOf(1, generatedTokens.size.toLong())))
17+
return this.module
18+
.forward(tokensEValue, input)[0]
19+
.toTensor()
20+
.dataAsLongArray[0]
21+
.toInt()
22+
}
23+
24+
override fun preprocess(input: EValue): EValue {
25+
TODO("Not yet implemented")
26+
}
27+
28+
override fun postprocess(output: Array<EValue>): Int {
29+
TODO("Not yet implemented")
30+
}
31+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package com.swmansion.rnexecutorch.models.speechToText
2+
3+
import android.util.Log
4+
import com.facebook.react.bridge.ReadableArray
5+
import com.swmansion.rnexecutorch.models.BaseModel
6+
import org.pytorch.executorch.EValue
7+
import org.pytorch.executorch.Module
8+
import java.net.URL
9+
10+
11+
abstract class BaseS2TModule(modelName: String) {
12+
lateinit var encoder: BaseModel<ReadableArray, EValue>
13+
lateinit var decoder: BaseS2TDecoder
14+
abstract var START_TOKEN:Int
15+
abstract var EOS_TOKEN:Int
16+
17+
fun encode(input: ReadableArray): EValue {
18+
return this.encoder.runModel(input)
19+
}
20+
21+
fun decode(prevTokens: MutableList<Int>, encoderOutput: EValue): Int {
22+
this.decoder.setGeneratedTokens(prevTokens)
23+
return this.decoder.runModel(encoderOutput)
24+
}
25+
26+
fun loadModel(encoderSource: String, decoderSource: String) {
27+
Log.i("rn_executorch", "encoder $encoderSource ${URL(encoderSource).path} ${Module.load(URL(encoderSource).path)}")
28+
try {
29+
30+
Log.i("rn_executorch", "encoder loaded decoder")
31+
Log.i("rn_executorch", "encoder loaded decoder: ${this.decoder}")
32+
Log.i("rn_executorch", "encoder loaded encoder: ${this.encoder}")
33+
Log.i("rn_executorch", "encoder loaded decoder: ${this.decoder}")
34+
} catch(e: Exception){
35+
Log.i("rn_executorch", "error: ${e.message}")
36+
}
37+
38+
this.encoder.loadModel(encoderSource)
39+
Log.i("rn_executorch", "decoder $decoderSource ${URL(decoderSource).path}")
40+
this.decoder.loadModel(decoderSource)
41+
Log.i("rn_executorch", "both")
42+
}
43+
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.swmansion.rnexecutorch.models.speechToText
2+
3+
class Moonshine(
4+
modelName: String,
5+
) : BaseS2TModule(modelName) {
6+
override var START_TOKEN = 1
7+
override var EOS_TOKEN = 2
8+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package com.swmansion.rnexecutorch.models.speechToText
2+
3+
import android.util.Log
4+
import com.facebook.react.bridge.ReactApplicationContext
5+
import com.facebook.react.bridge.ReadableArray
6+
import com.swmansion.rnexecutorch.models.BaseModel
7+
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createDoubleArray
8+
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createFloatArray
9+
import org.pytorch.executorch.EValue
10+
import org.pytorch.executorch.Tensor
11+
12+
class MoonshineEncoder(reactApplicationContext: ReactApplicationContext) :
13+
BaseModel<ReadableArray, EValue>(reactApplicationContext) {
14+
15+
override fun runModel(input: ReadableArray): EValue {
16+
val size = input.size()
17+
val inputFloatArray = FloatArray(size)
18+
for (i in 0 until size) {
19+
inputFloatArray[i] = input.getDouble(i).toFloat()
20+
}
21+
val preprocessorInputShape = longArrayOf(1, size.toLong())
22+
val doubleInput = createDoubleArray(input);
23+
Log.i("rn_executorch", "${EValue.from(Tensor.fromBlob(doubleInput, preprocessorInputShape)).isTensor}")
24+
Log.i("rn_executorch", "${EValue.from(Tensor.fromBlob(doubleInput, preprocessorInputShape)).isDoubleList}")
25+
Log.i("rn_executorch", "${doubleInput} shape: ${Tensor.fromBlob(doubleInput, preprocessorInputShape).shape().size}")
26+
27+
val hiddenState = this.module.forward(EValue.from(Tensor.fromBlob(doubleInput, preprocessorInputShape)))
28+
return hiddenState[0]
29+
}
30+
31+
override fun preprocess(input: ReadableArray): EValue {
32+
TODO("Not yet implemented")
33+
}
34+
35+
override fun postprocess(output: Array<EValue>): EValue {
36+
TODO("Not yet implemented")
37+
}
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.swmansion.rnexecutorch.models.speechToText
2+
3+
class Whisper(
4+
modelName: String,
5+
): BaseS2TModule(modelName) {
6+
override var START_TOKEN = 50257
7+
override var EOS_TOKEN = 50256
8+
}

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/WhisperEncoder.kt

+26-5
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,43 @@
11
package com.swmansion.rnexecutorch.models.speechToText
22

33
import com.facebook.react.bridge.ReactApplicationContext
4+
import com.swmansion.rnexecutorch.utils.ArrayUtils
5+
import com.facebook.react.bridge.ReadableArray
46
import com.swmansion.rnexecutorch.models.BaseModel
7+
import com.swmansion.rnexecutorch.utils.STFT
58
import org.pytorch.executorch.EValue
69
import org.pytorch.executorch.Tensor
710

811
class WhisperEncoder(reactApplicationContext: ReactApplicationContext) :
9-
BaseModel<EValue, EValue>(reactApplicationContext) {
10-
private val encoderInputShape = longArrayOf(1L, 80L, 3000L)
12+
BaseModel<ReadableArray, EValue>(reactApplicationContext) {
1113

12-
override fun runModel(input: EValue): EValue {
14+
private val fftSize = 512
15+
private val hopLength = 160
16+
private val stftFrameSize = (this.fftSize / 2).toLong()
17+
private val stft = STFT(fftSize, hopLength)
18+
19+
override fun runModel(input: ReadableArray): EValue {
1320
val inputEValue = this.preprocess(input)
1421
val hiddenState = this.module.forward(inputEValue)
1522
return hiddenState[0]
23+
// val size = input.size()
24+
// val inputFloatArray = FloatArray(size)
25+
// for (i in 0 until size) {
26+
// inputFloatArray[i] = input.getDouble(i).toFloat()
27+
// }
28+
// val stftResult = this.stft.fromWaveform(inputFloatArray)
29+
// val numStftFrames = stftResult.size / (this.fftSize / 2)
30+
// val preprocessorInputShape = longArrayOf(numStftFrames.toLong(), (this.fftSize / 2).toLong())
31+
// val hiddenState = this.module.forward(EValue.from(Tensor.fromBlob(stftResult, preprocessorInputShape)))
32+
// return hiddenState[0]
1633
}
1734

18-
override fun preprocess(input: EValue): EValue {
19-
val inputTensor = Tensor.fromBlob(input.toTensor().dataAsFloatArray, this.encoderInputShape)
35+
override fun preprocess(input: ReadableArray): EValue {
36+
val waveformFloatArray = ArrayUtils.createFloatArray(input)
37+
38+
val stftResult = this.stft.fromWaveform(waveformFloatArray)
39+
val numStftFrames = stftResult.size / this.stftFrameSize
40+
val inputTensor = Tensor.fromBlob(stftResult, longArrayOf(numStftFrames, this.stftFrameSize))
2041
return EValue.from(inputTensor)
2142
}
2243

examples/computer-vision/ios/Podfile.lock

+2-2
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ PODS:
12781278
- ReactCommon/turbomodule/bridging
12791279
- ReactCommon/turbomodule/core
12801280
- Yoga
1281-
- react-native-executorch (0.3.149):
1281+
- react-native-executorch (0.3.151):
12821282
- DoubleConversion
12831283
- glog
12841284
- hermes-engine
@@ -2092,7 +2092,7 @@ SPEC CHECKSUMS:
20922092
React-logger: 26155dc23db5c9038794db915f80bd2044512c2e
20932093
React-Mapbuffer: ad1ba0205205a16dbff11b8ade6d1b3959451658
20942094
React-microtasksnativemodule: e771eb9eb6ace5884ee40a293a0e14a9d7a4343c
2095-
react-native-executorch: e889cf3ec4616fd3f78b9e1f005d67f7d8b10e89
2095+
react-native-executorch: 2df97239270ae096a3cf0cecf9e520c9dfd49b9c
20962096
react-native-image-picker: e7331948589e764ecd5a9c715c3fc14d4e6187e6
20972097
react-native-safe-area-context: d6406c2adbd41b2e09ab1c386781dc1c81a90919
20982098
React-nativeconfig: aeed6e2a8ac02b2df54476afcc7c663416c12bf7

examples/computer-vision/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"react": "18.3.1",
1818
"react-native": "0.76.3",
1919
"react-native-audio-api": "0.4.11",
20-
"react-native-executorch": "/Users/kopcion/swm-ai/react-native-executorch/react-native-executorch-0.3.150.tgz",
20+
"react-native-executorch": "/Users/kopcion/swm-ai/react-native-executorch/react-native-executorch-0.3.197.tgz",
2121
"react-native-image-picker": "^7.2.2",
2222
"react-native-loading-spinner-overlay": "^3.0.1",
2323
"react-native-reanimated": "^3.16.3",

examples/computer-vision/screens/SpeechToTextScreen.tsx

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,19 @@ export const SpeechToTextScreen = () => {
99
transcribe,
1010
loadAudio,
1111
downloadProgress,
12-
} = useSpeechToText({ modelName: 'moonshine' });
12+
} = useSpeechToText({ modelName: 'whisper' });
1313

1414
return (
1515
<>
1616
<View style={styles.imageContainer}>
1717
<Button
1818
title="Download"
19-
onPress={() => loadAudio('http://localhost:8080/output.mp3')}
19+
// onPress={() => loadAudio('http://localhost:8080/output.mp3')}
20+
onPress={() =>
21+
loadAudio(
22+
'https://ai.swmansion.com/storage/moonshine/test_audio.mp3'
23+
)
24+
}
2025
/>
2126
<Button title="Transcribe" onPress={async () => await transcribe()} />
2227
<Text>downloadProgress: {downloadProgress}</Text>
@@ -25,6 +30,7 @@ export const SpeechToTextScreen = () => {
2530
isGenerating: {isModelGenerating ? 'generating' : 'not generating'}
2631
</Text>
2732
<Text>{sequence}</Text>
33+
<Text>whisper</Text>
2834
</View>
2935
</>
3036
);

examples/computer-vision/yarn.lock

+5-5
Original file line numberDiff line numberDiff line change
@@ -3489,7 +3489,7 @@ __metadata:
34893489
react: 18.3.1
34903490
react-native: 0.76.3
34913491
react-native-audio-api: 0.4.11
3492-
react-native-executorch: /Users/kopcion/swm-ai/react-native-executorch/react-native-executorch-0.3.150.tgz
3492+
react-native-executorch: /Users/kopcion/swm-ai/react-native-executorch/react-native-executorch-0.3.197.tgz
34933493
react-native-image-picker: ^7.2.2
34943494
react-native-loading-spinner-overlay: ^3.0.1
34953495
react-native-reanimated: ^3.16.3
@@ -6996,17 +6996,17 @@ __metadata:
69966996
languageName: node
69976997
linkType: hard
69986998

6999-
"react-native-executorch@file:/Users/kopcion/swm-ai/react-native-executorch/react-native-executorch-0.3.150.tgz::locator=computer-vision%40workspace%3A.":
7000-
version: 0.3.150
7001-
resolution: "react-native-executorch@file:/Users/kopcion/swm-ai/react-native-executorch/react-native-executorch-0.3.150.tgz::locator=computer-vision%40workspace%3A."
6999+
"react-native-executorch@file:/Users/kopcion/swm-ai/react-native-executorch/react-native-executorch-0.3.197.tgz::locator=computer-vision%40workspace%3A.":
7000+
version: 0.3.197
7001+
resolution: "react-native-executorch@file:/Users/kopcion/swm-ai/react-native-executorch/react-native-executorch-0.3.197.tgz::locator=computer-vision%40workspace%3A."
70027002
dependencies:
70037003
expo-asset: ^11.0.3
70047004
expo-file-system: ^18.0.10
70057005
react-native-audio-api: 0.4.11
70067006
peerDependencies:
70077007
react: "*"
70087008
react-native: "*"
7009-
checksum: 6bdf4b79dbe44e0a09d656dd0d26d119c6492aa640e755b6abd5e67e36fef121b82de6ae9f4f0a2dfc2289af367a281414f856d7be26d4735f8426d1e5ee4190
7009+
checksum: 9643a491f1bae4d4c2e8f2a15d9243c935a98fbeacd6f2d5995d4ec1b98a11764c6b355a078503f5069f209bbbc79b5076d9d5e8d0400f4afa373875125efc6d
70107010
languageName: node
70117011
linkType: hard
70127012

0 commit comments

Comments
 (0)