Skip to content

Commit 765305a

Browse files
feat: implemented vertical ocr (#109)
## Description <!-- Provide a concise and descriptive summary of the changes implemented in this PR. --> ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent f1bfd5e commit 765305a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1865
-216
lines changed

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

+17-21
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ class RnExecutorchPackage : TurboReactPackage() {
2525
ObjectDetection(reactContext)
2626
} else if (name == SpeechToText.NAME) {
2727
SpeechToText(reactContext)
28-
} else if (name == OCR.NAME){
28+
} else if (name == OCR.NAME) {
2929
OCR(reactContext)
30-
}
31-
else {
30+
} else if (name == VerticalOCR.NAME) {
31+
VerticalOCR(reactContext)
32+
} else {
3233
null
3334
}
3435

@@ -44,54 +45,49 @@ class RnExecutorchPackage : TurboReactPackage() {
4445
true,
4546
)
4647
moduleInfos[ETModule.NAME] = ReactModuleInfo(
47-
ETModule.NAME,
48-
ETModule.NAME,
49-
false, // canOverrideExistingModule
48+
ETModule.NAME, ETModule.NAME, false, // canOverrideExistingModule
5049
false, // needsEagerInit
5150
false, // isCxxModule
5251
true
5352
)
5453

5554
moduleInfos[StyleTransfer.NAME] = ReactModuleInfo(
56-
StyleTransfer.NAME,
57-
StyleTransfer.NAME,
58-
false, // canOverrideExistingModule
55+
StyleTransfer.NAME, StyleTransfer.NAME, false, // canOverrideExistingModule
5956
false, // needsEagerInit
6057
false, // isCxxModule
6158
true
6259
)
6360

6461
moduleInfos[Classification.NAME] = ReactModuleInfo(
65-
Classification.NAME,
66-
Classification.NAME,
67-
false, // canOverrideExistingModule
62+
Classification.NAME, Classification.NAME, false, // canOverrideExistingModule
6863
false, // needsEagerInit
6964
false, // isCxxModule
7065
true
7166
)
7267

7368
moduleInfos[ObjectDetection.NAME] = ReactModuleInfo(
74-
ObjectDetection.NAME,
75-
ObjectDetection.NAME,
76-
false, // canOverrideExistingModule
69+
ObjectDetection.NAME, ObjectDetection.NAME, false, // canOverrideExistingModule
7770
false, // needsEagerInit
7871
false, // isCxxModule
7972
true
8073
)
8174

8275
moduleInfos[SpeechToText.NAME] = ReactModuleInfo(
83-
SpeechToText.NAME,
84-
SpeechToText.NAME,
85-
false, // canOverrideExistingModule
76+
SpeechToText.NAME, SpeechToText.NAME, false, // canOverrideExistingModule
8677
false, // needsEagerInit
8778
false, // isCxxModule
8879
true
8980
)
9081

9182
moduleInfos[OCR.NAME] = ReactModuleInfo(
92-
OCR.NAME,
93-
OCR.NAME,
94-
false, // canOverrideExistingModule
83+
OCR.NAME, OCR.NAME, false, // canOverrideExistingModule
84+
false, // needsEagerInit
85+
false, // isCxxModule
86+
true
87+
)
88+
89+
moduleInfos[VerticalOCR.NAME] = ReactModuleInfo(
90+
VerticalOCR.NAME, VerticalOCR.NAME, false, // canOverrideExistingModule
9591
false, // needsEagerInit
9692
false, // isCxxModule
9793
true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
package com.swmansion.rnexecutorch
2+
3+
import android.util.Log
4+
import com.facebook.react.bridge.Arguments
5+
import com.facebook.react.bridge.Promise
6+
import com.facebook.react.bridge.ReactApplicationContext
7+
import com.swmansion.rnexecutorch.utils.ETError
8+
import com.swmansion.rnexecutorch.utils.ImageProcessor
9+
import org.opencv.android.OpenCVLoader
10+
import com.swmansion.rnexecutorch.models.ocr.Recognizer
11+
import com.swmansion.rnexecutorch.models.ocr.VerticalDetector
12+
import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter
13+
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
14+
import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils
15+
import org.opencv.core.Core
16+
import org.opencv.core.Mat
17+
18+
class VerticalOCR(reactContext: ReactApplicationContext) :
19+
NativeVerticalOCRSpec(reactContext) {
20+
21+
private lateinit var detectorLarge: VerticalDetector
22+
private lateinit var detectorNarrow: VerticalDetector
23+
private lateinit var recognizer: Recognizer
24+
private lateinit var converter: CTCLabelConverter
25+
private var independentCharacters = true
26+
27+
companion object {
28+
const val NAME = "VerticalOCR"
29+
}
30+
31+
init {
32+
if (!OpenCVLoader.initLocal()) {
33+
Log.d("rn_executorch", "OpenCV not loaded")
34+
} else {
35+
Log.d("rn_executorch", "OpenCV loaded")
36+
}
37+
}
38+
39+
override fun loadModule(
40+
detectorLargeSource: String,
41+
detectorNarrowSource: String,
42+
recognizerSource: String,
43+
symbols: String,
44+
independentCharacters: Boolean,
45+
promise: Promise
46+
) {
47+
try {
48+
this.independentCharacters = independentCharacters
49+
detectorLarge = VerticalDetector(false, reactApplicationContext)
50+
detectorLarge.loadModel(detectorLargeSource)
51+
detectorNarrow = VerticalDetector(true, reactApplicationContext)
52+
detectorNarrow.loadModel(detectorNarrowSource)
53+
recognizer = Recognizer(reactApplicationContext)
54+
recognizer.loadModel(recognizerSource)
55+
56+
converter = CTCLabelConverter(symbols)
57+
58+
promise.resolve(0)
59+
} catch (e: Exception) {
60+
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
61+
}
62+
}
63+
64+
override fun forward(input: String, promise: Promise) {
65+
try {
66+
val inputImage = ImageProcessor.readImage(input)
67+
val result = detectorLarge.runModel(inputImage)
68+
val largeDetectorSize = detectorLarge.getModelImageSize()
69+
val resizedImage = ImageProcessor.resizeWithPadding(
70+
inputImage,
71+
largeDetectorSize.width.toInt(),
72+
largeDetectorSize.height.toInt()
73+
)
74+
val predictions = Arguments.createArray()
75+
for (box in result) {
76+
val cords = box.bBox
77+
val boxWidth = cords[2].x - cords[0].x
78+
val boxHeight = cords[2].y - cords[0].y
79+
80+
val boundingBox = RecognizerUtils.extractBoundingBox(cords)
81+
val croppedImage = Mat(resizedImage, boundingBox)
82+
83+
val paddings = RecognizerUtils.calculateResizeRatioAndPaddings(
84+
inputImage.width(),
85+
inputImage.height(),
86+
largeDetectorSize.width.toInt(),
87+
largeDetectorSize.height.toInt()
88+
)
89+
90+
var text = ""
91+
var confidenceScore = 0.0
92+
val boxResult = detectorNarrow.runModel(croppedImage)
93+
val narrowDetectorSize = detectorNarrow.getModelImageSize()
94+
95+
val croppedCharacters = mutableListOf<Mat>()
96+
97+
for (characterBox in boxResult) {
98+
val boxCords = characterBox.bBox
99+
val paddingsBox = RecognizerUtils.calculateResizeRatioAndPaddings(
100+
boxWidth.toInt(),
101+
boxHeight.toInt(),
102+
narrowDetectorSize.width.toInt(),
103+
narrowDetectorSize.height.toInt()
104+
)
105+
106+
var croppedCharacter = RecognizerUtils.cropImageWithBoundingBox(
107+
inputImage,
108+
boxCords,
109+
cords,
110+
paddingsBox,
111+
paddings
112+
)
113+
114+
if (this.independentCharacters) {
115+
croppedCharacter = RecognizerUtils.cropSingleCharacter(croppedCharacter)
116+
croppedCharacter = RecognizerUtils.normalizeForRecognizer(croppedCharacter, 0.0, true)
117+
val recognitionResult = recognizer.runModel(croppedCharacter)
118+
val predIndex = recognitionResult.first
119+
val decodedText = converter.decodeGreedy(predIndex, predIndex.size)
120+
text += decodedText[0]
121+
confidenceScore += recognitionResult.second
122+
} else {
123+
croppedCharacters.add(croppedCharacter)
124+
}
125+
}
126+
127+
if (this.independentCharacters) {
128+
confidenceScore /= boxResult.size
129+
} else {
130+
var mergedCharacters = Mat()
131+
Core.hconcat(croppedCharacters, mergedCharacters)
132+
mergedCharacters = ImageProcessor.resizeWithPadding(
133+
mergedCharacters,
134+
Constants.LARGE_MODEL_WIDTH,
135+
Constants.MODEL_HEIGHT
136+
)
137+
mergedCharacters = RecognizerUtils.normalizeForRecognizer(mergedCharacters, 0.0)
138+
139+
val recognitionResult = recognizer.runModel(mergedCharacters)
140+
val predIndex = recognitionResult.first
141+
val decodedText = converter.decodeGreedy(predIndex, predIndex.size)
142+
143+
text = decodedText[0]
144+
confidenceScore = recognitionResult.second
145+
}
146+
147+
for (bBox in box.bBox) {
148+
bBox.x =
149+
(bBox.x - paddings["left"] as Int) * paddings["resizeRatio"] as Float
150+
bBox.y =
151+
(bBox.y - paddings["top"] as Int) * paddings["resizeRatio"] as Float
152+
}
153+
154+
val resMap = Arguments.createMap()
155+
156+
resMap.putString("text", text)
157+
resMap.putArray("bbox", box.toWritableArray())
158+
resMap.putDouble("confidence", confidenceScore)
159+
160+
predictions.pushMap(resMap)
161+
}
162+
163+
promise.resolve(predictions)
164+
} catch (e: Exception) {
165+
Log.d("rn_executorch", "Error running model: ${e.message}")
166+
promise.reject(e.message!!, e.message)
167+
}
168+
}
169+
170+
override fun getName(): String {
171+
return NAME
172+
}
173+
}

android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt

+10-15
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
package com.swmansion.rnexecutorch.models.ocr
22

3-
import android.util.Log
43
import com.facebook.react.bridge.ReactApplicationContext
54
import com.swmansion.rnexecutorch.models.BaseModel
65
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
76
import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils
87
import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox
98
import com.swmansion.rnexecutorch.utils.ImageProcessor
109
import org.opencv.core.Mat
11-
import org.opencv.core.Scalar
1210
import org.opencv.core.Size
1311
import org.pytorch.executorch.EValue
1412

15-
class Detector(reactApplicationContext: ReactApplicationContext) :
16-
BaseModel<Mat, List<OCRbBox>>(reactApplicationContext) {
13+
class Detector(
14+
reactApplicationContext: ReactApplicationContext
15+
) : BaseModel<Mat, List<OCRbBox>>(reactApplicationContext) {
1716
private lateinit var originalSize: Size
1817

1918
fun getModelImageSize(): Size {
2019
val inputShape = module.getInputShape(0)
21-
val width = inputShape[inputShape.lastIndex]
22-
val height = inputShape[inputShape.lastIndex - 1]
20+
val width = inputShape[inputShape.lastIndex - 1]
21+
val height = inputShape[inputShape.lastIndex]
2322

2423
val modelImageSize = Size(height.toDouble(), width.toDouble())
2524

@@ -29,16 +28,11 @@ class Detector(reactApplicationContext: ReactApplicationContext) :
2928
override fun preprocess(input: Mat): EValue {
3029
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
3130
val resizedImage = ImageProcessor.resizeWithPadding(
32-
input,
33-
getModelImageSize().width.toInt(),
34-
getModelImageSize().height.toInt()
31+
input, getModelImageSize().width.toInt(), getModelImageSize().height.toInt()
3532
)
3633

3734
return ImageProcessor.matToEValue(
38-
resizedImage,
39-
module.getInputShape(0),
40-
Constants.MEAN,
41-
Constants.VARIANCE
35+
resizedImage, module.getInputShape(0), Constants.MEAN, Constants.VARIANCE
4236
)
4337
}
4438

@@ -48,8 +42,7 @@ class Detector(reactApplicationContext: ReactApplicationContext) :
4842
val modelImageSize = getModelImageSize()
4943

5044
val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats(
51-
outputArray,
52-
Size(modelImageSize.width / 2, modelImageSize.height / 2)
45+
outputArray, Size(modelImageSize.width / 2, modelImageSize.height / 2)
5346
)
5447
var bBoxesList = DetectorUtils.getDetBoxesFromTextMap(
5548
scoreText,
@@ -58,8 +51,10 @@ class Detector(reactApplicationContext: ReactApplicationContext) :
5851
Constants.LINK_THRESHOLD,
5952
Constants.LOW_TEXT_THRESHOLD
6053
)
54+
6155
bBoxesList =
6256
DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat())
57+
6358
bBoxesList = DetectorUtils.groupTextBoxes(
6459
bBoxesList,
6560
Constants.CENTER_THRESHOLD,

0 commit comments

Comments
 (0)