Skip to content

Commit bcac318

Browse files
committed
Remove abstract pre/postprocess from BaseModel
1 parent e01599f commit bcac318

File tree

12 files changed

+30
-28
lines changed

12 files changed

+30
-28
lines changed

android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt

-4
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,4 @@ abstract class BaseModel<Input, Output>(val context: Context) {
4646
}
4747

4848
abstract fun runModel(input: Input): Output
49-
50-
protected abstract fun preprocess(input: Input): EValue
51-
52-
protected abstract fun postprocess(output: Array<EValue>): Output
5349
}

android/src/main/java/com/swmansion/rnexecutorch/models/StyleTransferModel.kt

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import org.pytorch.executorch.Tensor
99
import org.pytorch.executorch.EValue
1010

1111

12-
class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Mat>(reactApplicationContext) {
12+
class StyleTransferModel(reactApplicationContext: ReactApplicationContext) :
13+
BaseModel<Mat, Mat>(reactApplicationContext) {
1314
private lateinit var originalSize: Size
1415

1516
private fun getModelImageSize(): Size {
@@ -20,13 +21,13 @@ class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : Bas
2021
return Size(height.toDouble(), width.toDouble())
2122
}
2223

23-
override fun preprocess(input: Mat): EValue {
24+
fun preprocess(input: Mat): EValue {
2425
originalSize = input.size()
2526
Imgproc.resize(input, input, getModelImageSize())
2627
return ImageProcessor.matToEValue(input, module.getInputShape(0))
2728
}
2829

29-
override fun postprocess(output: Array<EValue>): Mat {
30+
fun postprocess(output: Array<EValue>): Mat {
3031
val tensor = output[0].toTensor()
3132
val modelShape = getModelImageSize()
3233
val result = ImageProcessor.EValueToMat(tensor.dataAsFloatArray, modelShape.width.toInt(), modelShape.height.toInt())

android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import org.pytorch.executorch.EValue
99
import com.swmansion.rnexecutorch.models.BaseModel
1010

1111

12-
class ClassificationModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Map<String, Float>>(reactApplicationContext) {
12+
class ClassificationModel(reactApplicationContext: ReactApplicationContext) :
13+
BaseModel<Mat, Map<String, Float>>(reactApplicationContext) {
1314
private fun getModelImageSize(): Size {
1415
val inputShape = module.getInputShape(0)
1516
val width = inputShape[inputShape.lastIndex]
@@ -18,12 +19,12 @@ class ClassificationModel(reactApplicationContext: ReactApplicationContext) : Ba
1819
return Size(height.toDouble(), width.toDouble())
1920
}
2021

21-
override fun preprocess(input: Mat): EValue {
22+
fun preprocess(input: Mat): EValue {
2223
Imgproc.resize(input, input, getModelImageSize())
2324
return ImageProcessor.matToEValue(input, module.getInputShape(0))
2425
}
2526

26-
override fun postprocess(output: Array<EValue>): Map<String, Float> {
27+
fun postprocess(output: Array<EValue>): Map<String, Float> {
2728
val tensor = output[0].toTensor()
2829
val probabilities = softmax(tensor.dataAsFloatArray.toTypedArray())
2930

android/src/main/java/com/swmansion/rnexecutorch/models/objectDetection/SSDLiteLargeModel.kt

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ import org.pytorch.executorch.EValue
1515
const val detectionScoreThreshold = .7f
1616
const val iouThreshold = .55f
1717

18-
class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Array<Detection>>(reactApplicationContext) {
18+
class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) :
19+
BaseModel<Mat, Array<Detection>>(reactApplicationContext) {
1920
private var heightRatio: Float = 1.0f
2021
private var widthRatio: Float = 1.0f
2122

@@ -27,7 +28,7 @@ class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : Base
2728
return Size(height.toDouble(), width.toDouble())
2829
}
2930

30-
override fun preprocess(input: Mat): EValue {
31+
fun preprocess(input: Mat): EValue {
3132
this.widthRatio = (input.size().width / getModelImageSize().width).toFloat()
3233
this.heightRatio = (input.size().height / getModelImageSize().height).toFloat()
3334
Imgproc.resize(input, input, getModelImageSize())
@@ -40,7 +41,7 @@ class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : Base
4041
return postprocess(modelOutput)
4142
}
4243

43-
override fun postprocess(output: Array<EValue>): Array<Detection> {
44+
fun postprocess(output: Array<EValue>): Array<Detection> {
4445
val scoresTensor = output[1].toTensor()
4546
val numel = scoresTensor.numel()
4647
val bboxes = output[0].toTensor().dataAsFloatArray

android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class Detector(
2525
return modelImageSize
2626
}
2727

28-
override fun preprocess(input: Mat): EValue {
28+
fun preprocess(input: Mat): EValue {
2929
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
3030
val resizedImage = ImageProcessor.resizeWithPadding(
3131
input, getModelImageSize().width.toInt(), getModelImageSize().height.toInt()
@@ -36,7 +36,7 @@ class Detector(
3636
)
3737
}
3838

39-
override fun postprocess(output: Array<EValue>): List<OCRbBox> {
39+
fun postprocess(output: Array<EValue>): List<OCRbBox> {
4040
val outputTensor = output[0].toTensor()
4141
val outputArray = outputTensor.dataAsFloatArray
4242
val modelImageSize = getModelImageSize()

android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ class Recognizer(reactApplicationContext: ReactApplicationContext) :
1919
return Size(height.toDouble(), width.toDouble())
2020
}
2121

22-
override fun preprocess(input: Mat): EValue {
22+
fun preprocess(input: Mat): EValue {
2323
return ImageProcessor.matToEValueGray(input)
2424
}
2525

26-
override fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> {
26+
fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> {
2727
val modelOutputHeight = getModelOutputSize().height.toInt()
2828
val tensor = output[0].toTensor().dataAsFloatArray
2929
val numElements = tensor.size

android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class VerticalDetector(
2727
return modelImageSize
2828
}
2929

30-
override fun preprocess(input: Mat): EValue {
30+
fun preprocess(input: Mat): EValue {
3131
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
3232
val resizedImage = ImageProcessor.resizeWithPadding(
3333
input,
@@ -43,7 +43,7 @@ class VerticalDetector(
4343
)
4444
}
4545

46-
override fun postprocess(output: Array<EValue>): List<OCRbBox> {
46+
fun postprocess(output: Array<EValue>): List<OCRbBox> {
4747
val outputTensor = output[0].toTensor()
4848
val outputArray = outputTensor.dataAsFloatArray
4949
val modelImageSize = getModelImageSize()

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/BaseS2TDecoder.kt

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ import com.facebook.react.bridge.ReadableArray
77
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createFloatArray
88
import org.pytorch.executorch.Tensor
99

10-
abstract class BaseS2TDecoder(reactApplicationContext: ReactApplicationContext): BaseModel<ReadableArray, Int>(reactApplicationContext) {
10+
abstract class BaseS2TDecoder(reactApplicationContext: ReactApplicationContext):
11+
BaseModel<ReadableArray, Int>(reactApplicationContext) {
1112
protected abstract var methodName: String
1213

1314
abstract fun setGeneratedTokens(tokens: ReadableArray)
@@ -25,13 +26,13 @@ abstract class BaseS2TDecoder(reactApplicationContext: ReactApplicationContext):
2526

2627
abstract fun getInputShape(inputLength: Int): LongArray
2728

28-
override fun preprocess(input: ReadableArray): EValue {
29+
fun preprocess(input: ReadableArray): EValue {
2930
val inputArray = input.getArray(0)!!
3031
val preprocessorInputShape = this.getInputShape(inputArray.size())
3132
return EValue.from(Tensor.fromBlob(createFloatArray(inputArray), preprocessorInputShape))
3233
}
3334

34-
override fun postprocess(output: Array<EValue>): Int {
35+
fun postprocess(output: Array<EValue>): Int {
3536
TODO("Not yet implemented")
3637
}
3738
}

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/MoonshineDecoder.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import com.swmansion.rnexecutorch.utils.ArrayUtils
66
import org.pytorch.executorch.EValue
77
import org.pytorch.executorch.Tensor
88

9-
class MoonshineDecoder(reactApplicationContext: ReactApplicationContext) : BaseS2TDecoder(reactApplicationContext) {
9+
class MoonshineDecoder(reactApplicationContext: ReactApplicationContext) :
10+
BaseS2TDecoder(reactApplicationContext) {
1011
private lateinit var generatedTokens: LongArray
1112
private var innerDim: Long = 288;
1213

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/MoonshineEncoder.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ class MoonshineEncoder(reactApplicationContext: ReactApplicationContext) :
1616
return this.postprocess(this.module.forward(this.preprocess(input)))
1717
}
1818

19-
override fun preprocess(input: ReadableArray): EValue {
19+
fun preprocess(input: ReadableArray): EValue {
2020
val size = input.size()
2121
val preprocessorInputShape = longArrayOf(1, size.toLong())
2222
return EValue.from(Tensor.fromBlob(createFloatArray(input), preprocessorInputShape))
2323
}
2424

25-
public override fun postprocess(output: Array<EValue>): WritableArray {
25+
public fun postprocess(output: Array<EValue>): WritableArray {
2626
val outputWritableArray: WritableArray = Arguments.createArray()
2727
output[0].toTensor().dataAsFloatArray.map {outputWritableArray.pushDouble(
2828
it.toDouble()

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/WhisperDecoder.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import com.swmansion.rnexecutorch.utils.ArrayUtils
66
import org.pytorch.executorch.EValue
77
import org.pytorch.executorch.Tensor
88

9-
class WhisperDecoder(reactApplicationContext: ReactApplicationContext) : BaseS2TDecoder(reactApplicationContext) {
9+
class WhisperDecoder(reactApplicationContext: ReactApplicationContext) :
10+
BaseS2TDecoder(reactApplicationContext) {
1011
private lateinit var generatedTokens: IntArray
1112
override var methodName: String
1213
get() = "forward"

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/WhisperEncoder.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class WhisperEncoder(reactApplicationContext: ReactApplicationContext) :
2525
return this.postprocess(hiddenState)
2626
}
2727

28-
override fun preprocess(input: ReadableArray): EValue {
28+
fun preprocess(input: ReadableArray): EValue {
2929
val waveformFloatArray = ArrayUtils.createFloatArray(input)
3030

3131
val stftResult = this.stft.fromWaveform(waveformFloatArray)
@@ -34,7 +34,7 @@ class WhisperEncoder(reactApplicationContext: ReactApplicationContext) :
3434
return EValue.from(inputTensor)
3535
}
3636

37-
public override fun postprocess(output: Array<EValue>): WritableArray {
37+
public fun postprocess(output: Array<EValue>): WritableArray {
3838
val outputWritableArray: WritableArray = Arguments.createArray()
3939

4040
output[0].toTensor().dataAsFloatArray.map {

0 commit comments

Comments
 (0)