Skip to content

Commit 59e2fbf

Browse files
committed
Remove abstract pre/postprocess from BaseModel
1 parent ad69dc5 commit 59e2fbf

File tree

10 files changed

+18
-22
lines changed

10 files changed

+18
-22
lines changed

android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt

-4
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,4 @@ abstract class BaseModel<Input, Output>(
5252
}
5353

5454
abstract fun runModel(input: Input): Output
55-
56-
protected abstract fun preprocess(input: Input): EValue
57-
58-
protected abstract fun postprocess(output: Array<EValue>): Output
5955
}

android/src/main/java/com/swmansion/rnexecutorch/models/StyleTransferModel.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ class StyleTransferModel(
2020
return Size(height.toDouble(), width.toDouble())
2121
}
2222

23-
override fun preprocess(input: Mat): EValue {
23+
fun preprocess(input: Mat): EValue {
2424
originalSize = input.size()
2525
Imgproc.resize(input, input, getModelImageSize())
2626
return ImageProcessor.matToEValue(input, module.getInputShape(0))
2727
}
2828

29-
override fun postprocess(output: Array<EValue>): Mat {
29+
fun postprocess(output: Array<EValue>): Mat {
3030
val tensor = output[0].toTensor()
3131
val modelShape = getModelImageSize()
3232
val result = ImageProcessor.eValueToMat(tensor.dataAsFloatArray, modelShape.width.toInt(), modelShape.height.toInt())

android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ class ClassificationModel(
1919
return Size(height.toDouble(), width.toDouble())
2020
}
2121

22-
override fun preprocess(input: Mat): EValue {
22+
fun preprocess(input: Mat): EValue {
2323
Imgproc.resize(input, input, getModelImageSize())
2424
return ImageProcessor.matToEValue(input, module.getInputShape(0))
2525
}
2626

27-
override fun postprocess(output: Array<EValue>): Map<String, Float> {
27+
fun postprocess(output: Array<EValue>): Map<String, Float> {
2828
val tensor = output[0].toTensor()
2929
val probabilities = softmax(tensor.dataAsFloatArray.toTypedArray())
3030

android/src/main/java/com/swmansion/rnexecutorch/models/objectDetection/SSDLiteLargeModel.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class SSDLiteLargeModel(
2929
return Size(height.toDouble(), width.toDouble())
3030
}
3131

32-
override fun preprocess(input: Mat): EValue {
32+
fun preprocess(input: Mat): EValue {
3333
this.widthRatio = (input.size().width / getModelImageSize().width).toFloat()
3434
this.heightRatio = (input.size().height / getModelImageSize().height).toFloat()
3535
Imgproc.resize(input, input, getModelImageSize())
@@ -42,7 +42,7 @@ class SSDLiteLargeModel(
4242
return postprocess(modelOutput)
4343
}
4444

45-
override fun postprocess(output: Array<EValue>): Array<Detection> {
45+
fun postprocess(output: Array<EValue>): Array<Detection> {
4646
val scoresTensor = output[1].toTensor()
4747
val numel = scoresTensor.numel()
4848
val bboxes = output[0].toTensor().dataAsFloatArray

android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class Detector(
2525
return modelImageSize
2626
}
2727

28-
override fun preprocess(input: Mat): EValue {
28+
fun preprocess(input: Mat): EValue {
2929
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
3030
val resizedImage =
3131
ImageProcessor.resizeWithPadding(
@@ -42,7 +42,7 @@ class Detector(
4242
)
4343
}
4444

45-
override fun postprocess(output: Array<EValue>): List<OCRbBox> {
45+
fun postprocess(output: Array<EValue>): List<OCRbBox> {
4646
val outputTensor = output[0].toTensor()
4747
val outputArray = outputTensor.dataAsFloatArray
4848
val modelImageSize = getModelImageSize()

android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ class Recognizer(
1919
return Size(height.toDouble(), width.toDouble())
2020
}
2121

22-
override fun preprocess(input: Mat): EValue = ImageProcessor.matToEValueGray(input)
22+
fun preprocess(input: Mat): EValue = ImageProcessor.matToEValueGray(input)
2323

24-
override fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> {
24+
fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> {
2525
val modelOutputHeight = getModelOutputSize().height.toInt()
2626
val tensor = output[0].toTensor().dataAsFloatArray
2727
val numElements = tensor.size

android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class VerticalDetector(
2626
return modelImageSize
2727
}
2828

29-
override fun preprocess(input: Mat): EValue {
29+
fun preprocess(input: Mat): EValue {
3030
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
3131
val resizedImage =
3232
ImageProcessor.resizeWithPadding(
@@ -43,7 +43,7 @@ class VerticalDetector(
4343
)
4444
}
4545

46-
override fun postprocess(output: Array<EValue>): List<OCRbBox> {
46+
fun postprocess(output: Array<EValue>): List<OCRbBox> {
4747
val outputTensor = output[0].toTensor()
4848
val outputArray = outputTensor.dataAsFloatArray
4949
val modelImageSize = getModelImageSize()

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/BaseS2TDecoder.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ abstract class BaseS2TDecoder(
2828

2929
abstract fun getInputShape(inputLength: Int): LongArray
3030

31-
override fun preprocess(input: ReadableArray): EValue {
31+
fun preprocess(input: ReadableArray): EValue {
3232
val inputArray = input.getArray(0)!!
3333
val preprocessorInputShape = this.getInputShape(inputArray.size())
3434
return EValue.from(Tensor.fromBlob(createFloatArray(inputArray), preprocessorInputShape))
3535
}
3636

37-
override fun postprocess(output: Array<EValue>): Int {
37+
fun postprocess(output: Array<EValue>): Int {
3838
TODO("Not yet implemented")
3939
}
4040
}

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/MoonshineEncoder.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ class MoonshineEncoder(
1414
) : BaseModel<ReadableArray, WritableArray>(reactApplicationContext) {
1515
override fun runModel(input: ReadableArray): WritableArray = this.postprocess(this.module.forward(this.preprocess(input)))
1616

17-
override fun preprocess(input: ReadableArray): EValue {
17+
fun preprocess(input: ReadableArray): EValue {
1818
val size = input.size()
1919
val preprocessorInputShape = longArrayOf(1, size.toLong())
2020
return EValue.from(Tensor.fromBlob(createFloatArray(input), preprocessorInputShape))
2121
}
2222

23-
public override fun postprocess(output: Array<EValue>): WritableArray {
23+
public fun postprocess(output: Array<EValue>): WritableArray {
2424
val outputWritableArray: WritableArray = Arguments.createArray()
2525
output[0].toTensor().dataAsFloatArray.map {
2626
outputWritableArray.pushDouble(

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/WhisperEncoder.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class WhisperEncoder(
2424
return this.postprocess(hiddenState)
2525
}
2626

27-
override fun preprocess(input: ReadableArray): EValue {
27+
fun preprocess(input: ReadableArray): EValue {
2828
val waveformFloatArray = ArrayUtils.createFloatArray(input)
2929

3030
val stftResult = this.stft.fromWaveform(waveformFloatArray)
@@ -33,7 +33,7 @@ class WhisperEncoder(
3333
return EValue.from(inputTensor)
3434
}
3535

36-
public override fun postprocess(output: Array<EValue>): WritableArray {
36+
public fun postprocess(output: Array<EValue>): WritableArray {
3737
val outputWritableArray: WritableArray = Arguments.createArray()
3838

3939
output[0].toTensor().dataAsFloatArray.map {

0 commit comments

Comments
 (0)