Skip to content

Commit 80721ec

Browse files
authored
feat: Whisper (#101)
## Description This PR provides an almost-ready implementation of Whisper. The only thing left to do should be to align the TS implementation with Moonshine, Hookless API and to create a hook for this. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 985b8a0 commit 80721ec

40 files changed

+51096
-276
lines changed

LICENSE

+25
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,28 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
5252
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
5353
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
5454
SOFTWARE.
55+
56+
This software includes components from the JTransforms library. The license and copyright notice for this library are as follows:
57+
JTransforms
58+
Copyright (c) 2007 onward, Piotr Wendykier
59+
All rights reserved.
60+
61+
Redistribution and use in source and binary forms, with or without
62+
modification, are permitted provided that the following conditions are met:
63+
64+
1. Redistributions of source code must retain the above copyright notice, this
65+
list of conditions and the following disclaimer.
66+
2. Redistributions in binary form must reproduce the above copyright notice,
67+
this list of conditions and the following disclaimer in the documentation
68+
and/or other materials provided with the distribution.
69+
70+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
71+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
72+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
73+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
74+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
75+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
76+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
77+
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
78+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
79+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

android/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ dependencies {
9999
// For < 0.71, this will be from the local maven repo
100100
// For > 0.71, this will be replaced by `com.facebook.react:react-android:$version` by react gradle plugin
101101
//noinspection GradleDynamicVersion
102+
implementation 'com.github.wendykierp:JTransforms:3.1'
102103
implementation "com.facebook.react:react-android:+"
103104
implementation 'org.opencv:opencv:4.10.0'
104105
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

+11
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class RnExecutorchPackage : TurboReactPackage() {
2323
Classification(reactContext)
2424
} else if (name == ObjectDetection.NAME) {
2525
ObjectDetection(reactContext)
26+
} else if (name == SpeechToText.NAME) {
27+
SpeechToText(reactContext)
2628
}
2729
else {
2830
null
@@ -74,6 +76,15 @@ class RnExecutorchPackage : TurboReactPackage() {
7476
false, // isCxxModule
7577
true
7678
)
79+
80+
moduleInfos[SpeechToText.NAME] = ReactModuleInfo(
81+
SpeechToText.NAME,
82+
SpeechToText.NAME,
83+
false, // canOverrideExistingModule
84+
false, // needsEagerInit
85+
false, // isCxxModule
86+
true
87+
)
7788
moduleInfos
7889
}
7990
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package com.swmansion.rnexecutorch
2+
3+
import com.facebook.react.bridge.Promise
4+
import com.facebook.react.bridge.ReactApplicationContext
5+
import com.facebook.react.bridge.ReadableArray
6+
import com.swmansion.rnexecutorch.models.speechToText.WhisperDecoder
7+
import com.swmansion.rnexecutorch.models.speechToText.WhisperEncoder
8+
import com.swmansion.rnexecutorch.models.speechToText.WhisperPreprocessor
9+
import com.swmansion.rnexecutorch.utils.ArrayUtils
10+
import com.swmansion.rnexecutorch.utils.ETError
11+
12+
class SpeechToText(reactContext: ReactApplicationContext) :
13+
NativeSpeechToTextSpec(reactContext) {
14+
private var whisperPreprocessor = WhisperPreprocessor(reactContext)
15+
private var whisperEncoder = WhisperEncoder(reactContext)
16+
private var whisperDecoder = WhisperDecoder(reactContext)
17+
private var START_TOKEN = 50257
18+
private var EOS_TOKEN = 50256
19+
20+
companion object {
21+
const val NAME = "SpeechToText"
22+
}
23+
24+
override fun loadModule(preprocessorSource: String, encoderSource: String, decoderSource: String, promise: Promise) {
25+
try {
26+
this.whisperPreprocessor.loadModel(preprocessorSource)
27+
this.whisperEncoder.loadModel(encoderSource)
28+
this.whisperDecoder.loadModel(decoderSource)
29+
promise.resolve(0)
30+
} catch (e: Exception) {
31+
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
32+
}
33+
}
34+
35+
override fun generate(waveform: ReadableArray, promise: Promise) {
36+
val logMel = this.whisperPreprocessor.runModel(waveform)
37+
val encoding = this.whisperEncoder.runModel(logMel)
38+
val generatedTokens = mutableListOf(this.START_TOKEN)
39+
var lastToken = 0
40+
Thread {
41+
while (lastToken != this.EOS_TOKEN) {
42+
this.whisperDecoder.setGeneratedTokens(generatedTokens)
43+
lastToken = this.whisperDecoder.runModel(encoding)
44+
emitOnToken(lastToken.toDouble())
45+
generatedTokens.add(lastToken)
46+
}
47+
val generatedTokensReadableArray = ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray())
48+
promise.resolve(generatedTokensReadableArray)
49+
}.start()
50+
}
51+
52+
override fun getName(): String {
53+
return NAME
54+
}
55+
}

android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt

+15-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,21 @@ abstract class BaseModel<Input, Output>(val context: Context) {
2222
//The error is thrown when transformation to Tensor fails
2323
throw Error(ETError.InvalidArgument.code.toString())
2424
} catch (e: Exception) {
25-
throw Error(e.message!!)
25+
throw Error(e.message)
26+
}
27+
}
28+
29+
protected fun forward(inputs: Array<FloatArray>, shapes: Array<LongArray>) : Array<EValue> {
30+
// We want to convert each input to EValue, a data structure accepted by ExecuTorch's
31+
// Module. The array below keeps track of that values.
32+
try {
33+
val executorchInputs = inputs.mapIndexed { index, _ -> EValue.from(Tensor.fromBlob(inputs[index], shapes[index]))}
34+
val forwardResult = module.forward(*executorchInputs.toTypedArray())
35+
return forwardResult
36+
} catch (e: IllegalArgumentException) {
37+
throw Error(ETError.InvalidArgument.code.toString())
38+
} catch (e: Exception) {
39+
throw Error(e.message)
2640
}
2741
}
2842

android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import com.swmansion.rnexecutorch.utils.ImageProcessor
55
import org.opencv.core.Mat
66
import org.opencv.core.Size
77
import org.opencv.imgproc.Imgproc
8-
import org.pytorch.executorch.Tensor
98
import org.pytorch.executorch.EValue
109
import com.swmansion.rnexecutorch.models.BaseModel
1110

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package com.swmansion.rnexecutorch.models.speechToText
2+
3+
import com.facebook.react.bridge.ReactApplicationContext
4+
import com.swmansion.rnexecutorch.models.BaseModel
5+
import org.pytorch.executorch.EValue
6+
import org.pytorch.executorch.Tensor
7+
8+
class WhisperDecoder(
9+
reactApplicationContext: ReactApplicationContext,
10+
) : BaseModel<EValue, Int>(reactApplicationContext) {
11+
private var generatedTokens: MutableList<Int> = mutableListOf()
12+
13+
fun setGeneratedTokens(tokens: MutableList<Int>) {
14+
this.generatedTokens = tokens
15+
}
16+
17+
override fun runModel(input: EValue): Int {
18+
val tokensEValue = EValue.from(Tensor.fromBlob(this.generatedTokens.toIntArray(), longArrayOf(1, generatedTokens.size.toLong())))
19+
return this.module
20+
.forward(tokensEValue, input)[0]
21+
.toTensor()
22+
.dataAsLongArray[0]
23+
.toInt()
24+
}
25+
26+
override fun preprocess(input: EValue): EValue {
27+
TODO("Not yet implemented")
28+
}
29+
30+
override fun postprocess(output: Array<EValue>): Int {
31+
TODO("Not yet implemented")
32+
}
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package com.swmansion.rnexecutorch.models.speechToText
2+
3+
import com.facebook.react.bridge.ReactApplicationContext
4+
import com.swmansion.rnexecutorch.models.BaseModel
5+
import org.pytorch.executorch.EValue
6+
import org.pytorch.executorch.Tensor
7+
8+
class WhisperEncoder(reactApplicationContext: ReactApplicationContext) :
9+
BaseModel<EValue, EValue>(reactApplicationContext) {
10+
private val encoderInputShape = longArrayOf(1L, 80L, 3000L)
11+
12+
override fun runModel(input: EValue): EValue {
13+
val inputEValue = this.preprocess(input)
14+
val hiddenState = this.module.forward(inputEValue)
15+
return hiddenState[0]
16+
}
17+
18+
override fun preprocess(input: EValue): EValue {
19+
val inputTensor = Tensor.fromBlob(input.toTensor().dataAsFloatArray, this.encoderInputShape)
20+
return EValue.from(inputTensor)
21+
}
22+
23+
override fun postprocess(output: Array<EValue>): EValue {
24+
TODO("Not yet implemented")
25+
}
26+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.swmansion.rnexecutorch.models.speechToText
2+
3+
import com.facebook.react.bridge.ReactApplicationContext
4+
import com.facebook.react.bridge.ReadableArray
5+
import com.swmansion.rnexecutorch.models.BaseModel
6+
import com.swmansion.rnexecutorch.utils.STFT
7+
import org.pytorch.executorch.EValue
8+
import org.pytorch.executorch.Tensor
9+
10+
class WhisperPreprocessor(reactApplicationContext: ReactApplicationContext) :
11+
BaseModel<ReadableArray, EValue>(reactApplicationContext) {
12+
private val fftSize = 512
13+
private val hopLength = 160
14+
private val stft = STFT(fftSize, hopLength)
15+
16+
override fun runModel(input: ReadableArray): EValue {
17+
val size = input.size()
18+
val inputFloatArray = FloatArray(size)
19+
for (i in 0 until size) {
20+
inputFloatArray[i] = input.getDouble(i).toFloat()
21+
}
22+
val stftResult = this.stft.fromWaveform(inputFloatArray)
23+
val numStftFrames = stftResult.size / (this.fftSize / 2)
24+
val preprocessorInputShape = longArrayOf(numStftFrames.toLong(), (this.fftSize / 2).toLong())
25+
val melSpectrogram = this.module.forward(EValue.from(Tensor.fromBlob(stftResult, preprocessorInputShape)))
26+
return melSpectrogram[0]
27+
}
28+
29+
override fun preprocess(input: ReadableArray): EValue {
30+
TODO("Not yet implemented")
31+
}
32+
33+
override fun postprocess(output: Array<EValue>): EValue {
34+
TODO("Not yet implemented")
35+
}
36+
}

android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
package com.swmansion.rnexecutorch.utils
22

3+
import android.util.Log
34
import com.facebook.react.bridge.Arguments
45
import com.facebook.react.bridge.ReadableArray
56
import org.pytorch.executorch.DType
67
import org.pytorch.executorch.Tensor
78

89
class ArrayUtils {
910
companion object {
10-
private inline fun <reified T> createTypedArrayFromReadableArray(input: ReadableArray, transform: (ReadableArray, Int) -> T): Array<T> {
11+
inline fun <reified T> createTypedArrayFromReadableArray(input: ReadableArray, transform: (ReadableArray, Int) -> T): Array<T> {
1112
return Array(input.size()) { index -> transform(input, index) }
1213
}
1314

15+
fun createByteArray(input: ReadableArray): ByteArray {
16+
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toByte() }.toByteArray()
17+
}
18+
19+
fun createCharArray(input: ReadableArray): CharArray {
20+
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toChar() }.toCharArray()
21+
}
1422
fun createByteArray(input: ReadableArray): ByteArray {
1523
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toByte() }.toByteArray()
1624
}
@@ -62,5 +70,18 @@ class ArrayUtils {
6270

6371
return resultArray
6472
}
73+
74+
fun createReadableArrayFromFloatArray(input: FloatArray): ReadableArray {
75+
val resultArray = Arguments.createArray()
76+
input.forEach { resultArray.pushDouble(it.toDouble()) }
77+
return resultArray
78+
}
79+
80+
fun createReadableArrayFromIntArray(input: IntArray): ReadableArray {
81+
val resultArray = Arguments.createArray()
82+
input.forEach { resultArray.pushInt(it) }
83+
return resultArray
84+
}
85+
6586
}
6687
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package com.swmansion.rnexecutorch.utils
2+
3+
import java.util.Vector
4+
import kotlin.math.cos
5+
import kotlin.math.PI
6+
import org.jtransforms.fft.FloatFFT_1D
7+
import kotlin.math.sqrt
8+
9+
class STFT public constructor(var fftSize: Int = 512, var hopLength: Int = 160) {
10+
private val fftModule = FloatFFT_1D(this.fftSize.toLong())
11+
private val magnitudeScale = 1.0 / this.fftSize
12+
// https://www.mathworks.com/help/signal/ref/hann.html
13+
private val hannWindow = FloatArray(this.fftSize) { i ->0.5f - 0.5f * cos(2f * PI.toFloat() * i / this.fftSize) }
14+
15+
16+
fun fromWaveform(signal: FloatArray): FloatArray {
17+
val numFftFrames = (signal.size - this.fftSize) / this.hopLength
18+
// The output of FFT is always 2x smaller
19+
val stft = FloatArray(numFftFrames * (this.fftSize / 2))
20+
21+
var windowStartIdx = 0
22+
var outputIndex = 0
23+
// TODO: i dont think the substraction at the end is needed
24+
while (windowStartIdx + this.fftSize < signal.size - this.fftSize) {
25+
val currentWindow = signal.copyOfRange(windowStartIdx, windowStartIdx + this.fftSize)
26+
// Apply Hann window to the current slice
27+
for (i in currentWindow.indices) currentWindow[i] *= this.hannWindow[i]
28+
29+
// Perform in-place FFT
30+
this.fftModule.realForward(currentWindow)
31+
32+
stft[outputIndex++] = kotlin.math.abs(currentWindow[0])
33+
for (i in 1 until this.fftSize / 2 - 1) {
34+
val real = currentWindow[2 * i]
35+
val imag = currentWindow[2 * i + 1]
36+
37+
val currentMagnitude = (sqrt(real * real + imag * imag) * this.magnitudeScale).toFloat()
38+
// FIXME: we don't need that, but if we remove this we have to get rid of
39+
// reversing this operation in the preprocessing part
40+
stft[outputIndex++] = 20 * kotlin.math.log10(currentMagnitude)
41+
}
42+
// Nyquist frequency
43+
stft[outputIndex++] = kotlin.math.abs(currentWindow[1])
44+
windowStartIdx += this.hopLength
45+
}
46+
return stft
47+
}
48+
}

ios/ExecutorchLib.xcframework/Info.plist

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
<key>BinaryPath</key>
99
<string>ExecutorchLib.framework/ExecutorchLib</string>
1010
<key>LibraryIdentifier</key>
11-
<string>ios-arm64</string>
11+
<string>ios-arm64-simulator</string>
1212
<key>LibraryPath</key>
1313
<string>ExecutorchLib.framework</string>
1414
<key>SupportedArchitectures</key>
@@ -17,12 +17,14 @@
1717
</array>
1818
<key>SupportedPlatform</key>
1919
<string>ios</string>
20+
<key>SupportedPlatformVariant</key>
21+
<string>simulator</string>
2022
</dict>
2123
<dict>
2224
<key>BinaryPath</key>
2325
<string>ExecutorchLib.framework/ExecutorchLib</string>
2426
<key>LibraryIdentifier</key>
25-
<string>ios-arm64-simulator</string>
27+
<string>ios-arm64</string>
2628
<key>LibraryPath</key>
2729
<string>ExecutorchLib.framework</string>
2830
<key>SupportedArchitectures</key>
@@ -31,8 +33,6 @@
3133
</array>
3234
<key>SupportedPlatform</key>
3335
<string>ios</string>
34-
<key>SupportedPlatformVariant</key>
35-
<string>simulator</string>
3636
</dict>
3737
</array>
3838
<key>CFBundlePackageType</key>

ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Headers/ETModel.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
- (NSNumber *)loadModel:(NSString *)filePath;
99
- (NSNumber *)loadMethod:(NSString *)methodName;
1010
- (NSNumber *)loadForward;
11-
- (NSArray *)forward:(NSArray *)input
12-
shape:(NSArray *)shape
13-
inputType:(NSNumber *)inputType;
11+
- (NSArray *)forward:(NSArray *)inputs
12+
shapes:(NSArray *)shapes
13+
inputTypes: (NSArray *)inputTypes;
1414
- (NSNumber *)getNumberOfInputs;
1515
- (NSNumber *)getInputType:(NSNumber *)index;
1616
- (NSArray *)getInputShape:(NSNumber *)index;

0 commit comments

Comments
 (0)