Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implemented vertical ocr #109

Merged
merged 12 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ class RnExecutorchPackage : TurboReactPackage() {
ObjectDetection(reactContext)
} else if (name == SpeechToText.NAME) {
SpeechToText(reactContext)
} else if (name == OCR.NAME){
} else if (name == OCR.NAME) {
OCR(reactContext)
}
else {
} else if (name == VerticalOCR.NAME) {
VerticalOCR(reactContext)
} else {
null
}

Expand All @@ -44,54 +45,49 @@ class RnExecutorchPackage : TurboReactPackage() {
true,
)
moduleInfos[ETModule.NAME] = ReactModuleInfo(
ETModule.NAME,
ETModule.NAME,
false, // canOverrideExistingModule
ETModule.NAME, ETModule.NAME, false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)

moduleInfos[StyleTransfer.NAME] = ReactModuleInfo(
StyleTransfer.NAME,
StyleTransfer.NAME,
false, // canOverrideExistingModule
StyleTransfer.NAME, StyleTransfer.NAME, false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)

moduleInfos[Classification.NAME] = ReactModuleInfo(
Classification.NAME,
Classification.NAME,
false, // canOverrideExistingModule
Classification.NAME, Classification.NAME, false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)

moduleInfos[ObjectDetection.NAME] = ReactModuleInfo(
ObjectDetection.NAME,
ObjectDetection.NAME,
false, // canOverrideExistingModule
ObjectDetection.NAME, ObjectDetection.NAME, false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)

moduleInfos[SpeechToText.NAME] = ReactModuleInfo(
SpeechToText.NAME,
SpeechToText.NAME,
false, // canOverrideExistingModule
SpeechToText.NAME, SpeechToText.NAME, false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)

moduleInfos[OCR.NAME] = ReactModuleInfo(
OCR.NAME,
OCR.NAME,
false, // canOverrideExistingModule
OCR.NAME, OCR.NAME, false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)

moduleInfos[VerticalOCR.NAME] = ReactModuleInfo(
VerticalOCR.NAME, VerticalOCR.NAME, false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
Expand Down
173 changes: 173 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package com.swmansion.rnexecutorch

import android.util.Log
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader
import com.swmansion.rnexecutorch.models.ocr.Recognizer
import com.swmansion.rnexecutorch.models.ocr.VerticalDetector
import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils
import org.opencv.core.Core
import org.opencv.core.Mat

class VerticalOCR(reactContext: ReactApplicationContext) :
NativeVerticalOCRSpec(reactContext) {

private lateinit var detectorLarge: VerticalDetector
private lateinit var detectorNarrow: VerticalDetector
private lateinit var recognizer: Recognizer
private lateinit var converter: CTCLabelConverter
private var independentCharacters = true

companion object {
const val NAME = "VerticalOCR"
}

init {
if (!OpenCVLoader.initLocal()) {
Log.d("rn_executorch", "OpenCV not loaded")
} else {
Log.d("rn_executorch", "OpenCV loaded")
}
}

override fun loadModule(
detectorLargeSource: String,
detectorNarrowSource: String,
recognizerSource: String,
symbols: String,
independentCharacters: Boolean,
promise: Promise
) {
try {
this.independentCharacters = independentCharacters
detectorLarge = VerticalDetector(false, reactApplicationContext)
detectorLarge.loadModel(detectorLargeSource)
detectorNarrow = VerticalDetector(true, reactApplicationContext)
detectorNarrow.loadModel(detectorNarrowSource)
recognizer = Recognizer(reactApplicationContext)
recognizer.loadModel(recognizerSource)

converter = CTCLabelConverter(symbols)

promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

override fun forward(input: String, promise: Promise) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put some of the logic inside something like recognitionHandler in horizontal OCR to make it more maintainable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about it, however whole ocr consists of many files already and I felt like adding another file which will be only wrapper won't help that much.

try {
val inputImage = ImageProcessor.readImage(input)
val result = detectorLarge.runModel(inputImage)
val largeDetectorSize = detectorLarge.getModelImageSize()
val resizedImage = ImageProcessor.resizeWithPadding(
inputImage,
largeDetectorSize.width.toInt(),
largeDetectorSize.height.toInt()
)
val predictions = Arguments.createArray()
for (box in result) {
val cords = box.bBox
val boxWidth = cords[2].x - cords[0].x
val boxHeight = cords[2].y - cords[0].y

val boundingBox = RecognizerUtils.extractBoundingBox(cords)
val croppedImage = Mat(resizedImage, boundingBox)

val paddings = RecognizerUtils.calculateResizeRatioAndPaddings(
inputImage.width(),
inputImage.height(),
largeDetectorSize.width.toInt(),
largeDetectorSize.height.toInt()
)

var text = ""
var confidenceScore = 0.0
val boxResult = detectorNarrow.runModel(croppedImage)
val narrowDetectorSize = detectorNarrow.getModelImageSize()

val croppedCharacters = mutableListOf<Mat>()

for (characterBox in boxResult) {
val boxCords = characterBox.bBox
val paddingsBox = RecognizerUtils.calculateResizeRatioAndPaddings(
boxWidth.toInt(),
boxHeight.toInt(),
narrowDetectorSize.width.toInt(),
narrowDetectorSize.height.toInt()
)

var croppedCharacter = RecognizerUtils.cropImageWithBoundingBox(
inputImage,
boxCords,
cords,
paddingsBox,
paddings
)

if (this.independentCharacters) {
croppedCharacter = RecognizerUtils.cropSingleCharacter(croppedCharacter)
croppedCharacter = RecognizerUtils.normalizeForRecognizer(croppedCharacter, 0.0, true)
val recognitionResult = recognizer.runModel(croppedCharacter)
val predIndex = recognitionResult.first
val decodedText = converter.decodeGreedy(predIndex, predIndex.size)
text += decodedText[0]
confidenceScore += recognitionResult.second
} else {
croppedCharacters.add(croppedCharacter)
}
}

if (this.independentCharacters) {
confidenceScore /= boxResult.size
} else {
var mergedCharacters = Mat()
Core.hconcat(croppedCharacters, mergedCharacters)
mergedCharacters = ImageProcessor.resizeWithPadding(
mergedCharacters,
Constants.LARGE_MODEL_WIDTH,
Constants.MODEL_HEIGHT
)
mergedCharacters = RecognizerUtils.normalizeForRecognizer(mergedCharacters, 0.0)

val recognitionResult = recognizer.runModel(mergedCharacters)
val predIndex = recognitionResult.first
val decodedText = converter.decodeGreedy(predIndex, predIndex.size)

text = decodedText[0]
confidenceScore = recognitionResult.second
}

for (bBox in box.bBox) {
bBox.x =
(bBox.x - paddings["left"] as Int) * paddings["resizeRatio"] as Float
bBox.y =
(bBox.y - paddings["top"] as Int) * paddings["resizeRatio"] as Float
}

val resMap = Arguments.createMap()

resMap.putString("text", text)
resMap.putArray("bbox", box.toWritableArray())
resMap.putDouble("confidence", confidenceScore)

predictions.pushMap(resMap)
}

promise.resolve(predictions)
} catch (e: Exception) {
Log.d("rn_executorch", "Error running model: ${e.message}")
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String {
return NAME
}
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
package com.swmansion.rnexecutorch.models.ocr

import android.util.Log
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils
import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Scalar
import org.opencv.core.Size
import org.pytorch.executorch.EValue

class Detector(reactApplicationContext: ReactApplicationContext) :
BaseModel<Mat, List<OCRbBox>>(reactApplicationContext) {
class Detector(
reactApplicationContext: ReactApplicationContext
) : BaseModel<Mat, List<OCRbBox>>(reactApplicationContext) {
private lateinit var originalSize: Size

fun getModelImageSize(): Size {
val inputShape = module.getInputShape(0)
val width = inputShape[inputShape.lastIndex]
val height = inputShape[inputShape.lastIndex - 1]
val width = inputShape[inputShape.lastIndex - 1]
val height = inputShape[inputShape.lastIndex]

val modelImageSize = Size(height.toDouble(), width.toDouble())

Expand All @@ -29,16 +28,11 @@ class Detector(reactApplicationContext: ReactApplicationContext) :
override fun preprocess(input: Mat): EValue {
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
val resizedImage = ImageProcessor.resizeWithPadding(
input,
getModelImageSize().width.toInt(),
getModelImageSize().height.toInt()
input, getModelImageSize().width.toInt(), getModelImageSize().height.toInt()
)

return ImageProcessor.matToEValue(
resizedImage,
module.getInputShape(0),
Constants.MEAN,
Constants.VARIANCE
resizedImage, module.getInputShape(0), Constants.MEAN, Constants.VARIANCE
)
}

Expand All @@ -48,8 +42,7 @@ class Detector(reactApplicationContext: ReactApplicationContext) :
val modelImageSize = getModelImageSize()

val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats(
outputArray,
Size(modelImageSize.width / 2, modelImageSize.height / 2)
outputArray, Size(modelImageSize.width / 2, modelImageSize.height / 2)
)
var bBoxesList = DetectorUtils.getDetBoxesFromTextMap(
scoreText,
Expand All @@ -58,8 +51,10 @@ class Detector(reactApplicationContext: ReactApplicationContext) :
Constants.LINK_THRESHOLD,
Constants.LOW_TEXT_THRESHOLD
)

bBoxesList =
DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat())

bBoxesList = DetectorUtils.groupTextBoxes(
bBoxesList,
Constants.CENTER_THRESHOLD,
Expand Down
Loading