Skip to content

Commit deffbe5

Browse files
feat: ocr(ios) (#84)
## Description <!-- Provide a concise and descriptive summary of the changes implemented in this PR. --> ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 4bc1986 commit deffbe5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+3365
-20
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package com.swmansion.rnexecutorch
2+
3+
import android.util.Log
4+
import com.facebook.react.bridge.Promise
5+
import com.facebook.react.bridge.ReactApplicationContext
6+
import com.swmansion.rnexecutorch.utils.ETError
7+
import com.swmansion.rnexecutorch.utils.ImageProcessor
8+
import org.opencv.android.OpenCVLoader
9+
import com.swmansion.rnexecutorch.models.ocr.Detector
10+
import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler
11+
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
12+
import org.opencv.imgproc.Imgproc
13+
14+
class OCR(reactContext: ReactApplicationContext) :
15+
NativeOCRSpec(reactContext) {
16+
17+
private lateinit var detector: Detector
18+
private lateinit var recognitionHandler: RecognitionHandler
19+
20+
companion object {
21+
const val NAME = "OCR"
22+
}
23+
24+
init {
25+
if (!OpenCVLoader.initLocal()) {
26+
Log.d("rn_executorch", "OpenCV not loaded")
27+
} else {
28+
Log.d("rn_executorch", "OpenCV loaded")
29+
}
30+
}
31+
32+
override fun loadModule(
33+
detectorSource: String,
34+
recognizerSourceLarge: String,
35+
recognizerSourceMedium: String,
36+
recognizerSourceSmall: String,
37+
symbols: String,
38+
promise: Promise
39+
) {
40+
try {
41+
detector = Detector(reactApplicationContext)
42+
detector.loadModel(detectorSource)
43+
44+
recognitionHandler = RecognitionHandler(
45+
symbols,
46+
reactApplicationContext
47+
)
48+
49+
recognitionHandler.loadRecognizers(
50+
recognizerSourceLarge,
51+
recognizerSourceMedium,
52+
recognizerSourceSmall
53+
) { _, errorRecognizer ->
54+
if (errorRecognizer != null) {
55+
throw Error(errorRecognizer.message!!)
56+
}
57+
58+
promise.resolve(0)
59+
}
60+
} catch (e: Exception) {
61+
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
62+
}
63+
}
64+
65+
override fun forward(input: String, promise: Promise) {
66+
try {
67+
val inputImage = ImageProcessor.readImage(input)
68+
val bBoxesList = detector.runModel(inputImage)
69+
val detectorSize = detector.getModelImageSize()
70+
Imgproc.cvtColor(inputImage, inputImage, Imgproc.COLOR_BGR2GRAY)
71+
val result = recognitionHandler.recognize(
72+
bBoxesList,
73+
inputImage,
74+
(detectorSize.width * Constants.RECOGNIZER_RATIO).toInt(),
75+
(detectorSize.height * Constants.RECOGNIZER_RATIO).toInt()
76+
)
77+
promise.resolve(result)
78+
} catch (e: Exception) {
79+
Log.d("rn_executorch", "Error running model: ${e.message}")
80+
promise.reject(e.message!!, e.message)
81+
}
82+
}
83+
84+
override fun getName(): String {
85+
return NAME
86+
}
87+
}

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

+11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class RnExecutorchPackage : TurboReactPackage() {
2525
ObjectDetection(reactContext)
2626
} else if (name == SpeechToText.NAME) {
2727
SpeechToText(reactContext)
28+
} else if (name == OCR.NAME){
29+
OCR(reactContext)
2830
}
2931
else {
3032
null
@@ -85,6 +87,15 @@ class RnExecutorchPackage : TurboReactPackage() {
8587
false, // isCxxModule
8688
true
8789
)
90+
91+
moduleInfos[OCR.NAME] = ReactModuleInfo(
92+
OCR.NAME,
93+
OCR.NAME,
94+
false, // canOverrideExistingModule
95+
false, // needsEagerInit
96+
false, // isCxxModule
97+
true
98+
)
8899
moduleInfos
89100
}
90101
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package com.swmansion.rnexecutorch.models.ocr
2+
3+
import android.util.Log
4+
import com.facebook.react.bridge.ReactApplicationContext
5+
import com.swmansion.rnexecutorch.models.BaseModel
6+
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
7+
import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils
8+
import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox
9+
import com.swmansion.rnexecutorch.utils.ImageProcessor
10+
import org.opencv.core.Mat
11+
import org.opencv.core.Scalar
12+
import org.opencv.core.Size
13+
import org.pytorch.executorch.EValue
14+
15+
class Detector(reactApplicationContext: ReactApplicationContext) :
16+
BaseModel<Mat, List<OCRbBox>>(reactApplicationContext) {
17+
private lateinit var originalSize: Size
18+
19+
fun getModelImageSize(): Size {
20+
val inputShape = module.getInputShape(0)
21+
val width = inputShape[inputShape.lastIndex]
22+
val height = inputShape[inputShape.lastIndex - 1]
23+
24+
val modelImageSize = Size(height.toDouble(), width.toDouble())
25+
26+
return modelImageSize
27+
}
28+
29+
override fun preprocess(input: Mat): EValue {
30+
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
31+
val resizedImage = ImageProcessor.resizeWithPadding(
32+
input,
33+
getModelImageSize().width.toInt(),
34+
getModelImageSize().height.toInt()
35+
)
36+
37+
return ImageProcessor.matToEValue(
38+
resizedImage,
39+
module.getInputShape(0),
40+
Constants.MEAN,
41+
Constants.VARIANCE
42+
)
43+
}
44+
45+
override fun postprocess(output: Array<EValue>): List<OCRbBox> {
46+
val outputTensor = output[0].toTensor()
47+
val outputArray = outputTensor.dataAsFloatArray
48+
val modelImageSize = getModelImageSize()
49+
50+
val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats(
51+
outputArray,
52+
Size(modelImageSize.width / 2, modelImageSize.height / 2)
53+
)
54+
var bBoxesList = DetectorUtils.getDetBoxesFromTextMap(
55+
scoreText,
56+
scoreLink,
57+
Constants.TEXT_THRESHOLD,
58+
Constants.LINK_THRESHOLD,
59+
Constants.LOW_TEXT_THRESHOLD
60+
)
61+
bBoxesList =
62+
DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat())
63+
bBoxesList = DetectorUtils.groupTextBoxes(
64+
bBoxesList,
65+
Constants.CENTER_THRESHOLD,
66+
Constants.DISTANCE_THRESHOLD,
67+
Constants.HEIGHT_THRESHOLD,
68+
Constants.MIN_SIDE_THRESHOLD,
69+
Constants.MAX_SIDE_THRESHOLD,
70+
Constants.MAX_WIDTH
71+
)
72+
73+
return bBoxesList.toList()
74+
}
75+
76+
override fun runModel(input: Mat): List<OCRbBox> {
77+
return postprocess(forward(preprocess(input)))
78+
}
79+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package com.swmansion.rnexecutorch.models.ocr
2+
3+
import com.facebook.react.bridge.Arguments
4+
import com.facebook.react.bridge.ReactApplicationContext
5+
import com.facebook.react.bridge.WritableArray
6+
import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter
7+
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
8+
import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox
9+
import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils
10+
import com.swmansion.rnexecutorch.utils.ImageProcessor
11+
import org.opencv.core.Core
12+
import org.opencv.core.Mat
13+
14+
class RecognitionHandler(
15+
symbols: String,
16+
reactApplicationContext: ReactApplicationContext
17+
) {
18+
private val recognizerLarge = Recognizer(reactApplicationContext)
19+
private val recognizerMedium = Recognizer(reactApplicationContext)
20+
private val recognizerSmall = Recognizer(reactApplicationContext)
21+
private val converter = CTCLabelConverter(symbols)
22+
23+
private fun runModel(croppedImage: Mat): Pair<List<Int>, Double> {
24+
val result: Pair<List<Int>, Double> = if (croppedImage.cols() >= Constants.LARGE_MODEL_WIDTH) {
25+
recognizerLarge.runModel(croppedImage)
26+
} else if (croppedImage.cols() >= Constants.MEDIUM_MODEL_WIDTH) {
27+
recognizerMedium.runModel(croppedImage)
28+
} else {
29+
recognizerSmall.runModel(croppedImage)
30+
}
31+
32+
return result
33+
}
34+
35+
fun loadRecognizers(
36+
largeRecognizerPath: String,
37+
mediumRecognizerPath: String,
38+
smallRecognizerPath: String,
39+
onComplete: (Int, Exception?) -> Unit
40+
) {
41+
try {
42+
recognizerLarge.loadModel(largeRecognizerPath)
43+
recognizerMedium.loadModel(mediumRecognizerPath)
44+
recognizerSmall.loadModel(smallRecognizerPath)
45+
onComplete(0, null)
46+
} catch (e: Exception) {
47+
onComplete(1, e)
48+
}
49+
}
50+
51+
fun recognize(
52+
bBoxesList: List<OCRbBox>,
53+
imgGray: Mat,
54+
desiredWidth: Int,
55+
desiredHeight: Int
56+
): WritableArray {
57+
val res: WritableArray = Arguments.createArray()
58+
val ratioAndPadding = RecognizerUtils.calculateResizeRatioAndPaddings(
59+
imgGray.width(),
60+
imgGray.height(),
61+
desiredWidth,
62+
desiredHeight
63+
)
64+
65+
val left = ratioAndPadding["left"] as Int
66+
val top = ratioAndPadding["top"] as Int
67+
val resizeRatio = ratioAndPadding["resizeRatio"] as Float
68+
val resizedImg = ImageProcessor.resizeWithPadding(
69+
imgGray,
70+
desiredWidth,
71+
desiredHeight
72+
)
73+
74+
for (box in bBoxesList) {
75+
var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, Constants.MODEL_HEIGHT)
76+
if (croppedImage.empty()) {
77+
continue
78+
}
79+
80+
croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, Constants.ADJUST_CONTRAST)
81+
82+
var result = runModel(croppedImage)
83+
var confidenceScore = result.second
84+
85+
if (confidenceScore < Constants.LOW_CONFIDENCE_THRESHOLD) {
86+
Core.rotate(croppedImage, croppedImage, Core.ROTATE_180)
87+
val rotatedResult = runModel(croppedImage)
88+
val rotatedConfidenceScore = rotatedResult.second
89+
if (rotatedConfidenceScore > confidenceScore) {
90+
result = rotatedResult
91+
confidenceScore = rotatedConfidenceScore
92+
}
93+
}
94+
95+
val predIndex = result.first
96+
val decodedTexts = converter.decodeGreedy(predIndex, predIndex.size)
97+
98+
for (bBox in box.bBox) {
99+
bBox.x = (bBox.x - left) * resizeRatio
100+
bBox.y = (bBox.y - top) * resizeRatio
101+
}
102+
103+
val resMap = Arguments.createMap()
104+
105+
resMap.putString("text", decodedTexts[0])
106+
resMap.putArray("bbox", box.toWritableArray())
107+
resMap.putDouble("confidence", confidenceScore)
108+
109+
res.pushMap(resMap)
110+
}
111+
112+
return res
113+
}
114+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package com.swmansion.rnexecutorch.models.ocr
2+
3+
import com.facebook.react.bridge.ReactApplicationContext
4+
import com.swmansion.rnexecutorch.models.BaseModel
5+
import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils
6+
import com.swmansion.rnexecutorch.utils.ImageProcessor
7+
import org.opencv.core.Mat
8+
import org.opencv.core.Size
9+
import org.pytorch.executorch.EValue
10+
11+
class Recognizer(reactApplicationContext: ReactApplicationContext) :
12+
BaseModel<Mat, Pair<List<Int>, Double>>(reactApplicationContext) {
13+
14+
private fun getModelOutputSize(): Size {
15+
val outputShape = module.getOutputShape(0)
16+
val width = outputShape[outputShape.lastIndex]
17+
val height = outputShape[outputShape.lastIndex - 1]
18+
19+
return Size(height.toDouble(), width.toDouble())
20+
}
21+
22+
override fun preprocess(input: Mat): EValue {
23+
return ImageProcessor.matToEValueGray(input)
24+
}
25+
26+
override fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> {
27+
val modelOutputHeight = getModelOutputSize().height.toInt()
28+
val tensor = output[0].toTensor().dataAsFloatArray
29+
val numElements = tensor.size
30+
val numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight
31+
val resultMat = Mat(numRows, modelOutputHeight, org.opencv.core.CvType.CV_32F)
32+
var counter = 0
33+
var currentRow = 0
34+
for (num in tensor) {
35+
resultMat.put(currentRow, counter, floatArrayOf(num))
36+
counter++
37+
if (counter >= modelOutputHeight) {
38+
counter = 0
39+
currentRow++
40+
}
41+
}
42+
43+
var probabilities = RecognizerUtils.softmax(resultMat)
44+
val predsNorm = RecognizerUtils.sumProbabilityRows(probabilities, modelOutputHeight)
45+
probabilities = RecognizerUtils.divideMatrixByVector(probabilities, predsNorm)
46+
val (values, indices) = RecognizerUtils.findMaxValuesAndIndices(probabilities)
47+
48+
val confidenceScore = RecognizerUtils.computeConfidenceScore(values, indices)
49+
return Pair(indices, confidenceScore)
50+
}
51+
52+
53+
override fun runModel(input: Mat): Pair<List<Int>, Double> {
54+
return postprocess(module.forward(preprocess(input)))
55+
}
56+
}

0 commit comments

Comments
 (0)