Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove abstract pre/postprocess from BaseModel #123

Merged
merged 2 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,4 @@ abstract class BaseModel<Input, Output>(
}

abstract fun runModel(input: Input): Output

protected abstract fun preprocess(input: Input): EValue

protected abstract fun postprocess(output: Array<EValue>): Output
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ class StyleTransferModel(
return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): EValue {
fun preprocess(input: Mat): EValue {
originalSize = input.size()
Imgproc.resize(input, input, getModelImageSize())
return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

override fun postprocess(output: Array<EValue>): Mat {
fun postprocess(output: Array<EValue>): Mat {
val tensor = output[0].toTensor()
val modelShape = getModelImageSize()
val result = ImageProcessor.eValueToMat(tensor.dataAsFloatArray, modelShape.width.toInt(), modelShape.height.toInt())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ class ClassificationModel(
return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): EValue {
fun preprocess(input: Mat): EValue {
Imgproc.resize(input, input, getModelImageSize())
return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

override fun postprocess(output: Array<EValue>): Map<String, Float> {
fun postprocess(output: Array<EValue>): Map<String, Float> {
val tensor = output[0].toTensor()
val probabilities = softmax(tensor.dataAsFloatArray.toTypedArray())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SSDLiteLargeModel(
return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): EValue {
fun preprocess(input: Mat): EValue {
this.widthRatio = (input.size().width / getModelImageSize().width).toFloat()
this.heightRatio = (input.size().height / getModelImageSize().height).toFloat()
Imgproc.resize(input, input, getModelImageSize())
Expand All @@ -42,7 +42,7 @@ class SSDLiteLargeModel(
return postprocess(modelOutput)
}

override fun postprocess(output: Array<EValue>): Array<Detection> {
fun postprocess(output: Array<EValue>): Array<Detection> {
val scoresTensor = output[1].toTensor()
val numel = scoresTensor.numel()
val bboxes = output[0].toTensor().dataAsFloatArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Detector(
return modelImageSize
}

override fun preprocess(input: Mat): EValue {
fun preprocess(input: Mat): EValue {
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
val resizedImage =
ImageProcessor.resizeWithPadding(
Expand All @@ -42,7 +42,7 @@ class Detector(
)
}

override fun postprocess(output: Array<EValue>): List<OCRbBox> {
fun postprocess(output: Array<EValue>): List<OCRbBox> {
val outputTensor = output[0].toTensor()
val outputArray = outputTensor.dataAsFloatArray
val modelImageSize = getModelImageSize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class Recognizer(
return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): EValue = ImageProcessor.matToEValueGray(input)
fun preprocess(input: Mat): EValue = ImageProcessor.matToEValueGray(input)

override fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> {
fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> {
val modelOutputHeight = getModelOutputSize().height.toInt()
val tensor = output[0].toTensor().dataAsFloatArray
val numElements = tensor.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class VerticalDetector(
return modelImageSize
}

override fun preprocess(input: Mat): EValue {
fun preprocess(input: Mat): EValue {
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
val resizedImage =
ImageProcessor.resizeWithPadding(
Expand All @@ -43,7 +43,7 @@ class VerticalDetector(
)
}

override fun postprocess(output: Array<EValue>): List<OCRbBox> {
fun postprocess(output: Array<EValue>): List<OCRbBox> {
val outputTensor = output[0].toTensor()
val outputArray = outputTensor.dataAsFloatArray
val modelImageSize = getModelImageSize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,9 @@ abstract class BaseS2TDecoder(

abstract fun getInputShape(inputLength: Int): LongArray

override fun preprocess(input: ReadableArray): EValue {
fun preprocess(input: ReadableArray): EValue {
val inputArray = input.getArray(0)!!
val preprocessorInputShape = this.getInputShape(inputArray.size())
return EValue.from(Tensor.fromBlob(createFloatArray(inputArray), preprocessorInputShape))
}

override fun postprocess(output: Array<EValue>): Int {
TODO("Not yet implemented")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ class MoonshineEncoder(
) : BaseModel<ReadableArray, WritableArray>(reactApplicationContext) {
override fun runModel(input: ReadableArray): WritableArray = this.postprocess(this.module.forward(this.preprocess(input)))

override fun preprocess(input: ReadableArray): EValue {
fun preprocess(input: ReadableArray): EValue {
val size = input.size()
val preprocessorInputShape = longArrayOf(1, size.toLong())
return EValue.from(Tensor.fromBlob(createFloatArray(input), preprocessorInputShape))
}

public override fun postprocess(output: Array<EValue>): WritableArray {
fun postprocess(output: Array<EValue>): WritableArray {
val outputWritableArray: WritableArray = Arguments.createArray()
output[0].toTensor().dataAsFloatArray.map {
outputWritableArray.pushDouble(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class WhisperEncoder(
return this.postprocess(hiddenState)
}

override fun preprocess(input: ReadableArray): EValue {
fun preprocess(input: ReadableArray): EValue {
val waveformFloatArray = ArrayUtils.createFloatArray(input)

val stftResult = this.stft.fromWaveform(waveformFloatArray)
Expand All @@ -33,7 +33,7 @@ class WhisperEncoder(
return EValue.from(inputTensor)
}

public override fun postprocess(output: Array<EValue>): WritableArray {
fun postprocess(output: Array<EValue>): WritableArray {
val outputWritableArray: WritableArray = Arguments.createArray()

output[0].toTensor().dataAsFloatArray.map {
Expand Down