From 76b47f15c86734442607a9f23dab8f1799cd5f10 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 26 Feb 2025 17:20:25 +0100 Subject: [PATCH 01/12] feat: implemented vertical ocr --- .../rnexecutorch/RnExecutorchPackage.kt | 38 ++-- .../com/swmansion/rnexecutorch/VerticalOCR.kt | 167 ++++++++++++++++++ .../rnexecutorch/models/ocr/Detector.kt | 54 ++++-- .../models/ocr/utils/Constants.kt | 3 + .../models/ocr/utils/DetectorUtils.kt | 94 ++++++++++ .../models/ocr/utils/RecognizerUtils.kt | 42 ++++- ios/RnExecutorch/VerticalOCR.h | 7 + ios/RnExecutorch/VerticalOCR.mm | 159 +++++++++++++++++ ios/RnExecutorch/models/ocr/Detector.h | 3 + ios/RnExecutorch/models/ocr/Detector.mm | 52 ++++-- .../models/ocr/utils/DetectorUtils.h | 7 +- .../models/ocr/utils/DetectorUtils.mm | 109 +++++++++++- .../models/ocr/utils/RecognizerUtils.h | 1 + .../models/ocr/utils/RecognizerUtils.mm | 36 +++- src/hooks/computer_vision/useVerticalOCR.ts | 132 ++++++++++++++ src/index.tsx | 1 + src/native/NativeVerticalOCR.ts | 16 ++ src/native/RnExecutorchModules.ts | 14 ++ 18 files changed, 885 insertions(+), 50 deletions(-) create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt create mode 100644 ios/RnExecutorch/VerticalOCR.h create mode 100644 ios/RnExecutorch/VerticalOCR.mm create mode 100644 src/hooks/computer_vision/useVerticalOCR.ts create mode 100644 src/native/NativeVerticalOCR.ts diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index 0ec2a51c..64f74661 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -25,10 +25,11 @@ class RnExecutorchPackage : TurboReactPackage() { ObjectDetection(reactContext) } else if (name == SpeechToText.NAME) { SpeechToText(reactContext) - } else if (name == OCR.NAME){ + } else if (name == OCR.NAME) { OCR(reactContext) - } - else { + } else if (name == VerticalOCR.NAME) { + VerticalOCR(reactContext) + } else { null } @@ -44,54 +45,49 @@ class RnExecutorchPackage : TurboReactPackage() { true, ) moduleInfos[ETModule.NAME] = ReactModuleInfo( - ETModule.NAME, - ETModule.NAME, - false, // canOverrideExistingModule + ETModule.NAME, ETModule.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true ) moduleInfos[StyleTransfer.NAME] = ReactModuleInfo( - StyleTransfer.NAME, - StyleTransfer.NAME, - false, // canOverrideExistingModule + StyleTransfer.NAME, StyleTransfer.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true ) moduleInfos[Classification.NAME] = ReactModuleInfo( - Classification.NAME, - Classification.NAME, - false, // canOverrideExistingModule + Classification.NAME, Classification.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true ) moduleInfos[ObjectDetection.NAME] = ReactModuleInfo( - ObjectDetection.NAME, - ObjectDetection.NAME, - false, // canOverrideExistingModule + ObjectDetection.NAME, ObjectDetection.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true ) moduleInfos[SpeechToText.NAME] = ReactModuleInfo( - SpeechToText.NAME, - SpeechToText.NAME, - false, // canOverrideExistingModule + SpeechToText.NAME, SpeechToText.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true ) moduleInfos[OCR.NAME] = ReactModuleInfo( - OCR.NAME, - OCR.NAME, - false, // canOverrideExistingModule + OCR.NAME, OCR.NAME, false, // canOverrideExistingModule + false, // needsEagerInit + false, // isCxxModule + true + ) + + moduleInfos[VerticalOCR.NAME] = ReactModuleInfo( + VerticalOCR.NAME, VerticalOCR.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true diff --git a/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt b/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt new file mode 100644 index 00000000..945a83c0 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt @@ -0,0 +1,167 @@ +package com.swmansion.rnexecutorch + +import android.media.Image +import android.util.Log +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.Promise +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.utils.ETError +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.android.OpenCVLoader +import com.swmansion.rnexecutorch.models.ocr.Detector +import com.swmansion.rnexecutorch.models.ocr.Recognizer +import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter +import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils +import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils +import org.opencv.core.Core +import org.opencv.core.Mat +import org.opencv.core.MatOfPoint +import org.opencv.core.Point +import org.opencv.imgproc.Imgproc + +class VerticalOCR(reactContext: ReactApplicationContext) : + NativeVerticalOCRSpec(reactContext) { + + private lateinit var detectorLarge: Detector + private lateinit var detectorNarrow: Detector + private lateinit var recognizer: Recognizer + private lateinit var converter: CTCLabelConverter + private var independentCharacters = true + + companion object { + const val NAME = "VerticalOCR" + } + + init { + if (!OpenCVLoader.initLocal()) { + Log.d("rn_executorch", "OpenCV not loaded") + } else { + Log.d("rn_executorch", "OpenCV loaded") + } + } + + override fun loadModule( + detectorLargeSource: String, + detectorNarrowSource: String, + recognizerSource: String, + symbols: String, + independentCharacters: Boolean, + promise: Promise + ) { + try { + this.independentCharacters = independentCharacters + detectorLarge = Detector(true, false, reactApplicationContext) + detectorLarge.loadModel(detectorLargeSource) + detectorNarrow = Detector(true, true, reactApplicationContext) + detectorNarrow.loadModel(detectorNarrowSource) + recognizer = Recognizer(reactApplicationContext) + recognizer.loadModel(recognizerSource) + + converter = CTCLabelConverter(symbols) + + promise.resolve(0) + } catch (e: Exception) { + promise.reject(e.message!!, ETError.InvalidModelSource.toString()) + } + } + + override fun forward(input: String, promise: Promise) { + try { + val inputImage = ImageProcessor.readImage(input) + val result = detectorLarge.runModel(inputImage) + + val resizedImage = ImageProcessor.resizeWithPadding(inputImage, 1280, 1280) + val predictions = Arguments.createArray() + for (box in result) { + val coords = box.bBox + val boxWidth = coords[2].x - coords[0].x + val boxHeight = coords[2].y - coords[0].y + val points = arrayOfNulls(4) + + for (i in 0 until 4) { + points.set(i, Point(coords[i].x, coords[i].y)) + } + + val boundingBox = Imgproc.boundingRect(MatOfPoint(*points)) + val croppedImage = Mat(resizedImage, boundingBox) + + val ratioAndPadding = RecognizerUtils.calculateResizeRatioAndPaddings( + inputImage.width(), + inputImage.height(), + 1280, + 1280 + ) + + var text = "" + var confidenceScore = 0.0 + var detectionResult = detectorNarrow.runModel(croppedImage) + + var croppedCharacters = mutableListOf() + + for (bbox in detectionResult) { + val coords2 = bbox.bBox + var paddingsSingle = RecognizerUtils.calculateResizeRatioAndPaddings( + boxWidth.toInt(), boxHeight.toInt(), 320, 1280 + ) + + var croppedCharacter = RecognizerUtils.cropImageWithBoundingBox( + inputImage, + coords2, + coords, + paddingsSingle, + ratioAndPadding + ) + if (this.independentCharacters) { + croppedCharacter = RecognizerUtils.normalizeForRecognizer(croppedCharacter, 0.0) + val recognitionResult = recognizer.runModel(croppedCharacter) + val predIndex = recognitionResult.first + val decodedText = converter.decodeGreedy(predIndex, predIndex.size) + text += decodedText[0] + confidenceScore += recognitionResult.second + } else { + croppedCharacters.add(croppedCharacter) + } + } + + if (this.independentCharacters) { + confidenceScore /= detectionResult.size + } else { + var mergedCharacters = Mat() + Core.hconcat(croppedCharacters, mergedCharacters) + mergedCharacters = ImageProcessor.resizeWithPadding(mergedCharacters, 512, 64) + mergedCharacters = RecognizerUtils.normalizeForRecognizer(mergedCharacters, 0.0) + + val recognitionResult = recognizer.runModel(mergedCharacters) + val predIndex = recognitionResult.first + val decodedText = converter.decodeGreedy(predIndex, predIndex.size) + + text = decodedText[0] + confidenceScore = recognitionResult.second + } + + for (bBox in box.bBox) { + bBox.x = + (bBox.x - ratioAndPadding["left"] as Int) * ratioAndPadding["resizeRatio"] as Float + bBox.y = + (bBox.y - ratioAndPadding["top"] as Int) * ratioAndPadding["resizeRatio"] as Float + } + + val resMap = Arguments.createMap() + + resMap.putString("text", text) + resMap.putArray("bbox", box.toWritableArray()) + resMap.putDouble("confidence", confidenceScore) + + predictions.pushMap(resMap) + } + promise.resolve(predictions) + } catch (e: Exception) { + Log.d("rn_executorch", "Error running model: ${e.message}") + promise.reject(e.message!!, e.message) + } + } + + override fun getName(): String { + return NAME + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt index 85976e22..87b58d22 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt @@ -12,14 +12,18 @@ import org.opencv.core.Scalar import org.opencv.core.Size import org.pytorch.executorch.EValue -class Detector(reactApplicationContext: ReactApplicationContext) : +class Detector( + val isVertical: Boolean, + val detectSingleCharacter: Boolean, + reactApplicationContext: ReactApplicationContext +) : BaseModel>(reactApplicationContext) { private lateinit var originalSize: Size fun getModelImageSize(): Size { val inputShape = module.getInputShape(0) - val width = inputShape[inputShape.lastIndex] - val height = inputShape[inputShape.lastIndex - 1] + val width = inputShape[inputShape.lastIndex - 1] + val height = inputShape[inputShape.lastIndex] val modelImageSize = Size(height.toDouble(), width.toDouble()) @@ -51,15 +55,41 @@ class Detector(reactApplicationContext: ReactApplicationContext) : outputArray, Size(modelImageSize.width / 2, modelImageSize.height / 2) ) - var bBoxesList = DetectorUtils.getDetBoxesFromTextMap( - scoreText, - scoreLink, - Constants.TEXT_THRESHOLD, - Constants.LINK_THRESHOLD, - Constants.LOW_TEXT_THRESHOLD - ) - bBoxesList = - DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat()) + var bBoxesList: MutableList = mutableListOf() + if (!isVertical) { + bBoxesList = DetectorUtils.getDetBoxesFromTextMap( + scoreText, + scoreLink, + Constants.TEXT_THRESHOLD, + Constants.LINK_THRESHOLD, + Constants.LOW_TEXT_THRESHOLD + ) + + bBoxesList = + DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat()) + } else { + var txtThreshold = Constants.TEXT_THRESHOLD + + if (!detectSingleCharacter) { + txtThreshold = Constants.TEXT_THRESHOLD_VERTICAL + } + + bBoxesList = DetectorUtils.getDetBoxesFromTextMapVertical( + scoreText, + scoreLink, + txtThreshold, + Constants.LINK_THRESHOLD, + detectSingleCharacter + ) + + bBoxesList = + DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RESTORE_RATIO_VERTICAL).toFloat()) + + if (detectSingleCharacter) { + return bBoxesList + } + } + bBoxesList = DetectorUtils.groupTextBoxes( bBoxesList, Constants.CENTER_THRESHOLD, diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt index b49232f4..6e65bd64 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt @@ -5,13 +5,16 @@ import org.opencv.core.Scalar class Constants { companion object { const val RECOGNIZER_RATIO = 1.6 + const val RESTORE_RATIO_VERTICAL = 2.0 const val MODEL_HEIGHT = 64 const val LARGE_MODEL_WIDTH = 512 const val MEDIUM_MODEL_WIDTH = 256 const val SMALL_MODEL_WIDTH = 128 + const val VERTICAL_SMALL_MODEL_WIDTH = 64 const val LOW_CONFIDENCE_THRESHOLD = 0.3 const val ADJUST_CONTRAST = 0.2 const val TEXT_THRESHOLD = 0.4 + const val TEXT_THRESHOLD_VERTICAL = 0.3 const val LINK_THRESHOLD = 0.4 const val LOW_TEXT_THRESHOLD = 0.7 const val CENTER_THRESHOLD = 0.5 diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt index 4beb7ecf..992a4014 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt @@ -1,5 +1,6 @@ package com.swmansion.rnexecutorch.models.ocr.utils +import android.util.Log import com.facebook.react.bridge.Arguments import com.facebook.react.bridge.WritableArray import org.opencv.core.Core @@ -288,6 +289,97 @@ class DetectorUtils { return Pair(mat1, mat2) } + fun getDetBoxesFromTextMapVertical( + textMap: Mat, + affinityMap: Mat, + textThreshold: Double, + linkThreshold: Double, + independentCharacters: Boolean + ): MutableList { + val imgH = textMap.rows() + val imgW = textMap.cols() + + val textScore = Mat() + val affinityScore = Mat() + Imgproc.threshold(textMap, textScore, textThreshold, 1.0, Imgproc.THRESH_BINARY) + Imgproc.threshold(affinityMap, affinityScore, linkThreshold, 1.0, Imgproc.THRESH_BINARY) + val textScoreComb = Mat() + val kernel = Imgproc.getStructuringElement( + Imgproc.MORPH_RECT, + Size(3.0, 3.0) + ) + if (independentCharacters) { + Core.subtract(textScore, affinityScore, textScoreComb) + Imgproc.threshold(textScoreComb, textScoreComb, 0.0, 0.0, Imgproc.THRESH_TOZERO) + Imgproc.threshold(textScoreComb, textScoreComb, 1.0, 1.0, Imgproc.THRESH_TRUNC) + Imgproc.erode(textScoreComb, textScoreComb, kernel, Point(-1.0, -1.0), 1) + Imgproc.dilate(textScoreComb, textScoreComb, kernel, Point(-1.0, -1.0), 4) + } else { + Core.add(textScore, affinityScore, textScoreComb) + Imgproc.threshold(textScoreComb, textScoreComb, 0.0, 0.0, Imgproc.THRESH_TOZERO) + Imgproc.threshold(textScoreComb, textScoreComb, 1.0, 1.0, Imgproc.THRESH_TRUNC) + Imgproc.dilate(textScoreComb, textScoreComb, kernel, Point(-1.0, -1.0), 2) + } + + val binaryMat = Mat() + textScoreComb.convertTo(binaryMat, CvType.CV_8UC1) + + val labels = Mat() + val stats = Mat() + val centroids = Mat() + val nLabels = Imgproc.connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4) + + val detectedBoxes = mutableListOf() + for (i in 1 until nLabels) { + val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() + val height = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() + val width = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() + if (area < 20) continue + + if (!independentCharacters && height < width) continue + val mask = createMaskFromLabels(labels, i) + + val segMap = Mat.zeros(textMap.size(), CvType.CV_8U) + segMap.setTo(Scalar(255.0), mask) + + val x = stats.get(i, Imgproc.CC_STAT_LEFT)[0].toInt() + val y = stats.get(i, Imgproc.CC_STAT_TOP)[0].toInt() + val w = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() + val h = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() + val dilationRadius = (sqrt(area / max(w, h).toDouble()) * 2.0).toInt() + val sx = max(x - dilationRadius, 0) + val ex = min(x + w + dilationRadius + 1, imgW) + val sy = max(y - dilationRadius, 0) + val ey = min(y + h + dilationRadius + 1, imgH) + val roi = Rect(sx, sy, ex - sx, ey - sy) + val kernel = Imgproc.getStructuringElement( + Imgproc.MORPH_RECT, + Size((1 + dilationRadius).toDouble(), (1 + dilationRadius).toDouble()) + ) + val roiSegMap = Mat(segMap, roi) + Imgproc.dilate(roiSegMap, roiSegMap, kernel, Point(-1.0, -1.0), 2) + + val contours: List = ArrayList() + Imgproc.findContours( + segMap, + contours, + Mat(), + Imgproc.RETR_EXTERNAL, + Imgproc.CHAIN_APPROX_SIMPLE + ) + if (contours.isNotEmpty()) { + val minRect = Imgproc.minAreaRect(MatOfPoint2f(*contours[0].toArray())) + val points = Array(4) { Point() } + minRect.points(points) + val pointsList = points.map { point -> BBoxPoint(point.x, point.y) } + val boxInfo = OCRbBox(pointsList, minRect.angle) + detectedBoxes.add(boxInfo) + } + } + + return detectedBoxes + } + fun getDetBoxesFromTextMap( textMap: Mat, affinityMap: Mat, @@ -435,6 +527,8 @@ class DetectorUtils { mergedArray = removeSmallBoxes(mergedArray, minSideThreshold, maxSideThreshold) mergedArray = mergedArray.sortedWith(compareBy { minimumYFromBox(it.bBox) }).toMutableList() + mergedArray = mergedArray.map { box -> orderPointsClockwise(box) }.toMutableList() + return mergedArray } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt index 99adcad9..91b99f4b 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt @@ -255,7 +255,8 @@ class RecognizerUtils { val desiredWidth = when { img.width() >= Constants.LARGE_MODEL_WIDTH -> Constants.LARGE_MODEL_WIDTH img.width() >= Constants.MEDIUM_MODEL_WIDTH -> Constants.MEDIUM_MODEL_WIDTH - else -> Constants.SMALL_MODEL_WIDTH + img.width() >= Constants.SMALL_MODEL_WIDTH -> Constants.SMALL_MODEL_WIDTH + else -> Constants.VERTICAL_SMALL_MODEL_WIDTH } img = ImageProcessor.resizeWithPadding(img, desiredWidth, Constants.MODEL_HEIGHT) @@ -265,5 +266,44 @@ class RecognizerUtils { return img } + + fun cropImageWithBoundingBox( + image: Mat, + bbox: List, + originalBbox: List, + paddings: Map, + originalPaddings: Map + ): Mat { + var topLeft = originalBbox[0] + val points = arrayOfNulls(4) + + for (i in 0 until 4) { + val coords = bbox[i] + coords.x -= paddings["left"]!! as Int + coords.y -= paddings["top"]!! as Int + + coords.x *= paddings["resizeRatio"]!! as Float + coords.y *= paddings["resizeRatio"]!! as Float + + coords.x += topLeft.x + coords.y += topLeft.y + + coords.x -= originalPaddings["left"]!! as Int + coords.y -= (originalPaddings["top"]!! as Int) + + coords.x *= originalPaddings["resizeRatio"]!! as Float + coords.y *= originalPaddings["resizeRatio"]!! as Float + + points[i] = Point(coords.x, coords.y) + } + + val boundingBox = Imgproc.boundingRect(MatOfPoint2f(*points)) + val croppedImage = Mat(image, boundingBox) + Imgproc.cvtColor(croppedImage, croppedImage, Imgproc.COLOR_BGR2GRAY) + Imgproc.resize(croppedImage, croppedImage, Size(64.0, 64.0), 0.0, 0.0, Imgproc.INTER_LANCZOS4) + Imgproc.medianBlur(croppedImage, croppedImage, 1) + + return croppedImage + } } } diff --git a/ios/RnExecutorch/VerticalOCR.h b/ios/RnExecutorch/VerticalOCR.h new file mode 100644 index 00000000..ee19e11e --- /dev/null +++ b/ios/RnExecutorch/VerticalOCR.h @@ -0,0 +1,7 @@ +#import + +constexpr CGFloat recognizerRatio = 1.6; + +@interface VerticalOCR : NSObject + +@end diff --git a/ios/RnExecutorch/VerticalOCR.mm b/ios/RnExecutorch/VerticalOCR.mm new file mode 100644 index 00000000..14940ecb --- /dev/null +++ b/ios/RnExecutorch/VerticalOCR.mm @@ -0,0 +1,159 @@ +#import "VerticalOCR.h" +#import "models/ocr/Detector.h" +#import "models/ocr/RecognitionHandler.h" +#import "models/ocr/Recognizer.h" +#import "models/ocr/utils/RecognizerUtils.h" +#import "utils/ImageProcessor.h" +#import +#import +#import "models/ocr/utils/OCRUtils.h" +#import "models/ocr/utils/CTCLabelConverter.h" + +@implementation VerticalOCR { + Detector *detectorLarge; + Detector *detectorNarrow; + Recognizer *recognizer; + CTCLabelConverter *converter; + BOOL independentCharacters; +} + +RCT_EXPORT_MODULE() + +- (void)loadModule:(NSString *)detectorLargeSource +detectorNarrowSource:(NSString *)detectorNarrowSource + recognizerSource:(NSString *)recognizerSource + symbols:(NSString *)symbols +independentCharacters:(BOOL)independentCharacters + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { + NSLog(@"%@", recognizerSource); + detectorLarge = [[Detector alloc] initWithIsVertical:YES detectSingleCharacters: NO]; + converter = [[CTCLabelConverter alloc] initWithCharacters:symbols separatorList:@{}]; + self->independentCharacters = independentCharacters; + [detectorLarge + loadModel:[NSURL URLWithString:detectorLargeSource] + completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + reject(@"init_module_error", @"Failed to initialize detector module", + nil); + return; + } + self->detectorNarrow = [[Detector alloc] initWithIsVertical:YES detectSingleCharacters:YES]; + [self->detectorNarrow + loadModel:[NSURL URLWithString:detectorNarrowSource] + completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + reject(@"init_module_error", + @"Failed to initialize detector module", nil); + return; + } + + self->recognizer = [[Recognizer alloc] init]; + [self->recognizer + loadModel:[NSURL URLWithString:recognizerSource] + completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + reject(@"init_module_error", + @"Failed to initialize recognizer module", nil); + } + + resolve(@(YES)); + }]; + }]; + }]; + +} + +- (void)forward:(NSString *)input + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { + @try { + cv::Mat image = [ImageProcessor readImage:input]; + NSArray *result = [detectorLarge runModel:image]; + cv::Mat resizedImage = [OCRUtils resizeWithPadding:image desiredWidth:1280 desiredHeight:1280]; + NSMutableArray *predictions = [NSMutableArray array]; + for (NSDictionary *box in result){ + NSArray *coords = box[@"bbox"]; + const int boxWidth = [[coords objectAtIndex:2] CGPointValue].x - [[coords objectAtIndex:0] CGPointValue].x; + const int boxHeight = [[coords objectAtIndex:2] CGPointValue].y - [[coords objectAtIndex:0] CGPointValue].y; + std::vector points; + for (NSValue *value in coords) { + const CGPoint point = [value CGPointValue]; + points.emplace_back(static_cast(point.x), + static_cast(point.y)); + } + + cv::Rect boundingBox = cv::boundingRect(points); + cv::Mat croppedImage = resizedImage(boundingBox); + NSDictionary *ratioAndPadding = + [RecognizerUtils calculateResizeRatioAndPaddings:image.cols + height:image.rows + desiredWidth:1280 + desiredHeight:1280]; + + NSString *text = @""; + NSNumber *confidenceScore = @0.0; + NSArray *detectionResult = [detectorNarrow runModel:croppedImage]; + std::vector croppedCharacters; + for(NSDictionary *bbox in detectionResult){ + NSArray *coords2 = bbox[@"bbox"]; + NSDictionary *paddingsSingle = [RecognizerUtils calculateResizeRatioAndPaddings:boxWidth height:boxHeight desiredWidth:320 desiredHeight:1280]; + cv::Mat croppedCharacter = [RecognizerUtils cropImageWithBoundingBox:image bbox:coords2 originalBbox:coords paddings:paddingsSingle originalPaddings:ratioAndPadding]; + if(self->independentCharacters){ + croppedCharacter = [RecognizerUtils normalizeForRecognizer:croppedCharacter adjustContrast:0.0]; + NSArray *recognitionResult = [recognizer runModel:croppedCharacter]; + NSArray *predIndex = [recognitionResult objectAtIndex:0]; + NSArray *decodedText = [converter decodeGreedy: predIndex length:(int)(predIndex.count)]; + text = [text stringByAppendingString:decodedText[0]]; + confidenceScore = @([confidenceScore floatValue] + [[recognitionResult objectAtIndex:1] floatValue]); + }else{ + croppedCharacters.push_back(croppedCharacter); + } + } + + if(self->independentCharacters){ + confidenceScore = @([confidenceScore floatValue] / detectionResult.count); + }else{ + cv::Mat mergedCharacters; + cv::hconcat(croppedCharacters.data(), (int)croppedCharacters.size(), mergedCharacters); + mergedCharacters = [OCRUtils resizeWithPadding:mergedCharacters desiredWidth:512 desiredHeight:64]; + mergedCharacters = [RecognizerUtils normalizeForRecognizer:mergedCharacters adjustContrast:0.0]; + NSArray *recognitionResult = [recognizer runModel:mergedCharacters]; + NSArray *predIndex = [recognitionResult objectAtIndex:0]; + NSArray *decodedText = [converter decodeGreedy: predIndex length:(int)(predIndex.count)]; + text = [text stringByAppendingString:decodedText[0]]; + confidenceScore = @([confidenceScore floatValue] + [[recognitionResult objectAtIndex:1] floatValue]); + } + + NSMutableArray *newCoords = [NSMutableArray arrayWithCapacity:4]; + for (NSValue *coord in coords) { + const CGPoint point = [coord CGPointValue]; + + [newCoords addObject:@{ + @"x" : @((point.x - [ratioAndPadding[@"left"] intValue]) * [ratioAndPadding[@"resizeRatio"] floatValue]), + @"y" : @((point.y - [ratioAndPadding[@"top"] intValue]) * [ratioAndPadding[@"resizeRatio"] floatValue]) + }]; + } + + NSDictionary *res = @{ + @"text" : text, + @"bbox" : newCoords, + @"score" : confidenceScore + }; + [predictions addObject:res]; + } + + + resolve(predictions); + } @catch (NSException *exception) { + reject(@"forward_error", + [NSString stringWithFormat:@"%@", exception.reason], nil); + } +} + +- (std::shared_ptr)getTurboModule: +(const facebook::react::ObjCTurboModule::InitParams &)params { + return std::make_shared(params); +} + +@end diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h index 0f67e93b..562bf92b 100644 --- a/ios/RnExecutorch/models/ocr/Detector.h +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -3,12 +3,14 @@ #import "opencv2/opencv.hpp" constexpr CGFloat textThreshold = 0.4; +constexpr CGFloat textThresholdVertical = 0.3; constexpr CGFloat linkThreshold = 0.4; constexpr CGFloat lowTextThreshold = 0.7; constexpr CGFloat centerThreshold = 0.5; constexpr CGFloat distanceThreshold = 2.0; constexpr CGFloat heightThreshold = 2.0; constexpr CGFloat restoreRatio = 3.2; +constexpr CGFloat restoreRatioVertical = 2.0; constexpr int minSideThreshold = 15; constexpr int maxSideThreshold = 30; constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); @@ -19,6 +21,7 @@ const cv::Scalar variance(0.229, 0.224, 0.225); @interface Detector : BaseModel +- (instancetype)initWithIsVertical:(BOOL)isVertical detectSingleCharacters:(BOOL)detectSingleCharacters; - (cv::Size)getModelImageSize; - (NSArray *)runModel:(cv::Mat &)input; diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 20b82b5e..68cd72ea 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -11,6 +11,18 @@ The model used as detector is based on CRAFT (Character Region Awareness for @implementation Detector { cv::Size originalSize; cv::Size modelSize; + BOOL isVertical; + BOOL detectSingleCharacters; +} + +- (instancetype)initWithIsVertical:(BOOL)isVertical + detectSingleCharacters:(BOOL)detectSingleCharacters { + self = [super init]; + if (self) { + self->isVertical = isVertical; + self->detectSingleCharacters = detectSingleCharacters; + } + return self; } - (cv::Size)getModelImageSize { @@ -19,8 +31,8 @@ @implementation Detector { } NSArray *inputShape = [module getInputShape:@0]; - NSNumber *widthNumber = inputShape.lastObject; - NSNumber *heightNumber = inputShape[inputShape.count - 2]; + NSNumber *widthNumber = inputShape[inputShape.count - 2]; + NSNumber *heightNumber = inputShape.lastObject; const int height = [heightNumber intValue]; const int width = [widthNumber intValue]; @@ -36,7 +48,6 @@ - (NSArray *)preprocess:(cv::Mat &)input { original aspect ratio and the missing parts are filled with padding. */ self->originalSize = cv::Size(input.cols, input.rows); - cv::Size modelImageSize = [self getModelImageSize]; cv::Mat resizedImage; resizedImage = [OCRUtils resizeWithPadding:input @@ -72,13 +83,34 @@ group each character into a single instance (sequence) Both matrices are outputMat2:scoreAffinityCV withSize:cv::Size(modelImageSize.width / 2, modelImageSize.height / 2)]; - NSArray *bBoxesList = [DetectorUtils getDetBoxesFromTextMap:scoreTextCV - affinityMap:scoreAffinityCV - usingTextThreshold:textThreshold - linkThreshold:linkThreshold - lowTextThreshold:lowTextThreshold]; - bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList - usingRestoreRatio:restoreRatio]; + NSArray *bBoxesList; + if (!self->isVertical) { + bBoxesList = [DetectorUtils getDetBoxesFromTextMap:scoreTextCV + affinityMap:scoreAffinityCV + usingTextThreshold:textThreshold + linkThreshold:linkThreshold + lowTextThreshold:lowTextThreshold]; + bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList + usingRestoreRatio:restoreRatio]; + } else if (self->isVertical) { + CGFloat txtThreshold = textThreshold; + if (!self->detectSingleCharacters) { + txtThreshold = textThresholdVertical; + } + bBoxesList = + [DetectorUtils getDetBoxesFromTextMapVertical:scoreTextCV + affinityMap:scoreAffinityCV + usingTextThreshold:txtThreshold + linkThreshold:linkThreshold + independentCharacters:self->detectSingleCharacters]; + bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList + usingRestoreRatio:restoreRatioVertical]; + + if (self->detectSingleCharacters){ + return bBoxesList; + } + } + bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList centerThreshold:centerThreshold distanceThreshold:distanceThreshold diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h index 3f205b8e..70467169 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -13,9 +13,14 @@ constexpr int verticalLineThreshold = 20; usingTextThreshold:(CGFloat)textThreshold linkThreshold:(CGFloat)linkThreshold lowTextThreshold:(CGFloat)lowTextThreshold; ++ (NSArray *)getDetBoxesFromTextMapVertical:(cv::Mat)textMap + affinityMap:(cv::Mat)affinityMap + usingTextThreshold:(CGFloat)textThreshold + linkThreshold:(CGFloat)linkThreshold + independentCharacters:(BOOL)independentCharacters; + (NSArray *)restoreBboxRatio:(NSArray *)boxes usingRestoreRatio:(CGFloat)restoreRatio; -+ (NSArray *)groupTextBoxes:(NSArray *)polys ++ (NSArray *)groupTextBoxes:(NSArray *)polys centerThreshold:(CGFloat)centerThreshold distanceThreshold:(CGFloat)distanceThreshold heightThreshold:(CGFloat)heightThreshold diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index 8ee7424d..62ed9fa1 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -22,6 +22,98 @@ + (void)interleavedArrayToMats:(NSArray *)array } } ++ (NSArray *)getDetBoxesFromTextMapVertical:(cv::Mat)textMap + affinityMap:(cv::Mat)affinityMap + usingTextThreshold:(CGFloat)textThreshold + linkThreshold:(CGFloat)linkThreshold + independentCharacters:(BOOL)independentCharacters { + const int imgH = textMap.rows; + const int imgW = textMap.cols; + cv::Mat textScore; + cv::Mat affinityScore; + cv::threshold(textMap, textScore, textThreshold, 1, cv::THRESH_BINARY); + cv::threshold(affinityMap, affinityScore, linkThreshold, 1, + cv::THRESH_BINARY); + cv::Mat textScoreComb; + if (independentCharacters) { + textScoreComb = textScore - affinityScore; + cv::threshold(textScoreComb, textScoreComb, 0.0, 0, cv::THRESH_TOZERO); + cv::threshold(textScoreComb, textScoreComb, 1.0, 1.0, cv::THRESH_TRUNC); + cv::erode(textScoreComb, textScoreComb, + cv::getStructuringElement(cv::MORPH_RECT, cv::Size(3, 3)), + cv::Point(-1, -1), 1); + cv::dilate(textScoreComb, textScoreComb, + cv::getStructuringElement(cv::MORPH_RECT, cv::Size(3, 3)), + cv::Point(-1, -1), 4); + } else { + textScoreComb = textScore + affinityScore; + cv::threshold(textScoreComb, textScoreComb, 0.0, 0, cv::THRESH_TOZERO); + cv::threshold(textScoreComb, textScoreComb, 1.0, 1.0, cv::THRESH_TRUNC); + cv::dilate(textScoreComb, textScoreComb, + cv::getStructuringElement(cv::MORPH_RECT, cv::Size(3, 3)), + cv::Point(-1, -1), 2); + } + + cv::Mat binaryMat; + textScoreComb.convertTo(binaryMat, CV_8UC1); + + cv::Mat labels, stats, centroids; + const int nLabels = + cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4); + + NSMutableArray *detectedBoxes = [NSMutableArray array]; + for (int i = 1; i < nLabels; i++) { + const int area = stats.at(i, cv::CC_STAT_AREA); + const int width = stats.at(i, cv::CC_STAT_WIDTH); + const int height = stats.at(i, cv::CC_STAT_HEIGHT); + if (area < 20) + continue; + + if (!independentCharacters && height < width) + continue; + + cv::Mat mask = (labels == i); + + cv::Mat segMap = cv::Mat::zeros(textMap.size(), CV_8U); + segMap.setTo(255, mask); + + const int x = stats.at(i, cv::CC_STAT_LEFT); + const int y = stats.at(i, cv::CC_STAT_TOP); + const int w = stats.at(i, cv::CC_STAT_WIDTH); + const int h = stats.at(i, cv::CC_STAT_HEIGHT); + const int dilationRadius = (int)(sqrt((double)(area / MAX(w, h))) * 2.0); + const int sx = MAX(x - dilationRadius, 0); + const int ex = MIN(x + w + dilationRadius + 1, imgW); + const int sy = MAX(y - dilationRadius, 0); + const int ey = MIN(y + h + dilationRadius + 1, imgH); + + cv::Rect roi(sx, sy, ex - sx, ey - sy); + cv::Mat kernel = cv::getStructuringElement( + cv::MORPH_RECT, cv::Size(1 + dilationRadius, 1 + dilationRadius)); + cv::Mat roiSegMap = segMap(roi); + cv::dilate(roiSegMap, roiSegMap, kernel, cv::Point(-1, -1), 2); + + std::vector> contours; + cv::findContours(segMap, contours, cv::RETR_EXTERNAL, + cv::CHAIN_APPROX_SIMPLE); + if (!contours.empty()) { + cv::RotatedRect minRect = cv::minAreaRect(contours[0]); + cv::Point2f vertices[4]; + minRect.points(vertices); + NSMutableArray *pointsArray = [NSMutableArray arrayWithCapacity:4]; + for (int j = 0; j < 4; j++) { + const CGPoint point = CGPointMake(vertices[j].x, vertices[j].y); + [pointsArray addObject:[NSValue valueWithCGPoint:point]]; + } + NSDictionary *dict = + @{@"bbox" : pointsArray, @"angle" : @(minRect.angle)}; + [detectedBoxes addObject:dict]; + } + } + + return detectedBoxes; +} + /** * This method applies a series of image processing operations to identify * likely areas of text in the textMap and return the bounding boxes for single @@ -545,7 +637,7 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes * criteria. * 4. Sort the final array of boxes by their vertical positions. */ -+ (NSArray *)groupTextBoxes: ++ (NSArray *)groupTextBoxes: (NSMutableArray *)boxes centerThreshold:(CGFloat)centerThreshold distanceThreshold:(CGFloat)distanceThreshold @@ -635,7 +727,7 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes usingMinSideThreshold:minSideThreshold maxSideThreshold:maxSideThreshold]; - NSArray *sortedBoxes = [mergedArray + NSArray *sortedBoxes = [mergedArray sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { NSArray *coords1 = obj1[@"bbox"]; @@ -646,8 +738,17 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes : (minY1 > minY2) ? NSOrderedDescending : NSOrderedSame; }]; - - return sortedBoxes; + + NSMutableArray *orderedSortedBoxes = [[NSMutableArray alloc] initWithCapacity:[sortedBoxes count]]; + for (NSDictionary *dict in sortedBoxes) { + NSMutableDictionary *mutableDict = [dict mutableCopy]; + NSArray *originalBBox = mutableDict[@"bbox"]; + NSArray *orderedBBox = [self orderPointsClockwise:originalBBox]; + mutableDict[@"bbox"] = orderedBBox; + [orderedSortedBoxes addObject:mutableDict]; + } + + return orderedSortedBoxes; } @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h index 7af748f5..d976b1ae 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h @@ -24,5 +24,6 @@ + (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities; + (double)computeConfidenceScore:(NSArray *)valuesArray indicesArray:(NSArray *)indicesArray; ++ (cv::Mat)cropImageWithBoundingBox:(cv::Mat &)img bbox:(NSArray *)bbox originalBbox:(NSArray *)originalBbox paddings:(NSDictionary *)paddings originalPaddings:(NSDictionary *)originalPaddings; @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index 65c088b3..b1726e79 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -61,11 +61,13 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { image = [self adjustContrastGrey:image target:adjustContrast]; } - int desiredWidth = 128; + int desiredWidth = 64; if (image.cols >= 512) { desiredWidth = 512; } else if (image.cols >= 256) { desiredWidth = 256; + } else if (image.cols >= 128){ + desiredWidth = 128; } image = [OCRUtils resizeWithPadding:image @@ -220,4 +222,36 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray return pow(product, 2.0 / sqrt(predsMaxProb.count)); } ++ (cv::Mat)cropImageWithBoundingBox:(cv::Mat&)img bbox:(NSArray *)bbox originalBbox:(NSArray *)originalBbox paddings:(NSDictionary *)paddings originalPaddings:(NSDictionary *)originalPaddings { + CGPoint topLeft = [originalBbox[0] CGPointValue]; + std::vector points; + for(NSValue* coords in bbox) { + CGPoint point = [coords CGPointValue]; + + point.x = point.x - [paddings[@"left"] intValue]; + point.y = point.y - [paddings[@"top"] intValue]; + + point.x = point.x * [paddings[@"resizeRatio"] floatValue]; + point.y = point.y * [paddings[@"resizeRatio"] floatValue]; + + point.x = point.x + topLeft.x; + point.y = point.y + topLeft.y; + + point.x = point.x - [originalPaddings[@"left"] intValue]; + point.y = point.y - [originalPaddings[@"top"] intValue]; + + point.x = point.x * [originalPaddings[@"resizeRatio"] floatValue]; + point.y = point.y * [originalPaddings[@"resizeRatio"] floatValue]; + + points.push_back(cv::Point2f(point.x, point.y)); + } + + cv::Rect rect = cv::boundingRect(points); + cv::Mat croppedImage = img(rect); + cv::cvtColor(croppedImage, croppedImage, cv::COLOR_BGR2GRAY); + cv::resize(croppedImage, croppedImage, cv::Size(64, 64), 0, 0, cv::INTER_AREA); + cv::medianBlur(croppedImage, croppedImage, 1); + return croppedImage; +} + @end diff --git a/src/hooks/computer_vision/useVerticalOCR.ts b/src/hooks/computer_vision/useVerticalOCR.ts new file mode 100644 index 00000000..3a5707d7 --- /dev/null +++ b/src/hooks/computer_vision/useVerticalOCR.ts @@ -0,0 +1,132 @@ +import { useEffect, useState } from 'react'; +import { fetchResource } from '../../utils/fetchResource'; +import { languageDicts } from '../../constants/ocr/languageDicts'; +import { symbols } from '../../constants/ocr/symbols'; +import { getError, ETError } from '../../Error'; +import { VerticalOCR } from '../../native/RnExecutorchModules'; +import { ResourceSource } from '../../types/common'; +import { OCRDetection } from '../../types/ocr'; + +interface OCRModule { + error: string | null; + isReady: boolean; + isGenerating: boolean; + forward: (input: string) => Promise; + downloadProgress: number; +} + +export const useVerticalOCR = ({ + detectorSources, + recognizerSources, + language = 'en', + independentCharacters = false, +}: { + detectorSources: { + detectorLarge: ResourceSource; + detectorNarrow: ResourceSource; + }; + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerSmall: ResourceSource; + }; + language?: string; + independentCharacters?: boolean; +}): OCRModule => { + const [error, setError] = useState(null); + const [isReady, setIsReady] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); + + useEffect(() => { + const loadModel = async () => { + try { + if ( + Object.keys(detectorSources).length !== 2 || + Object.keys(recognizerSources).length !== 2 + ) + return; + + let recognizerPath; + + const detectorPaths = {} as { + detectorLarge: string; + detectorNarrow: string; + }; + + if (!symbols[language] || !languageDicts[language]) { + setError(getError(ETError.LanguageNotSupported)); + return; + } + + await Promise.all([ + fetchResource(detectorSources.detectorLarge), + fetchResource(detectorSources.detectorNarrow), + ]).then((values) => { + detectorPaths.detectorLarge = values[0]; + detectorPaths.detectorNarrow = values[1]; + }); + + if (independentCharacters) { + recognizerPath = await fetchResource( + recognizerSources.recognizerSmall, + setDownloadProgress + ); + } else { + recognizerPath = await fetchResource( + recognizerSources.recognizerLarge, + setDownloadProgress + ); + } + + setIsReady(false); + await VerticalOCR.loadModule( + detectorPaths.detectorLarge, + detectorPaths.detectorNarrow, + recognizerPath, + symbols[language], + independentCharacters + ); + setIsReady(true); + } catch (e) { + setError(getError(e)); + } + }; + + loadModel(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [ + // eslint-disable-next-line react-hooks/exhaustive-deps + JSON.stringify(detectorSources), + language, + independentCharacters, + // eslint-disable-next-line react-hooks/exhaustive-deps + JSON.stringify(recognizerSources), + ]); + + const forward = async (input: string) => { + if (!isReady) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + if (isGenerating) { + throw new Error(getError(ETError.ModelGenerating)); + } + + try { + setIsGenerating(true); + const output = await VerticalOCR.forward(input); + return output; + } catch (e) { + throw new Error(getError(e)); + } finally { + setIsGenerating(false); + } + }; + + return { + error, + isReady, + isGenerating, + forward, + downloadProgress, + }; +}; diff --git a/src/index.tsx b/src/index.tsx index f5bfa185..c9b6fca0 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -3,6 +3,7 @@ export * from './hooks/computer_vision/useClassification'; export * from './hooks/computer_vision/useObjectDetection'; export * from './hooks/computer_vision/useStyleTransfer'; export * from './hooks/computer_vision/useOCR'; +export * from './hooks/computer_vision/useVerticalOCR'; export * from './hooks/natural_language_processing/useLLM'; diff --git a/src/native/NativeVerticalOCR.ts b/src/native/NativeVerticalOCR.ts new file mode 100644 index 00000000..2aca8cbe --- /dev/null +++ b/src/native/NativeVerticalOCR.ts @@ -0,0 +1,16 @@ +import type { TurboModule } from 'react-native'; +import { TurboModuleRegistry } from 'react-native'; +import { OCRDetection } from '../types/ocr'; + +export interface Spec extends TurboModule { + loadModule( + detectorLargeSource: string, + detectorNarrowSource: string, + recognizerSource: string, + symbols: string, + independentCharacters: boolean + ): Promise; + forward(input: string): Promise; +} + +export default TurboModuleRegistry.get('VerticalOCR'); diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index c8044aa4..49e4b89e 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -101,6 +101,19 @@ const OCR = OCRSpec } ); +const VerticalOCRSpec = require('./NativeVerticalOCR').default; + +const VerticalOCR = VerticalOCRSpec + ? VerticalOCRSpec + : new Proxy( + {}, + { + get() { + throw new Error(LINKING_ERROR); + }, + } + ); + class _ObjectDetectionModule { async forward( input: string @@ -182,6 +195,7 @@ export { StyleTransfer, SpeechToText, OCR, + VerticalOCR, _ETModule, _ClassificationModule, _StyleTransferModule, From 4c118a0b89ee74c463058b136dde593a050ef457 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 27 Feb 2025 14:28:01 +0100 Subject: [PATCH 02/12] refactor: refactor of vertical ocr code --- .../com/swmansion/rnexecutorch/VerticalOCR.kt | 87 +++++++------ .../rnexecutorch/models/ocr/Detector.kt | 61 ++------- .../models/ocr/VerticalDetector.kt | 94 ++++++++++++++ .../models/ocr/utils/DetectorUtils.kt | 4 +- .../models/ocr/utils/RecognizerUtils.kt | 61 +++++---- examples/computer-vision/App.tsx | 7 ++ .../models/ocr/VerticalDetector.h | 28 +++++ .../models/ocr/VerticalDetector.mm | 118 ++++++++++++++++++ .../computer-vision/screens/OCRScreen.tsx | 6 +- .../screens/VerticalOCRScreen.tsx | 118 ++++++++++++++++++ ios/RnExecutorch/VerticalOCR.mm | 60 +++++---- ios/RnExecutorch/models/ocr/Detector.h | 2 +- ios/RnExecutorch/models/ocr/Detector.mm | 46 ++----- .../models/ocr/RecognitionHandler.mm | 2 +- .../models/ocr/VerticalDetector.h | 28 +++++ .../models/ocr/VerticalDetector.mm | 118 ++++++++++++++++++ ios/RnExecutorch/models/ocr/utils/OCRUtils.h | 1 + ios/RnExecutorch/models/ocr/utils/OCRUtils.mm | 12 ++ .../models/ocr/utils/RecognizerUtils.h | 9 +- .../models/ocr/utils/RecognizerUtils.mm | 37 +++--- 20 files changed, 697 insertions(+), 202 deletions(-) create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt create mode 100644 examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.h create mode 100644 examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.mm create mode 100644 examples/computer-vision/screens/VerticalOCRScreen.tsx create mode 100644 ios/RnExecutorch/models/ocr/VerticalDetector.h create mode 100644 ios/RnExecutorch/models/ocr/VerticalDetector.mm diff --git a/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt b/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt index 945a83c0..859ebaa7 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt @@ -1,6 +1,5 @@ package com.swmansion.rnexecutorch -import android.media.Image import android.util.Log import com.facebook.react.bridge.Arguments import com.facebook.react.bridge.Promise @@ -8,22 +7,19 @@ import com.facebook.react.bridge.ReactApplicationContext import com.swmansion.rnexecutorch.utils.ETError import com.swmansion.rnexecutorch.utils.ImageProcessor import org.opencv.android.OpenCVLoader -import com.swmansion.rnexecutorch.models.ocr.Detector import com.swmansion.rnexecutorch.models.ocr.Recognizer +import com.swmansion.rnexecutorch.models.ocr.VerticalDetector import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter -import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils +import com.swmansion.rnexecutorch.models.ocr.utils.Constants import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils import org.opencv.core.Core import org.opencv.core.Mat -import org.opencv.core.MatOfPoint -import org.opencv.core.Point -import org.opencv.imgproc.Imgproc class VerticalOCR(reactContext: ReactApplicationContext) : NativeVerticalOCRSpec(reactContext) { - private lateinit var detectorLarge: Detector - private lateinit var detectorNarrow: Detector + private lateinit var detectorLarge: VerticalDetector + private lateinit var detectorNarrow: VerticalDetector private lateinit var recognizer: Recognizer private lateinit var converter: CTCLabelConverter private var independentCharacters = true @@ -50,9 +46,9 @@ class VerticalOCR(reactContext: ReactApplicationContext) : ) { try { this.independentCharacters = independentCharacters - detectorLarge = Detector(true, false, reactApplicationContext) + detectorLarge = VerticalDetector(false, reactApplicationContext) detectorLarge.loadModel(detectorLargeSource) - detectorNarrow = Detector(true, true, reactApplicationContext) + detectorNarrow = VerticalDetector(true, reactApplicationContext) detectorNarrow.loadModel(detectorNarrowSource) recognizer = Recognizer(reactApplicationContext) recognizer.loadModel(recognizerSource) @@ -69,50 +65,54 @@ class VerticalOCR(reactContext: ReactApplicationContext) : try { val inputImage = ImageProcessor.readImage(input) val result = detectorLarge.runModel(inputImage) - - val resizedImage = ImageProcessor.resizeWithPadding(inputImage, 1280, 1280) + val largeDetectorSize = detectorLarge.getModelImageSize() + val resizedImage = ImageProcessor.resizeWithPadding( + inputImage, + largeDetectorSize.width.toInt(), + largeDetectorSize.height.toInt() + ) val predictions = Arguments.createArray() for (box in result) { - val coords = box.bBox - val boxWidth = coords[2].x - coords[0].x - val boxHeight = coords[2].y - coords[0].y - val points = arrayOfNulls(4) - - for (i in 0 until 4) { - points.set(i, Point(coords[i].x, coords[i].y)) - } + val cords = box.bBox + val boxWidth = cords[2].x - cords[0].x + val boxHeight = cords[2].y - cords[0].y - val boundingBox = Imgproc.boundingRect(MatOfPoint(*points)) + val boundingBox = RecognizerUtils.extractBoundingBox(cords) val croppedImage = Mat(resizedImage, boundingBox) - val ratioAndPadding = RecognizerUtils.calculateResizeRatioAndPaddings( + val paddings = RecognizerUtils.calculateResizeRatioAndPaddings( inputImage.width(), inputImage.height(), - 1280, - 1280 + largeDetectorSize.width.toInt(), + largeDetectorSize.height.toInt() ) var text = "" var confidenceScore = 0.0 - var detectionResult = detectorNarrow.runModel(croppedImage) - - var croppedCharacters = mutableListOf() - - for (bbox in detectionResult) { - val coords2 = bbox.bBox - var paddingsSingle = RecognizerUtils.calculateResizeRatioAndPaddings( - boxWidth.toInt(), boxHeight.toInt(), 320, 1280 + val boxResult = detectorNarrow.runModel(croppedImage) + val narrowDetectorSize = detectorNarrow.getModelImageSize() + + val croppedCharacters = mutableListOf() + + for (characterBox in boxResult) { + val boxCords = characterBox.bBox + val paddingsBox = RecognizerUtils.calculateResizeRatioAndPaddings( + boxWidth.toInt(), + boxHeight.toInt(), + narrowDetectorSize.width.toInt(), + narrowDetectorSize.height.toInt() ) var croppedCharacter = RecognizerUtils.cropImageWithBoundingBox( inputImage, - coords2, - coords, - paddingsSingle, - ratioAndPadding + boxCords, + cords, + paddingsBox, + paddings ) + if (this.independentCharacters) { - croppedCharacter = RecognizerUtils.normalizeForRecognizer(croppedCharacter, 0.0) + croppedCharacter = RecognizerUtils.normalizeForRecognizer(croppedCharacter, 0.0, true) val recognitionResult = recognizer.runModel(croppedCharacter) val predIndex = recognitionResult.first val decodedText = converter.decodeGreedy(predIndex, predIndex.size) @@ -124,11 +124,15 @@ class VerticalOCR(reactContext: ReactApplicationContext) : } if (this.independentCharacters) { - confidenceScore /= detectionResult.size + confidenceScore /= boxResult.size } else { var mergedCharacters = Mat() Core.hconcat(croppedCharacters, mergedCharacters) - mergedCharacters = ImageProcessor.resizeWithPadding(mergedCharacters, 512, 64) + mergedCharacters = ImageProcessor.resizeWithPadding( + mergedCharacters, + Constants.LARGE_MODEL_WIDTH, + Constants.MODEL_HEIGHT + ) mergedCharacters = RecognizerUtils.normalizeForRecognizer(mergedCharacters, 0.0) val recognitionResult = recognizer.runModel(mergedCharacters) @@ -141,9 +145,9 @@ class VerticalOCR(reactContext: ReactApplicationContext) : for (bBox in box.bBox) { bBox.x = - (bBox.x - ratioAndPadding["left"] as Int) * ratioAndPadding["resizeRatio"] as Float + (bBox.x - paddings["left"] as Int) * paddings["resizeRatio"] as Float bBox.y = - (bBox.y - ratioAndPadding["top"] as Int) * ratioAndPadding["resizeRatio"] as Float + (bBox.y - paddings["top"] as Int) * paddings["resizeRatio"] as Float } val resMap = Arguments.createMap() @@ -154,6 +158,7 @@ class VerticalOCR(reactContext: ReactApplicationContext) : predictions.pushMap(resMap) } + promise.resolve(predictions) } catch (e: Exception) { Log.d("rn_executorch", "Error running model: ${e.message}") diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt index 87b58d22..fb8e4329 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt @@ -1,6 +1,5 @@ package com.swmansion.rnexecutorch.models.ocr -import android.util.Log import com.facebook.react.bridge.ReactApplicationContext import com.swmansion.rnexecutorch.models.BaseModel import com.swmansion.rnexecutorch.models.ocr.utils.Constants @@ -8,16 +7,12 @@ import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox import com.swmansion.rnexecutorch.utils.ImageProcessor import org.opencv.core.Mat -import org.opencv.core.Scalar import org.opencv.core.Size import org.pytorch.executorch.EValue class Detector( - val isVertical: Boolean, - val detectSingleCharacter: Boolean, reactApplicationContext: ReactApplicationContext -) : - BaseModel>(reactApplicationContext) { +) : BaseModel>(reactApplicationContext) { private lateinit var originalSize: Size fun getModelImageSize(): Size { @@ -33,16 +28,11 @@ class Detector( override fun preprocess(input: Mat): EValue { originalSize = Size(input.cols().toDouble(), input.rows().toDouble()) val resizedImage = ImageProcessor.resizeWithPadding( - input, - getModelImageSize().width.toInt(), - getModelImageSize().height.toInt() + input, getModelImageSize().width.toInt(), getModelImageSize().height.toInt() ) return ImageProcessor.matToEValue( - resizedImage, - module.getInputShape(0), - Constants.MEAN, - Constants.VARIANCE + resizedImage, module.getInputShape(0), Constants.MEAN, Constants.VARIANCE ) } @@ -52,43 +42,18 @@ class Detector( val modelImageSize = getModelImageSize() val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats( - outputArray, - Size(modelImageSize.width / 2, modelImageSize.height / 2) + outputArray, Size(modelImageSize.width / 2, modelImageSize.height / 2) + ) + var bBoxesList = DetectorUtils.getDetBoxesFromTextMap( + scoreText, + scoreLink, + Constants.TEXT_THRESHOLD, + Constants.LINK_THRESHOLD, + Constants.LOW_TEXT_THRESHOLD ) - var bBoxesList: MutableList = mutableListOf() - if (!isVertical) { - bBoxesList = DetectorUtils.getDetBoxesFromTextMap( - scoreText, - scoreLink, - Constants.TEXT_THRESHOLD, - Constants.LINK_THRESHOLD, - Constants.LOW_TEXT_THRESHOLD - ) - - bBoxesList = - DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat()) - } else { - var txtThreshold = Constants.TEXT_THRESHOLD - - if (!detectSingleCharacter) { - txtThreshold = Constants.TEXT_THRESHOLD_VERTICAL - } - - bBoxesList = DetectorUtils.getDetBoxesFromTextMapVertical( - scoreText, - scoreLink, - txtThreshold, - Constants.LINK_THRESHOLD, - detectSingleCharacter - ) - - bBoxesList = - DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RESTORE_RATIO_VERTICAL).toFloat()) - if (detectSingleCharacter) { - return bBoxesList - } - } + bBoxesList = + DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat()) bBoxesList = DetectorUtils.groupTextBoxes( bBoxesList, diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt new file mode 100644 index 00000000..d3365274 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt @@ -0,0 +1,94 @@ +package com.swmansion.rnexecutorch.models.ocr + +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.models.ocr.utils.Constants +import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils +import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Mat +import org.opencv.core.Size +import org.pytorch.executorch.EValue + +class VerticalDetector( + private val detectSingleCharacter: Boolean, + reactApplicationContext: ReactApplicationContext +) : + BaseModel>(reactApplicationContext) { + private lateinit var originalSize: Size + + fun getModelImageSize(): Size { + val inputShape = module.getInputShape(0) + val width = inputShape[inputShape.lastIndex - 1] + val height = inputShape[inputShape.lastIndex] + + val modelImageSize = Size(height.toDouble(), width.toDouble()) + + return modelImageSize + } + + override fun preprocess(input: Mat): EValue { + originalSize = Size(input.cols().toDouble(), input.rows().toDouble()) + val resizedImage = ImageProcessor.resizeWithPadding( + input, + getModelImageSize().width.toInt(), + getModelImageSize().height.toInt() + ) + + return ImageProcessor.matToEValue( + resizedImage, + module.getInputShape(0), + Constants.MEAN, + Constants.VARIANCE + ) + } + + override fun postprocess(output: Array): List { + val outputTensor = output[0].toTensor() + val outputArray = outputTensor.dataAsFloatArray + val modelImageSize = getModelImageSize() + + val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats( + outputArray, + Size(modelImageSize.width / 2, modelImageSize.height / 2) + ) + + var txtThreshold = Constants.TEXT_THRESHOLD + + if (!detectSingleCharacter) { + txtThreshold = Constants.TEXT_THRESHOLD_VERTICAL + } + + var bBoxesList = DetectorUtils.getDetBoxesFromTextMapVertical( + scoreText, + scoreLink, + txtThreshold, + Constants.LINK_THRESHOLD, + detectSingleCharacter + ) + + bBoxesList = + DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RESTORE_RATIO_VERTICAL).toFloat()) + + if (detectSingleCharacter) { + return bBoxesList + } + + + bBoxesList = DetectorUtils.groupTextBoxes( + bBoxesList, + Constants.CENTER_THRESHOLD, + Constants.DISTANCE_THRESHOLD, + Constants.HEIGHT_THRESHOLD, + Constants.MIN_SIDE_THRESHOLD, + Constants.MAX_SIDE_THRESHOLD, + Constants.MAX_WIDTH + ) + + return bBoxesList.toList() + } + + override fun runModel(input: Mat): List { + return postprocess(forward(preprocess(input))) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt index 992a4014..c1c90774 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt @@ -304,7 +304,7 @@ class DetectorUtils { Imgproc.threshold(textMap, textScore, textThreshold, 1.0, Imgproc.THRESH_BINARY) Imgproc.threshold(affinityMap, affinityScore, linkThreshold, 1.0, Imgproc.THRESH_BINARY) val textScoreComb = Mat() - val kernel = Imgproc.getStructuringElement( + var kernel = Imgproc.getStructuringElement( Imgproc.MORPH_RECT, Size(3.0, 3.0) ) @@ -352,7 +352,7 @@ class DetectorUtils { val sy = max(y - dilationRadius, 0) val ey = min(y + h + dilationRadius + 1, imgH) val roi = Rect(sx, sy, ex - sx, ey - sy) - val kernel = Imgproc.getStructuringElement( + kernel = Imgproc.getStructuringElement( Imgproc.MORPH_RECT, Size((1 + dilationRadius).toDouble(), (1 + dilationRadius).toDouble()) ) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt index 91b99f4b..bbd0dd6d 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt @@ -245,20 +245,27 @@ class RecognizerUtils { return computeRatioAndResize(croppedImage, boundingBox.width, boundingBox.height, modelHeight) } - fun normalizeForRecognizer(image: Mat, adjustContrast: Double): Mat { + fun normalizeForRecognizer( + image: Mat, + adjustContrast: Double, + isVertical: Boolean = false + ): Mat { var img = image.clone() if (adjustContrast > 0) { img = adjustContrastGrey(img, adjustContrast) } - val desiredWidth = when { - img.width() >= Constants.LARGE_MODEL_WIDTH -> Constants.LARGE_MODEL_WIDTH - img.width() >= Constants.MEDIUM_MODEL_WIDTH -> Constants.MEDIUM_MODEL_WIDTH - img.width() >= Constants.SMALL_MODEL_WIDTH -> Constants.SMALL_MODEL_WIDTH - else -> Constants.VERTICAL_SMALL_MODEL_WIDTH + var desiredWidth = + if (isVertical) Constants.VERTICAL_SMALL_MODEL_WIDTH else Constants.SMALL_MODEL_WIDTH + + if (img.width() >= Constants.LARGE_MODEL_WIDTH) { + desiredWidth = Constants.LARGE_MODEL_WIDTH + } else if (img.width() >= Constants.MEDIUM_MODEL_WIDTH) { + desiredWidth = Constants.MEDIUM_MODEL_WIDTH } + img = ImageProcessor.resizeWithPadding(img, desiredWidth, Constants.MODEL_HEIGHT) img.convertTo(img, CvType.CV_32F, 1.0 / 255.0) Core.subtract(img, Scalar(0.5), img) @@ -269,32 +276,32 @@ class RecognizerUtils { fun cropImageWithBoundingBox( image: Mat, - bbox: List, - originalBbox: List, + box: List, + originalBox: List, paddings: Map, originalPaddings: Map ): Mat { - var topLeft = originalBbox[0] + val topLeft = originalBox[0] val points = arrayOfNulls(4) for (i in 0 until 4) { - val coords = bbox[i] - coords.x -= paddings["left"]!! as Int - coords.y -= paddings["top"]!! as Int + val cords = box[i] + cords.x -= paddings["left"]!! as Int + cords.y -= paddings["top"]!! as Int - coords.x *= paddings["resizeRatio"]!! as Float - coords.y *= paddings["resizeRatio"]!! as Float + cords.x *= paddings["resizeRatio"]!! as Float + cords.y *= paddings["resizeRatio"]!! as Float - coords.x += topLeft.x - coords.y += topLeft.y + cords.x += topLeft.x + cords.y += topLeft.y - coords.x -= originalPaddings["left"]!! as Int - coords.y -= (originalPaddings["top"]!! as Int) + cords.x -= originalPaddings["left"]!! as Int + cords.y -= (originalPaddings["top"]!! as Int) - coords.x *= originalPaddings["resizeRatio"]!! as Float - coords.y *= originalPaddings["resizeRatio"]!! as Float + cords.x *= originalPaddings["resizeRatio"]!! as Float + cords.y *= originalPaddings["resizeRatio"]!! as Float - points[i] = Point(coords.x, coords.y) + points[i] = Point(cords.x, cords.y) } val boundingBox = Imgproc.boundingRect(MatOfPoint2f(*points)) @@ -305,5 +312,17 @@ class RecognizerUtils { return croppedImage } + + fun extractBoundingBox(cords: List): Rect { + val points = arrayOfNulls(4) + + for (i in 0 until 4) { + points[i] = Point(cords[i].x, cords[i].y) + } + + val boundingBox = Imgproc.boundingRect(MatOfPoint2f(*points)) + + return boundingBox + } } } diff --git a/examples/computer-vision/App.tsx b/examples/computer-vision/App.tsx index 488c61cd..c79519ca 100644 --- a/examples/computer-vision/App.tsx +++ b/examples/computer-vision/App.tsx @@ -9,12 +9,14 @@ import { View, StyleSheet } from 'react-native'; import { ClassificationScreen } from './screens/ClassificationScreen'; import { ObjectDetectionScreen } from './screens/ObjectDetectionScreen'; import { OCRScreen } from './screens/OCRScreen'; +import { VerticalOCRScreen } from './screens/VerticalOCRScreen'; enum ModelType { STYLE_TRANSFER, OBJECT_DETECTION, CLASSIFICATION, OCR, + VERTICAL_OCR, } export default function App() { @@ -50,6 +52,10 @@ export default function App() { ); case ModelType.OCR: return ; + case ModelType.VERTICAL_OCR: + return ( + + ); default: return ( @@ -69,6 +75,7 @@ export default function App() { 'Object Detection', 'Classification', 'OCR', + 'Vertical OCR', ]} onValueChange={(_, selectedIndex) => { handleModeChange(selectedIndex); diff --git a/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.h b/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.h new file mode 100644 index 00000000..8263ddd4 --- /dev/null +++ b/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.h @@ -0,0 +1,28 @@ +#import "BaseModel.h" +#import "RecognitionHandler.h" +#import "opencv2/opencv.hpp" + +constexpr CGFloat textThreshold = 0.4; +constexpr CGFloat textThresholdVertical = 0.3; +constexpr CGFloat linkThreshold = 0.4; +constexpr CGFloat lowTextThreshold = 0.7; +constexpr CGFloat centerThreshold = 0.5; +constexpr CGFloat distanceThreshold = 2.0; +constexpr CGFloat heightThreshold = 2.0; +constexpr CGFloat restoreRatio = 3.2; +constexpr CGFloat restoreRatioVertical = 2.0; +constexpr int minSideThreshold = 15; +constexpr int maxSideThreshold = 30; +constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); +constexpr int minSize = 20; + +const cv::Scalar mean(0.485, 0.456, 0.406); +const cv::Scalar variance(0.229, 0.224, 0.225); + +@interface VerticalDetector : BaseModel + +- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters; +- (cv::Size)getModelImageSize; +- (NSArray *)runModel:(cv::Mat &)input; + +@end diff --git a/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.mm b/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.mm new file mode 100644 index 00000000..a2657a00 --- /dev/null +++ b/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.mm @@ -0,0 +1,118 @@ +#import "VerticalDetector.h" +#import "../../utils/ImageProcessor.h" +#import "utils/DetectorUtils.h" +#import "utils/OCRUtils.h" + +/* + The model used as detector is based on CRAFT (Character Region Awareness for + Text Detection) paper. https://arxiv.org/pdf/1904.01941 + */ + +@implementation VerticalDetector { + cv::Size originalSize; + cv::Size modelSize; + BOOL detectSingleCharacters; +} + +- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters { + self = [super init]; + if (self) { + self->detectSingleCharacters = detectSingleCharacters; + } + return self; +} + +- (cv::Size)getModelImageSize { + if (!modelSize.empty()) { + return modelSize; + } + + NSArray *inputShape = [module getInputShape:@0]; + NSNumber *widthNumber = inputShape[inputShape.count - 2]; + NSNumber *heightNumber = inputShape.lastObject; + + const int height = [heightNumber intValue]; + const int width = [widthNumber intValue]; + modelSize = cv::Size(height, width); + + return cv::Size(height, width); +} + +- (NSArray *)preprocess:(cv::Mat &)input { + /* + Detector as an input accepts tensor with a shape of [1, 3, 800, 800]. + Due to big influence of resize to quality of recognition the image preserves + original aspect ratio and the missing parts are filled with padding. + */ + self->originalSize = cv::Size(input.cols, input.rows); + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat resizedImage; + resizedImage = [OCRUtils resizeWithPadding:input + desiredWidth:modelImageSize.width + desiredHeight:modelImageSize.height]; + NSArray *modelInput = [ImageProcessor matToNSArray:resizedImage + mean:mean + variance:variance]; + return modelInput; +} + +- (NSArray *)postprocess:(NSArray *)output { + /* + The output of the model consists of two matrices (heat maps): + 1. ScoreText(Score map) - The probability of a region containing character + 2. ScoreAffinity(Affinity map) - affinity between characters, used to to + group each character into a single instance (sequence) Both matrices are + 400x400 + + The result of this step is a list of bounding boxes that contain text. + */ + NSArray *predictions = [output objectAtIndex:0]; + + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat scoreTextCV, scoreAffinityCV; + /* + The output of the model is a matrix in size of input image containing two + matrices representing heatmap. Those two matrices are in the size of half of + the input image, that's why the width and height is divided by 2. + */ + [DetectorUtils interleavedArrayToMats:predictions + outputMat1:scoreTextCV + outputMat2:scoreAffinityCV + withSize:cv::Size(modelImageSize.width / 2, + modelImageSize.height / 2)]; + CGFloat txtThreshold = textThreshold; + if (!self->detectSingleCharacters) { + txtThreshold = textThresholdVertical; + } + NSArray *bBoxesList = [DetectorUtils + getDetBoxesFromTextMapVertical:scoreTextCV + affinityMap:scoreAffinityCV + usingTextThreshold:txtThreshold + linkThreshold:linkThreshold + independentCharacters:self->detectSingleCharacters]; + bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList + usingRestoreRatio:restoreRatioVertical]; + + if (self->detectSingleCharacters) { + return bBoxesList; + } + + bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList + centerThreshold:centerThreshold + distanceThreshold:distanceThreshold + heightThreshold:heightThreshold + minSideThreshold:minSideThreshold + maxSideThreshold:maxSideThreshold + maxWidth:maxWidth]; + + return bBoxesList; +} + +- (NSArray *)runModel:(cv::Mat &)input { + NSArray *modelInput = [self preprocess:input]; + NSArray *modelResult = [self forward:modelInput]; + NSArray *result = [self postprocess:modelResult]; + return result; +} + +@end diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx index 9d17118a..3869c419 100644 --- a/examples/computer-vision/screens/OCRScreen.tsx +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -19,6 +19,7 @@ export const OCRScreen = ({ height: number; }>(); const [detectedText, setDetectedText] = useState(''); + const model = useOCR({ detectorSource: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_craft_800.pte', @@ -63,7 +64,10 @@ export const OCRScreen = ({ if (!model.isReady) { return ( - + ); } diff --git a/examples/computer-vision/screens/VerticalOCRScreen.tsx b/examples/computer-vision/screens/VerticalOCRScreen.tsx new file mode 100644 index 00000000..05cda224 --- /dev/null +++ b/examples/computer-vision/screens/VerticalOCRScreen.tsx @@ -0,0 +1,118 @@ +import Spinner from 'react-native-loading-spinner-overlay'; +import { BottomBar } from '../components/BottomBar'; +import { getImage } from '../utils'; +import { useVerticalOCR } from 'react-native-executorch'; +import { View, StyleSheet, Image, Text } from 'react-native'; +import { useState } from 'react'; +import ImageWithBboxes2 from '../components/ImageWithOCRBboxes'; + +export const VerticalOCRScreen = ({ + imageUri, + setImageUri, +}: { + imageUri: string; + setImageUri: (imageUri: string) => void; +}) => { + const [results, setResults] = useState([]); + const [imageDimensions, setImageDimensions] = useState<{ + width: number; + height: number; + }>(); + const [detectedText, setDetectedText] = useState(''); + const model = useVerticalOCR({ + detectorSources: { + detectorLarge: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_craft.pte', + detectorNarrow: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_craft_narrow.pte', + }, + recognizerSources: { + recognizerLarge: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_512.pte', + recognizerSmall: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_64.pte', + }, + language: 'en', + independentCharacters: true, + }); + + const handleCameraPress = async (isCamera: boolean) => { + const image = await getImage(isCamera); + const width = image?.width; + const height = image?.height; + setImageDimensions({ width: width as number, height: height as number }); + const uri = image?.uri; + if (typeof uri === 'string') { + setImageUri(uri as string); + setResults([]); + setDetectedText(''); + } + }; + + const runForward = async () => { + try { + const output = await model.forward(imageUri); + setResults(output); + console.log(output); + let txt = ''; + output.forEach((detection: any) => { + txt += detection.text + ' '; + }); + setDetectedText(txt); + } catch (e) { + console.error(e); + } + }; + + if (!model.isReady) { + return ( + + ); + } + + return ( + <> + + + {imageUri && imageDimensions?.width && imageDimensions?.height ? ( + + ) : ( + + )} + + {detectedText} + + + + ); +}; + +const styles = StyleSheet.create({ + image: { + flex: 2, + borderRadius: 8, + width: '100%', + }, + imageContainer: { + flex: 6, + width: '100%', + padding: 16, + }, +}); diff --git a/ios/RnExecutorch/VerticalOCR.mm b/ios/RnExecutorch/VerticalOCR.mm index 14940ecb..d195fb11 100644 --- a/ios/RnExecutorch/VerticalOCR.mm +++ b/ios/RnExecutorch/VerticalOCR.mm @@ -1,5 +1,5 @@ #import "VerticalOCR.h" -#import "models/ocr/Detector.h" +#import "models/ocr/VerticalDetector.h" #import "models/ocr/RecognitionHandler.h" #import "models/ocr/Recognizer.h" #import "models/ocr/utils/RecognizerUtils.h" @@ -10,8 +10,8 @@ #import "models/ocr/utils/CTCLabelConverter.h" @implementation VerticalOCR { - Detector *detectorLarge; - Detector *detectorNarrow; + VerticalDetector *detectorLarge; + VerticalDetector *detectorNarrow; Recognizer *recognizer; CTCLabelConverter *converter; BOOL independentCharacters; @@ -26,8 +26,7 @@ - (void)loadModule:(NSString *)detectorLargeSource independentCharacters:(BOOL)independentCharacters resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { - NSLog(@"%@", recognizerSource); - detectorLarge = [[Detector alloc] initWithIsVertical:YES detectSingleCharacters: NO]; + detectorLarge = [[VerticalDetector alloc] initWithDetectSingleCharacters:NO]; converter = [[CTCLabelConverter alloc] initWithCharacters:symbols separatorList:@{}]; self->independentCharacters = independentCharacters; [detectorLarge @@ -38,7 +37,7 @@ - (void)loadModule:(NSString *)detectorLargeSource nil); return; } - self->detectorNarrow = [[Detector alloc] initWithIsVertical:YES detectSingleCharacters:YES]; + self->detectorNarrow = [[VerticalDetector alloc] initWithDetectSingleCharacters:YES]; [self->detectorNarrow loadModel:[NSURL URLWithString:detectorNarrowSource] completion:^(BOOL success, NSNumber *errorCode) { @@ -70,37 +69,34 @@ - (void)forward:(NSString *)input @try { cv::Mat image = [ImageProcessor readImage:input]; NSArray *result = [detectorLarge runModel:image]; - cv::Mat resizedImage = [OCRUtils resizeWithPadding:image desiredWidth:1280 desiredHeight:1280]; + cv::Size largeDetectorSize = [detectorLarge getModelImageSize]; + cv::Mat resizedImage = [OCRUtils resizeWithPadding:image desiredWidth:largeDetectorSize.width desiredHeight:largeDetectorSize.height]; NSMutableArray *predictions = [NSMutableArray array]; + for (NSDictionary *box in result){ - NSArray *coords = box[@"bbox"]; - const int boxWidth = [[coords objectAtIndex:2] CGPointValue].x - [[coords objectAtIndex:0] CGPointValue].x; - const int boxHeight = [[coords objectAtIndex:2] CGPointValue].y - [[coords objectAtIndex:0] CGPointValue].y; - std::vector points; - for (NSValue *value in coords) { - const CGPoint point = [value CGPointValue]; - points.emplace_back(static_cast(point.x), - static_cast(point.y)); - } + NSArray *cords = box[@"bbox"]; + const int boxWidth = [[cords objectAtIndex:2] CGPointValue].x - [[cords objectAtIndex:0] CGPointValue].x; + const int boxHeight = [[cords objectAtIndex:2] CGPointValue].y - [[cords objectAtIndex:0] CGPointValue].y; - cv::Rect boundingBox = cv::boundingRect(points); + cv::Rect boundingBox = [OCRUtils extractBoundingBox:cords]; cv::Mat croppedImage = resizedImage(boundingBox); - NSDictionary *ratioAndPadding = + NSDictionary *paddings = [RecognizerUtils calculateResizeRatioAndPaddings:image.cols height:image.rows - desiredWidth:1280 - desiredHeight:1280]; + desiredWidth:largeDetectorSize.width + desiredHeight:largeDetectorSize.height]; NSString *text = @""; NSNumber *confidenceScore = @0.0; - NSArray *detectionResult = [detectorNarrow runModel:croppedImage]; + NSArray *boxResult = [detectorNarrow runModel:croppedImage]; std::vector croppedCharacters; - for(NSDictionary *bbox in detectionResult){ - NSArray *coords2 = bbox[@"bbox"]; - NSDictionary *paddingsSingle = [RecognizerUtils calculateResizeRatioAndPaddings:boxWidth height:boxHeight desiredWidth:320 desiredHeight:1280]; - cv::Mat croppedCharacter = [RecognizerUtils cropImageWithBoundingBox:image bbox:coords2 originalBbox:coords paddings:paddingsSingle originalPaddings:ratioAndPadding]; + + for(NSDictionary *characterBox in boxResult){ + NSArray *boxCords = characterBox[@"bbox"]; + NSDictionary *paddingsBox = [RecognizerUtils calculateResizeRatioAndPaddings:boxWidth height:boxHeight desiredWidth:320 desiredHeight:1280]; + cv::Mat croppedCharacter = [RecognizerUtils cropImageWithBoundingBox:image bbox:boxCords originalBbox:cords paddings:paddingsBox originalPaddings:paddings]; if(self->independentCharacters){ - croppedCharacter = [RecognizerUtils normalizeForRecognizer:croppedCharacter adjustContrast:0.0]; + croppedCharacter = [RecognizerUtils normalizeForRecognizer:croppedCharacter adjustContrast:0.0 isVertical: YES]; NSArray *recognitionResult = [recognizer runModel:croppedCharacter]; NSArray *predIndex = [recognitionResult objectAtIndex:0]; NSArray *decodedText = [converter decodeGreedy: predIndex length:(int)(predIndex.count)]; @@ -112,12 +108,12 @@ - (void)forward:(NSString *)input } if(self->independentCharacters){ - confidenceScore = @([confidenceScore floatValue] / detectionResult.count); + confidenceScore = @([confidenceScore floatValue] / boxResult.count); }else{ cv::Mat mergedCharacters; cv::hconcat(croppedCharacters.data(), (int)croppedCharacters.size(), mergedCharacters); mergedCharacters = [OCRUtils resizeWithPadding:mergedCharacters desiredWidth:512 desiredHeight:64]; - mergedCharacters = [RecognizerUtils normalizeForRecognizer:mergedCharacters adjustContrast:0.0]; + mergedCharacters = [RecognizerUtils normalizeForRecognizer:mergedCharacters adjustContrast:0.0 isVertical: NO]; NSArray *recognitionResult = [recognizer runModel:mergedCharacters]; NSArray *predIndex = [recognitionResult objectAtIndex:0]; NSArray *decodedText = [converter decodeGreedy: predIndex length:(int)(predIndex.count)]; @@ -126,12 +122,12 @@ - (void)forward:(NSString *)input } NSMutableArray *newCoords = [NSMutableArray arrayWithCapacity:4]; - for (NSValue *coord in coords) { - const CGPoint point = [coord CGPointValue]; + for (NSValue *cord in cords){ + const CGPoint point = [cord CGPointValue]; [newCoords addObject:@{ - @"x" : @((point.x - [ratioAndPadding[@"left"] intValue]) * [ratioAndPadding[@"resizeRatio"] floatValue]), - @"y" : @((point.y - [ratioAndPadding[@"top"] intValue]) * [ratioAndPadding[@"resizeRatio"] floatValue]) + @"x" : @((point.x - [paddings[@"left"] intValue]) * [paddings[@"resizeRatio"] floatValue]), + @"y" : @((point.y - [paddings[@"top"] intValue]) * [paddings[@"resizeRatio"] floatValue]) }]; } diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h index 562bf92b..e21c92cb 100644 --- a/ios/RnExecutorch/models/ocr/Detector.h +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -21,7 +21,7 @@ const cv::Scalar variance(0.229, 0.224, 0.225); @interface Detector : BaseModel -- (instancetype)initWithIsVertical:(BOOL)isVertical detectSingleCharacters:(BOOL)detectSingleCharacters; +- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters; - (cv::Size)getModelImageSize; - (NSArray *)runModel:(cv::Mat &)input; diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 68cd72ea..5bec8836 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -11,18 +11,6 @@ The model used as detector is based on CRAFT (Character Region Awareness for @implementation Detector { cv::Size originalSize; cv::Size modelSize; - BOOL isVertical; - BOOL detectSingleCharacters; -} - -- (instancetype)initWithIsVertical:(BOOL)isVertical - detectSingleCharacters:(BOOL)detectSingleCharacters { - self = [super init]; - if (self) { - self->isVertical = isVertical; - self->detectSingleCharacters = detectSingleCharacters; - } - return self; } - (cv::Size)getModelImageSize { @@ -83,33 +71,13 @@ group each character into a single instance (sequence) Both matrices are outputMat2:scoreAffinityCV withSize:cv::Size(modelImageSize.width / 2, modelImageSize.height / 2)]; - NSArray *bBoxesList; - if (!self->isVertical) { - bBoxesList = [DetectorUtils getDetBoxesFromTextMap:scoreTextCV - affinityMap:scoreAffinityCV - usingTextThreshold:textThreshold - linkThreshold:linkThreshold - lowTextThreshold:lowTextThreshold]; - bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList - usingRestoreRatio:restoreRatio]; - } else if (self->isVertical) { - CGFloat txtThreshold = textThreshold; - if (!self->detectSingleCharacters) { - txtThreshold = textThresholdVertical; - } - bBoxesList = - [DetectorUtils getDetBoxesFromTextMapVertical:scoreTextCV - affinityMap:scoreAffinityCV - usingTextThreshold:txtThreshold - linkThreshold:linkThreshold - independentCharacters:self->detectSingleCharacters]; - bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList - usingRestoreRatio:restoreRatioVertical]; - - if (self->detectSingleCharacters){ - return bBoxesList; - } - } + NSArray *bBoxesList = [DetectorUtils getDetBoxesFromTextMap:scoreTextCV + affinityMap:scoreAffinityCV + usingTextThreshold:textThreshold + linkThreshold:linkThreshold + lowTextThreshold:lowTextThreshold]; + bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList + usingRestoreRatio:restoreRatio]; bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList centerThreshold:centerThreshold diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index 60616b90..5793c646 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -108,7 +108,7 @@ - (NSArray *)recognize:(NSArray *)bBoxesList continue; } croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage - adjustContrast:adjustContrast]; + adjustContrast:adjustContrast isVertical:NO]; NSArray *result = [self runModel:croppedImage]; NSNumber *confidenceScore = [result objectAtIndex:1]; diff --git a/ios/RnExecutorch/models/ocr/VerticalDetector.h b/ios/RnExecutorch/models/ocr/VerticalDetector.h new file mode 100644 index 00000000..8263ddd4 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/VerticalDetector.h @@ -0,0 +1,28 @@ +#import "BaseModel.h" +#import "RecognitionHandler.h" +#import "opencv2/opencv.hpp" + +constexpr CGFloat textThreshold = 0.4; +constexpr CGFloat textThresholdVertical = 0.3; +constexpr CGFloat linkThreshold = 0.4; +constexpr CGFloat lowTextThreshold = 0.7; +constexpr CGFloat centerThreshold = 0.5; +constexpr CGFloat distanceThreshold = 2.0; +constexpr CGFloat heightThreshold = 2.0; +constexpr CGFloat restoreRatio = 3.2; +constexpr CGFloat restoreRatioVertical = 2.0; +constexpr int minSideThreshold = 15; +constexpr int maxSideThreshold = 30; +constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); +constexpr int minSize = 20; + +const cv::Scalar mean(0.485, 0.456, 0.406); +const cv::Scalar variance(0.229, 0.224, 0.225); + +@interface VerticalDetector : BaseModel + +- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters; +- (cv::Size)getModelImageSize; +- (NSArray *)runModel:(cv::Mat &)input; + +@end diff --git a/ios/RnExecutorch/models/ocr/VerticalDetector.mm b/ios/RnExecutorch/models/ocr/VerticalDetector.mm new file mode 100644 index 00000000..a2657a00 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/VerticalDetector.mm @@ -0,0 +1,118 @@ +#import "VerticalDetector.h" +#import "../../utils/ImageProcessor.h" +#import "utils/DetectorUtils.h" +#import "utils/OCRUtils.h" + +/* + The model used as detector is based on CRAFT (Character Region Awareness for + Text Detection) paper. https://arxiv.org/pdf/1904.01941 + */ + +@implementation VerticalDetector { + cv::Size originalSize; + cv::Size modelSize; + BOOL detectSingleCharacters; +} + +- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters { + self = [super init]; + if (self) { + self->detectSingleCharacters = detectSingleCharacters; + } + return self; +} + +- (cv::Size)getModelImageSize { + if (!modelSize.empty()) { + return modelSize; + } + + NSArray *inputShape = [module getInputShape:@0]; + NSNumber *widthNumber = inputShape[inputShape.count - 2]; + NSNumber *heightNumber = inputShape.lastObject; + + const int height = [heightNumber intValue]; + const int width = [widthNumber intValue]; + modelSize = cv::Size(height, width); + + return cv::Size(height, width); +} + +- (NSArray *)preprocess:(cv::Mat &)input { + /* + Detector as an input accepts tensor with a shape of [1, 3, 800, 800]. + Due to big influence of resize to quality of recognition the image preserves + original aspect ratio and the missing parts are filled with padding. + */ + self->originalSize = cv::Size(input.cols, input.rows); + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat resizedImage; + resizedImage = [OCRUtils resizeWithPadding:input + desiredWidth:modelImageSize.width + desiredHeight:modelImageSize.height]; + NSArray *modelInput = [ImageProcessor matToNSArray:resizedImage + mean:mean + variance:variance]; + return modelInput; +} + +- (NSArray *)postprocess:(NSArray *)output { + /* + The output of the model consists of two matrices (heat maps): + 1. ScoreText(Score map) - The probability of a region containing character + 2. ScoreAffinity(Affinity map) - affinity between characters, used to to + group each character into a single instance (sequence) Both matrices are + 400x400 + + The result of this step is a list of bounding boxes that contain text. + */ + NSArray *predictions = [output objectAtIndex:0]; + + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat scoreTextCV, scoreAffinityCV; + /* + The output of the model is a matrix in size of input image containing two + matrices representing heatmap. Those two matrices are in the size of half of + the input image, that's why the width and height is divided by 2. + */ + [DetectorUtils interleavedArrayToMats:predictions + outputMat1:scoreTextCV + outputMat2:scoreAffinityCV + withSize:cv::Size(modelImageSize.width / 2, + modelImageSize.height / 2)]; + CGFloat txtThreshold = textThreshold; + if (!self->detectSingleCharacters) { + txtThreshold = textThresholdVertical; + } + NSArray *bBoxesList = [DetectorUtils + getDetBoxesFromTextMapVertical:scoreTextCV + affinityMap:scoreAffinityCV + usingTextThreshold:txtThreshold + linkThreshold:linkThreshold + independentCharacters:self->detectSingleCharacters]; + bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList + usingRestoreRatio:restoreRatioVertical]; + + if (self->detectSingleCharacters) { + return bBoxesList; + } + + bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList + centerThreshold:centerThreshold + distanceThreshold:distanceThreshold + heightThreshold:heightThreshold + minSideThreshold:minSideThreshold + maxSideThreshold:maxSideThreshold + maxWidth:maxWidth]; + + return bBoxesList; +} + +- (NSArray *)runModel:(cv::Mat &)input { + NSArray *modelInput = [self preprocess:input]; + NSArray *modelResult = [self forward:modelInput]; + NSArray *result = [self postprocess:modelResult]; + return result; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h index dca8b9bb..90a8fa7a 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h @@ -5,5 +5,6 @@ + (cv::Mat)resizeWithPadding:(cv::Mat)img desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; ++ (cv::Rect)extractBoundingBox:(NSArray *)coords; @end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm index f530dac2..db47ba5b 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm @@ -1,4 +1,5 @@ #import "OCRUtils.h" +#import "RecognizerUtils.h" @implementation OCRUtils @@ -52,4 +53,15 @@ @implementation OCRUtils return centeredImg; } ++ (cv::Rect)extractBoundingBox:(NSArray *)coords { + std::vector points; + for (NSValue *value in coords) { + const CGPoint point = [value CGPointValue]; + + points.emplace_back(point.x, point.y); + } + + return cv::boundingRect(points); +} + @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h index d976b1ae..eb47a0b4 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h @@ -8,7 +8,8 @@ height:(int)height modelHeight:(int)modelHeight; + (cv::Mat)normalizeForRecognizer:(cv::Mat)image - adjustContrast:(double)adjustContrast; + adjustContrast:(double)adjustContrast + isVertical:(BOOL)isVertical; + (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target; + (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector; + (cv::Mat)softmax:(cv::Mat)inputs; @@ -24,6 +25,10 @@ + (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities; + (double)computeConfidenceScore:(NSArray *)valuesArray indicesArray:(NSArray *)indicesArray; -+ (cv::Mat)cropImageWithBoundingBox:(cv::Mat &)img bbox:(NSArray *)bbox originalBbox:(NSArray *)originalBbox paddings:(NSDictionary *)paddings originalPaddings:(NSDictionary *)originalPaddings; ++ (cv::Mat)cropImageWithBoundingBox:(cv::Mat &)img + bbox:(NSArray *)bbox + originalBbox:(NSArray *)originalBbox + paddings:(NSDictionary *)paddings + originalPaddings:(NSDictionary *)originalPaddings; @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index b1726e79..f3c9ec25 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -56,18 +56,22 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { } + (cv::Mat)normalizeForRecognizer:(cv::Mat)image - adjustContrast:(double)adjustContrast { + adjustContrast:(double)adjustContrast + isVertical:(BOOL)isVertical { if (adjustContrast > 0) { image = [self adjustContrastGrey:image target:adjustContrast]; } - int desiredWidth = 64; + int desiredWidth; + if (isVertical){ + desiredWidth = 64; + }else{ + desiredWidth = 128; + } if (image.cols >= 512) { desiredWidth = 512; } else if (image.cols >= 256) { desiredWidth = 256; - } else if (image.cols >= 128){ - desiredWidth = 128; } image = [OCRUtils resizeWithPadding:image @@ -222,34 +226,39 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray return pow(product, 2.0 / sqrt(predsMaxProb.count)); } -+ (cv::Mat)cropImageWithBoundingBox:(cv::Mat&)img bbox:(NSArray *)bbox originalBbox:(NSArray *)originalBbox paddings:(NSDictionary *)paddings originalPaddings:(NSDictionary *)originalPaddings { ++ (cv::Mat)cropImageWithBoundingBox:(cv::Mat &)img + bbox:(NSArray *)bbox + originalBbox:(NSArray *)originalBbox + paddings:(NSDictionary *)paddings + originalPaddings:(NSDictionary *)originalPaddings { CGPoint topLeft = [originalBbox[0] CGPointValue]; std::vector points; - for(NSValue* coords in bbox) { + for (NSValue *coords in bbox) { CGPoint point = [coords CGPointValue]; - + point.x = point.x - [paddings[@"left"] intValue]; point.y = point.y - [paddings[@"top"] intValue]; - + point.x = point.x * [paddings[@"resizeRatio"] floatValue]; point.y = point.y * [paddings[@"resizeRatio"] floatValue]; - + point.x = point.x + topLeft.x; point.y = point.y + topLeft.y; - + point.x = point.x - [originalPaddings[@"left"] intValue]; point.y = point.y - [originalPaddings[@"top"] intValue]; - + point.x = point.x * [originalPaddings[@"resizeRatio"] floatValue]; point.y = point.y * [originalPaddings[@"resizeRatio"] floatValue]; - + points.push_back(cv::Point2f(point.x, point.y)); } - + cv::Rect rect = cv::boundingRect(points); cv::Mat croppedImage = img(rect); cv::cvtColor(croppedImage, croppedImage, cv::COLOR_BGR2GRAY); - cv::resize(croppedImage, croppedImage, cv::Size(64, 64), 0, 0, cv::INTER_AREA); + cv::resize(croppedImage, croppedImage, cv::Size(64, 64), 0, 0, + cv::INTER_AREA); cv::medianBlur(croppedImage, croppedImage, 1); return croppedImage; } From 75033ca1630fd1ebedc39bf7facc7481823b5126 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 27 Feb 2025 15:53:57 +0100 Subject: [PATCH 03/12] feat: added urls to models on hf repo --- .../computer-vision/screens/OCRScreen.tsx | 20 ++++++++++--------- .../screens/VerticalOCRScreen.tsx | 20 ++++++++++--------- ios/RnExecutorch/models/ocr/Detector.h | 1 - src/constants/modelUrls.ts | 18 +++++++++++++++++ 4 files changed, 40 insertions(+), 19 deletions(-) diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx index 3869c419..4bfc0bee 100644 --- a/examples/computer-vision/screens/OCRScreen.tsx +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -1,7 +1,13 @@ import Spinner from 'react-native-loading-spinner-overlay'; import { BottomBar } from '../components/BottomBar'; import { getImage } from '../utils'; -import { useOCR } from 'react-native-executorch'; +import { + DETECTOR_CRAFT_800, + RECOGNIZER_EN_CRNN_128, + RECOGNIZER_EN_CRNN_256, + RECOGNIZER_EN_CRNN_512, + useOCR, +} from 'react-native-executorch'; import { View, StyleSheet, Image, Text } from 'react-native'; import { useState } from 'react'; import ImageWithBboxes2 from '../components/ImageWithOCRBboxes'; @@ -21,15 +27,11 @@ export const OCRScreen = ({ const [detectedText, setDetectedText] = useState(''); const model = useOCR({ - detectorSource: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_craft_800.pte', + detectorSource: DETECTOR_CRAFT_800, recognizerSources: { - recognizerLarge: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_512.pte', - recognizerMedium: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_256.pte', - recognizerSmall: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_128.pte', + recognizerLarge: RECOGNIZER_EN_CRNN_512, + recognizerMedium: RECOGNIZER_EN_CRNN_256, + recognizerSmall: RECOGNIZER_EN_CRNN_128, }, language: 'en', }); diff --git a/examples/computer-vision/screens/VerticalOCRScreen.tsx b/examples/computer-vision/screens/VerticalOCRScreen.tsx index 05cda224..e242fb11 100644 --- a/examples/computer-vision/screens/VerticalOCRScreen.tsx +++ b/examples/computer-vision/screens/VerticalOCRScreen.tsx @@ -1,7 +1,13 @@ import Spinner from 'react-native-loading-spinner-overlay'; import { BottomBar } from '../components/BottomBar'; import { getImage } from '../utils'; -import { useVerticalOCR } from 'react-native-executorch'; +import { + DETECTOR_CRAFT_1280, + DETECTOR_CRAFT_320, + RECOGNIZER_EN_CRNN_512, + RECOGNIZER_EN_CRNN_64, + useVerticalOCR, +} from 'react-native-executorch'; import { View, StyleSheet, Image, Text } from 'react-native'; import { useState } from 'react'; import ImageWithBboxes2 from '../components/ImageWithOCRBboxes'; @@ -21,16 +27,12 @@ export const VerticalOCRScreen = ({ const [detectedText, setDetectedText] = useState(''); const model = useVerticalOCR({ detectorSources: { - detectorLarge: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_craft.pte', - detectorNarrow: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_craft_narrow.pte', + detectorLarge: DETECTOR_CRAFT_1280, + detectorNarrow: DETECTOR_CRAFT_320, }, recognizerSources: { - recognizerLarge: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_512.pte', - recognizerSmall: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_64.pte', + recognizerLarge: RECOGNIZER_EN_CRNN_512, + recognizerSmall: RECOGNIZER_EN_CRNN_64, }, language: 'en', independentCharacters: true, diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h index e21c92cb..a37720be 100644 --- a/ios/RnExecutorch/models/ocr/Detector.h +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -21,7 +21,6 @@ const cv::Scalar variance(0.229, 0.224, 0.225); @interface Detector : BaseModel -- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters; - (cv::Size)getModelImageSize; - (NSArray *)runModel:(cv::Mat &)input; diff --git a/src/constants/modelUrls.ts b/src/constants/modelUrls.ts index 2f57331c..30e38479 100644 --- a/src/constants/modelUrls.ts +++ b/src/constants/modelUrls.ts @@ -46,6 +46,24 @@ export const STYLE_TRANSFER_UDNIE = ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.2.0/coreml/style_transfer_udnie_coreml.pte' : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.2.0/xnnpack/style_transfer_udnie_xnnpack.pte'; +// OCR + +export const DETECTOR_CRAFT_1280 = + 'https://huggingface.co/software-mansion/react-native-executorch-detector-craft/resolve/v0.3.0/xnnpack/xnnpack_craft_1280.pte'; +export const DETECTOR_CRAFT_800 = + 'https://huggingface.co/software-mansion/react-native-executorch-detector-craft/resolve/v0.3.0/xnnpack/xnnpack_craft_800.pte'; +export const DETECTOR_CRAFT_320 = + 'https://huggingface.co/software-mansion/react-native-executorch-detector-craft/resolve/v0.3.0/xnnpack/xnnpack_craft_320.pte'; + +export const RECOGNIZER_EN_CRNN_512 = + 'https://huggingface.co/software-mansion/react-native-executorch-recognizer-crnn.en/resolve/v0.3.0/xnnpack/xnnpack_crnn_en_512.pte'; +export const RECOGNIZER_EN_CRNN_256 = + 'https://huggingface.co/software-mansion/react-native-executorch-recognizer-crnn.en/resolve/v0.3.0/xnnpack/xnnpack_crnn_en_256.pte'; +export const RECOGNIZER_EN_CRNN_128 = + 'https://huggingface.co/software-mansion/react-native-executorch-recognizer-crnn.en/resolve/v0.3.0/xnnpack/xnnpack_crnn_en_128.pte'; +export const RECOGNIZER_EN_CRNN_64 = + 'https://huggingface.co/software-mansion/react-native-executorch-recognizer-crnn.en/resolve/v0.3.0/xnnpack/xnnpack_crnn_en_64.pte'; + // Backward compatibility export const LLAMA3_2_3B_URL = LLAMA3_2_3B; export const LLAMA3_2_3B_QLORA_URL = LLAMA3_2_3B_QLORA; From 706e75392e3507bcf94a0899d647184eb3364134 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 28 Feb 2025 22:27:46 +0100 Subject: [PATCH 04/12] refactor: implement requested changes --- .../models/ocr/VerticalDetector.kt | 7 +--- .../models/ocr/utils/DetectorUtils.kt | 3 +- .../models/ocr/utils/RecognizerUtils.kt | 12 ++---- ios/RnExecutorch/VerticalOCR.h | 2 - ios/RnExecutorch/models/ocr/Detector.h | 18 +------- .../models/ocr/VerticalDetector.h | 18 +------- .../models/ocr/VerticalDetector.mm | 7 ++-- ios/RnExecutorch/models/ocr/utils/Constants.h | 17 ++++++++ .../models/ocr/utils/DetectorUtils.mm | 4 +- ios/RnExecutorch/models/ocr/utils/OCRUtils.mm | 1 + .../models/ocr/utils/RecognizerUtils.mm | 11 ++--- src/constants/ocr/languageDicts.ts | 4 -- src/hooks/computer_vision/useOCR.ts | 3 +- src/hooks/computer_vision/useVerticalOCR.ts | 41 +++++++------------ src/modules/computer_vision/OCRModule.ts | 3 +- 15 files changed, 52 insertions(+), 99 deletions(-) create mode 100644 ios/RnExecutorch/models/ocr/utils/Constants.h delete mode 100644 src/constants/ocr/languageDicts.ts diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt index d3365274..15665256 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt @@ -53,12 +53,7 @@ class VerticalDetector( Size(modelImageSize.width / 2, modelImageSize.height / 2) ) - var txtThreshold = Constants.TEXT_THRESHOLD - - if (!detectSingleCharacter) { - txtThreshold = Constants.TEXT_THRESHOLD_VERTICAL - } - + val txtThreshold = if (detectSingleCharacter) Constants.TEXT_THRESHOLD else Constants.TEXT_THRESHOLD_VERTICAL var bBoxesList = DetectorUtils.getDetBoxesFromTextMapVertical( scoreText, scoreLink, diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt index c1c90774..b07123ac 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt @@ -332,9 +332,10 @@ class DetectorUtils { val detectedBoxes = mutableListOf() for (i in 1 until nLabels) { val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() + if (area < 20) continue + val height = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() val width = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() - if (area < 20) continue if (!independentCharacters && height < width) continue val mask = createMaskFromLabels(labels, i) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt index bbd0dd6d..654aaa41 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt @@ -256,16 +256,12 @@ class RecognizerUtils { img = adjustContrastGrey(img, adjustContrast) } - var desiredWidth = - if (isVertical) Constants.VERTICAL_SMALL_MODEL_WIDTH else Constants.SMALL_MODEL_WIDTH - - if (img.width() >= Constants.LARGE_MODEL_WIDTH) { - desiredWidth = Constants.LARGE_MODEL_WIDTH - } else if (img.width() >= Constants.MEDIUM_MODEL_WIDTH) { - desiredWidth = Constants.MEDIUM_MODEL_WIDTH + val desiredWidth =when { + img.width() >= Constants.LARGE_MODEL_WIDTH -> Constants.LARGE_MODEL_WIDTH + img.width() >= Constants.MEDIUM_MODEL_WIDTH -> Constants.MEDIUM_MODEL_WIDTH + else -> if (isVertical) Constants.VERTICAL_SMALL_MODEL_WIDTH else Constants.SMALL_MODEL_WIDTH } - img = ImageProcessor.resizeWithPadding(img, desiredWidth, Constants.MODEL_HEIGHT) img.convertTo(img, CvType.CV_32F, 1.0 / 255.0) Core.subtract(img, Scalar(0.5), img) diff --git a/ios/RnExecutorch/VerticalOCR.h b/ios/RnExecutorch/VerticalOCR.h index ee19e11e..5692d378 100644 --- a/ios/RnExecutorch/VerticalOCR.h +++ b/ios/RnExecutorch/VerticalOCR.h @@ -1,7 +1,5 @@ #import -constexpr CGFloat recognizerRatio = 1.6; - @interface VerticalOCR : NSObject @end diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h index a37720be..16441359 100644 --- a/ios/RnExecutorch/models/ocr/Detector.h +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -1,23 +1,7 @@ #import "BaseModel.h" #import "RecognitionHandler.h" #import "opencv2/opencv.hpp" - -constexpr CGFloat textThreshold = 0.4; -constexpr CGFloat textThresholdVertical = 0.3; -constexpr CGFloat linkThreshold = 0.4; -constexpr CGFloat lowTextThreshold = 0.7; -constexpr CGFloat centerThreshold = 0.5; -constexpr CGFloat distanceThreshold = 2.0; -constexpr CGFloat heightThreshold = 2.0; -constexpr CGFloat restoreRatio = 3.2; -constexpr CGFloat restoreRatioVertical = 2.0; -constexpr int minSideThreshold = 15; -constexpr int maxSideThreshold = 30; -constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); -constexpr int minSize = 20; - -const cv::Scalar mean(0.485, 0.456, 0.406); -const cv::Scalar variance(0.229, 0.224, 0.225); +#import "utils/Constants.h" @interface Detector : BaseModel diff --git a/ios/RnExecutorch/models/ocr/VerticalDetector.h b/ios/RnExecutorch/models/ocr/VerticalDetector.h index 8263ddd4..1c1fcd2e 100644 --- a/ios/RnExecutorch/models/ocr/VerticalDetector.h +++ b/ios/RnExecutorch/models/ocr/VerticalDetector.h @@ -1,23 +1,7 @@ #import "BaseModel.h" #import "RecognitionHandler.h" #import "opencv2/opencv.hpp" - -constexpr CGFloat textThreshold = 0.4; -constexpr CGFloat textThresholdVertical = 0.3; -constexpr CGFloat linkThreshold = 0.4; -constexpr CGFloat lowTextThreshold = 0.7; -constexpr CGFloat centerThreshold = 0.5; -constexpr CGFloat distanceThreshold = 2.0; -constexpr CGFloat heightThreshold = 2.0; -constexpr CGFloat restoreRatio = 3.2; -constexpr CGFloat restoreRatioVertical = 2.0; -constexpr int minSideThreshold = 15; -constexpr int maxSideThreshold = 30; -constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); -constexpr int minSize = 20; - -const cv::Scalar mean(0.485, 0.456, 0.406); -const cv::Scalar variance(0.229, 0.224, 0.225); +#import "utils/Constants.h" @interface VerticalDetector : BaseModel diff --git a/ios/RnExecutorch/models/ocr/VerticalDetector.mm b/ios/RnExecutorch/models/ocr/VerticalDetector.mm index a2657a00..087604dd 100644 --- a/ios/RnExecutorch/models/ocr/VerticalDetector.mm +++ b/ios/RnExecutorch/models/ocr/VerticalDetector.mm @@ -80,10 +80,9 @@ group each character into a single instance (sequence) Both matrices are outputMat2:scoreAffinityCV withSize:cv::Size(modelImageSize.width / 2, modelImageSize.height / 2)]; - CGFloat txtThreshold = textThreshold; - if (!self->detectSingleCharacters) { - txtThreshold = textThresholdVertical; - } + CGFloat txtThreshold = (self->detectSingleCharacters) ? textThreshold + : textThresholdVertical; + NSArray *bBoxesList = [DetectorUtils getDetBoxesFromTextMapVertical:scoreTextCV affinityMap:scoreAffinityCV diff --git a/ios/RnExecutorch/models/ocr/utils/Constants.h b/ios/RnExecutorch/models/ocr/utils/Constants.h new file mode 100644 index 00000000..92470511 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/Constants.h @@ -0,0 +1,17 @@ +constexpr CGFloat textThreshold = 0.4; +constexpr CGFloat textThresholdVertical = 0.3; +constexpr CGFloat linkThreshold = 0.4; +constexpr CGFloat lowTextThreshold = 0.7; +constexpr CGFloat centerThreshold = 0.5; +constexpr CGFloat distanceThreshold = 2.0; +constexpr CGFloat heightThreshold = 2.0; +constexpr CGFloat restoreRatio = 3.2; +constexpr CGFloat restoreRatioVertical = 2.0; +constexpr int minSideThreshold = 15; +constexpr int maxSideThreshold = 30; +constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); +constexpr int minSize = 20; + +const cv::Scalar mean(0.485, 0.456, 0.406); +const cv::Scalar variance(0.229, 0.224, 0.225); + diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index 62ed9fa1..0bdd6a76 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -64,10 +64,10 @@ + (NSArray *)getDetBoxesFromTextMapVertical:(cv::Mat)textMap NSMutableArray *detectedBoxes = [NSMutableArray array]; for (int i = 1; i < nLabels; i++) { const int area = stats.at(i, cv::CC_STAT_AREA); - const int width = stats.at(i, cv::CC_STAT_WIDTH); - const int height = stats.at(i, cv::CC_STAT_HEIGHT); if (area < 20) continue; + const int width = stats.at(i, cv::CC_STAT_WIDTH); + const int height = stats.at(i, cv::CC_STAT_HEIGHT); if (!independentCharacters && height < width) continue; diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm index db47ba5b..eed17a15 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm @@ -55,6 +55,7 @@ @implementation OCRUtils + (cv::Rect)extractBoundingBox:(NSArray *)coords { std::vector points; + points.reserve(coords.count); for (NSValue *value in coords) { const CGPoint point = [value CGPointValue]; diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index f3c9ec25..28e419f2 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -62,12 +62,8 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { image = [self adjustContrastGrey:image target:adjustContrast]; } - int desiredWidth; - if (isVertical){ - desiredWidth = 64; - }else{ - desiredWidth = 128; - } + int desiredWidth = (isVertical) ? 64 : 128; + if (image.cols >= 512) { desiredWidth = 512; } else if (image.cols >= 256) { @@ -233,6 +229,7 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray originalPaddings:(NSDictionary *)originalPaddings { CGPoint topLeft = [originalBbox[0] CGPointValue]; std::vector points; + points.reserve(bbox.count); for (NSValue *coords in bbox) { CGPoint point = [coords CGPointValue]; @@ -251,7 +248,7 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray point.x = point.x * [originalPaddings[@"resizeRatio"] floatValue]; point.y = point.y * [originalPaddings[@"resizeRatio"] floatValue]; - points.push_back(cv::Point2f(point.x, point.y)); + points.emplace_back(cv::Point2f(point.x, point.y)); } cv::Rect rect = cv::boundingRect(points); diff --git a/src/constants/ocr/languageDicts.ts b/src/constants/ocr/languageDicts.ts deleted file mode 100644 index fcd189b5..00000000 --- a/src/constants/ocr/languageDicts.ts +++ /dev/null @@ -1,4 +0,0 @@ -export const languageDicts: { [key: string]: string } = { - en: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/en.txt', - pl: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/pl.txt', -}; diff --git a/src/hooks/computer_vision/useOCR.ts b/src/hooks/computer_vision/useOCR.ts index 56ee04e4..72ec85a1 100644 --- a/src/hooks/computer_vision/useOCR.ts +++ b/src/hooks/computer_vision/useOCR.ts @@ -1,6 +1,5 @@ import { useEffect, useState } from 'react'; import { fetchResource } from '../../utils/fetchResource'; -import { languageDicts } from '../../constants/ocr/languageDicts'; import { symbols } from '../../constants/ocr/symbols'; import { getError, ETError } from '../../Error'; import { OCR } from '../../native/RnExecutorchModules'; @@ -45,7 +44,7 @@ export const useOCR = ({ recognizerSmall: string; }; - if (!symbols[language] || !languageDicts[language]) { + if (!symbols[language]) { setError(getError(ETError.LanguageNotSupported)); return; } diff --git a/src/hooks/computer_vision/useVerticalOCR.ts b/src/hooks/computer_vision/useVerticalOCR.ts index 3a5707d7..c2eef299 100644 --- a/src/hooks/computer_vision/useVerticalOCR.ts +++ b/src/hooks/computer_vision/useVerticalOCR.ts @@ -1,6 +1,5 @@ import { useEffect, useState } from 'react'; import { fetchResource } from '../../utils/fetchResource'; -import { languageDicts } from '../../constants/ocr/languageDicts'; import { symbols } from '../../constants/ocr/symbols'; import { getError, ETError } from '../../Error'; import { VerticalOCR } from '../../native/RnExecutorchModules'; @@ -46,37 +45,25 @@ export const useVerticalOCR = ({ ) return; - let recognizerPath; - - const detectorPaths = {} as { - detectorLarge: string; - detectorNarrow: string; - }; - - if (!symbols[language] || !languageDicts[language]) { + if (!symbols[language]) { setError(getError(ETError.LanguageNotSupported)); return; } - await Promise.all([ - fetchResource(detectorSources.detectorLarge), - fetchResource(detectorSources.detectorNarrow), - ]).then((values) => { - detectorPaths.detectorLarge = values[0]; - detectorPaths.detectorNarrow = values[1]; - }); + const recognizerPath = independentCharacters + ? await fetchResource( + recognizerSources.recognizerSmall, + setDownloadProgress + ) + : await fetchResource( + recognizerSources.recognizerLarge, + setDownloadProgress + ); - if (independentCharacters) { - recognizerPath = await fetchResource( - recognizerSources.recognizerSmall, - setDownloadProgress - ); - } else { - recognizerPath = await fetchResource( - recognizerSources.recognizerLarge, - setDownloadProgress - ); - } + const detectorPaths = { + detectorLarge: await fetchResource(detectorSources.detectorLarge), + detectorNarrow: await fetchResource(detectorSources.detectorNarrow), + }; setIsReady(false); await VerticalOCR.loadModule( diff --git a/src/modules/computer_vision/OCRModule.ts b/src/modules/computer_vision/OCRModule.ts index 26ea6f4e..50f5f9a2 100644 --- a/src/modules/computer_vision/OCRModule.ts +++ b/src/modules/computer_vision/OCRModule.ts @@ -1,4 +1,3 @@ -import { languageDicts } from '../../constants/ocr/languageDicts'; import { symbols } from '../../constants/ocr/symbols'; import { getError, ETError } from '../../Error'; import { OCR } from '../../native/RnExecutorchModules'; @@ -27,7 +26,7 @@ export class OCRModule { recognizerSmall: string; }; - if (!symbols[language] || !languageDicts[language]) { + if (!symbols[language]) { throw new Error(getError(ETError.LanguageNotSupported)); } From e04945d65bb0786a426bbda109e5bd555f815c2d Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 3 Mar 2025 13:42:31 +0100 Subject: [PATCH 05/12] feat: added function to calculate download progress with multiple models, hookless for vertical ocr --- src/hooks/computer_vision/useOCR.ts | 48 ++++++----- src/hooks/computer_vision/useVerticalOCR.ts | 26 ++++-- src/modules/computer_vision/OCRModule.ts | 43 +++++----- .../computer_vision/VerticalOCRModule.ts | 81 +++++++++++++++++++ src/types/ocr.ts | 2 + src/utils/fetchResource.ts | 14 ++++ 6 files changed, 165 insertions(+), 49 deletions(-) create mode 100644 src/modules/computer_vision/VerticalOCRModule.ts diff --git a/src/hooks/computer_vision/useOCR.ts b/src/hooks/computer_vision/useOCR.ts index 72ec85a1..a2473ccb 100644 --- a/src/hooks/computer_vision/useOCR.ts +++ b/src/hooks/computer_vision/useOCR.ts @@ -1,10 +1,13 @@ import { useEffect, useState } from 'react'; -import { fetchResource } from '../../utils/fetchResource'; +import { + calculateDownloadProgres, + fetchResource, +} from '../../utils/fetchResource'; import { symbols } from '../../constants/ocr/symbols'; import { getError, ETError } from '../../Error'; import { OCR } from '../../native/RnExecutorchModules'; import { ResourceSource } from '../../types/common'; -import { OCRDetection } from '../../types/ocr'; +import { OCRDetection, OCRLanguage } from '../../types/ocr'; interface OCRModule { error: string | null; @@ -25,7 +28,7 @@ export const useOCR = ({ recognizerMedium: ResourceSource; recognizerSmall: ResourceSource; }; - language?: string; + language?: OCRLanguage; }): OCRModule => { const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); @@ -35,33 +38,36 @@ export const useOCR = ({ useEffect(() => { const loadModel = async () => { try { - if (!detectorSource || Object.keys(recognizerSources).length === 0) + if (!detectorSource || Object.keys(recognizerSources).length !== 3) return; - const recognizerPaths = {} as { - recognizerLarge: string; - recognizerMedium: string; - recognizerSmall: string; - }; - if (!symbols[language]) { setError(getError(ETError.LanguageNotSupported)); return; } - const detectorPath = await fetchResource(detectorSource); + setIsReady(false); - await Promise.all([ - fetchResource(recognizerSources.recognizerLarge, setDownloadProgress), - fetchResource(recognizerSources.recognizerMedium), - fetchResource(recognizerSources.recognizerSmall), - ]).then((values) => { - recognizerPaths.recognizerLarge = values[0]; - recognizerPaths.recognizerMedium = values[1]; - recognizerPaths.recognizerSmall = values[2]; - }); + const detectorPath = await fetchResource( + detectorSource, + calculateDownloadProgres(4, 0, setDownloadProgress) + ); + + const recognizerPaths = { + recognizerLarge: await fetchResource( + recognizerSources.recognizerLarge, + calculateDownloadProgres(4, 1, setDownloadProgress) + ), + recognizerMedium: await fetchResource( + recognizerSources.recognizerMedium, + calculateDownloadProgres(4, 2, setDownloadProgress) + ), + recognizerSmall: await fetchResource( + recognizerSources.recognizerSmall, + calculateDownloadProgres(4, 3, setDownloadProgress) + ), + }; - setIsReady(false); await OCR.loadModule( detectorPath, recognizerPaths.recognizerLarge, diff --git a/src/hooks/computer_vision/useVerticalOCR.ts b/src/hooks/computer_vision/useVerticalOCR.ts index c2eef299..8c039ef4 100644 --- a/src/hooks/computer_vision/useVerticalOCR.ts +++ b/src/hooks/computer_vision/useVerticalOCR.ts @@ -1,10 +1,13 @@ import { useEffect, useState } from 'react'; -import { fetchResource } from '../../utils/fetchResource'; +import { + calculateDownloadProgres, + fetchResource, +} from '../../utils/fetchResource'; import { symbols } from '../../constants/ocr/symbols'; import { getError, ETError } from '../../Error'; import { VerticalOCR } from '../../native/RnExecutorchModules'; import { ResourceSource } from '../../types/common'; -import { OCRDetection } from '../../types/ocr'; +import { OCRDetection, OCRLanguage } from '../../types/ocr'; interface OCRModule { error: string | null; @@ -28,7 +31,7 @@ export const useVerticalOCR = ({ recognizerLarge: ResourceSource; recognizerSmall: ResourceSource; }; - language?: string; + language?: OCRLanguage; independentCharacters?: boolean; }): OCRModule => { const [error, setError] = useState(null); @@ -50,22 +53,29 @@ export const useVerticalOCR = ({ return; } + setIsReady(false); + const recognizerPath = independentCharacters ? await fetchResource( recognizerSources.recognizerSmall, - setDownloadProgress + calculateDownloadProgres(3, 0, setDownloadProgress) ) : await fetchResource( recognizerSources.recognizerLarge, - setDownloadProgress + calculateDownloadProgres(3, 0, setDownloadProgress) ); const detectorPaths = { - detectorLarge: await fetchResource(detectorSources.detectorLarge), - detectorNarrow: await fetchResource(detectorSources.detectorNarrow), + detectorLarge: await fetchResource( + detectorSources.detectorLarge, + calculateDownloadProgres(3, 1, setDownloadProgress) + ), + detectorNarrow: await fetchResource( + detectorSources.detectorNarrow, + calculateDownloadProgres(3, 2, setDownloadProgress) + ), }; - setIsReady(false); await VerticalOCR.loadModule( detectorPaths.detectorLarge, detectorPaths.detectorNarrow, diff --git a/src/modules/computer_vision/OCRModule.ts b/src/modules/computer_vision/OCRModule.ts index 50f5f9a2..f62d32a7 100644 --- a/src/modules/computer_vision/OCRModule.ts +++ b/src/modules/computer_vision/OCRModule.ts @@ -2,7 +2,11 @@ import { symbols } from '../../constants/ocr/symbols'; import { getError, ETError } from '../../Error'; import { OCR } from '../../native/RnExecutorchModules'; import { ResourceSource } from '../../types/common'; -import { fetchResource } from '../../utils/fetchResource'; +import { OCRLanguage } from '../../types/ocr'; +import { + calculateDownloadProgres, + fetchResource, +} from '../../utils/fetchResource'; export class OCRModule { static onDownloadProgressCallback = (_downloadProgress: number) => {}; @@ -14,36 +18,35 @@ export class OCRModule { recognizerMedium: ResourceSource; recognizerSmall: ResourceSource; }, - language = 'en' + language: OCRLanguage = 'en' ) { try { - if (!detectorSource || Object.keys(recognizerSources).length === 0) + if (!detectorSource || Object.keys(recognizerSources).length !== 3) return; - const recognizerPaths = {} as { - recognizerLarge: string; - recognizerMedium: string; - recognizerSmall: string; - }; - if (!symbols[language]) { throw new Error(getError(ETError.LanguageNotSupported)); } - const detectorPath = await fetchResource(detectorSource); + const detectorPath = await fetchResource( + detectorSource, + calculateDownloadProgres(4, 0, this.onDownloadProgressCallback) + ); - await Promise.all([ - fetchResource( + const recognizerPaths = { + recognizerLarge: await fetchResource( recognizerSources.recognizerLarge, - this.onDownloadProgressCallback + calculateDownloadProgres(4, 1, this.onDownloadProgressCallback) ), - fetchResource(recognizerSources.recognizerMedium), - fetchResource(recognizerSources.recognizerSmall), - ]).then((values) => { - recognizerPaths.recognizerLarge = values[0]; - recognizerPaths.recognizerMedium = values[1]; - recognizerPaths.recognizerSmall = values[2]; - }); + recognizerMedium: await fetchResource( + recognizerSources.recognizerMedium, + calculateDownloadProgres(4, 2, this.onDownloadProgressCallback) + ), + recognizerSmall: await fetchResource( + recognizerSources.recognizerSmall, + calculateDownloadProgres(4, 3, this.onDownloadProgressCallback) + ), + }; await OCR.loadModule( detectorPath, diff --git a/src/modules/computer_vision/VerticalOCRModule.ts b/src/modules/computer_vision/VerticalOCRModule.ts new file mode 100644 index 00000000..fb85ef6b --- /dev/null +++ b/src/modules/computer_vision/VerticalOCRModule.ts @@ -0,0 +1,81 @@ +import { symbols } from '../../constants/ocr/symbols'; +import { getError, ETError } from '../../Error'; +import { VerticalOCR } from '../../native/RnExecutorchModules'; +import { ResourceSource } from '../../types/common'; +import { OCRLanguage } from '../../types/ocr'; +import { + calculateDownloadProgres, + fetchResource, +} from '../../utils/fetchResource'; + +export class VerticalOCRModule { + static onDownloadProgressCallback = (_downloadProgress: number) => {}; + + static async load( + detectorSources: { + detectorLarge: ResourceSource; + detectorNarrow: ResourceSource; + }, + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerSmall: ResourceSource; + }, + language: OCRLanguage = 'en', + independentCharacters: boolean = false + ) { + try { + if ( + Object.keys(detectorSources).length !== 2 || + Object.keys(recognizerSources).length !== 2 + ) + return; + + if (!symbols[language]) { + throw new Error(getError(ETError.LanguageNotSupported)); + } + + const recognizerPath = independentCharacters + ? await fetchResource( + recognizerSources.recognizerSmall, + calculateDownloadProgres(3, 0, this.onDownloadProgressCallback) + ) + : await fetchResource( + recognizerSources.recognizerLarge, + calculateDownloadProgres(3, 0, this.onDownloadProgressCallback) + ); + + const detectorPaths = { + detectorLarge: await fetchResource( + detectorSources.detectorLarge, + calculateDownloadProgres(3, 1, this.onDownloadProgressCallback) + ), + detectorNarrow: await fetchResource( + detectorSources.detectorNarrow, + calculateDownloadProgres(3, 2, this.onDownloadProgressCallback) + ), + }; + + await VerticalOCR.loadModule( + detectorPaths.detectorLarge, + detectorPaths.detectorNarrow, + recognizerPath, + symbols[language], + independentCharacters + ); + } catch (e) { + throw new Error(getError(e)); + } + } + + static async forward(input: string) { + try { + return await VerticalOCR.forward(input); + } catch (e) { + throw new Error(getError(e)); + } + } + + static onDownloadProgress(callback: (downloadProgress: number) => void) { + this.onDownloadProgressCallback = callback; + } +} diff --git a/src/types/ocr.ts b/src/types/ocr.ts index f5f2e6d3..f633265f 100644 --- a/src/types/ocr.ts +++ b/src/types/ocr.ts @@ -8,3 +8,5 @@ export interface OCRBbox { x: number; y: number; } + +export type OCRLanguage = 'en'; diff --git a/src/utils/fetchResource.ts b/src/utils/fetchResource.ts index 9885758e..8164d8fe 100644 --- a/src/utils/fetchResource.ts +++ b/src/utils/fetchResource.ts @@ -80,3 +80,17 @@ export const fetchResource = async ( return fileUri; }; + +export const calculateDownloadProgres = + ( + numberOfFiles: number, + currentFileIndex: number, + setProgress: (downloadProgress: number) => void + ) => + (progress: number) => { + const contributionPerFile = 1 / numberOfFiles; + const baseProgress = contributionPerFile * currentFileIndex; + const scaledProgress = progress * contributionPerFile; + const updatedProgress = baseProgress + scaledProgress; + setProgress(updatedProgress); + }; From 7bc527c2462a963a8a640ecbc3b064e3af7aaa6a Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 3 Mar 2025 15:22:53 +0100 Subject: [PATCH 06/12] feat: added controllers for ocrs to avoid duplicated code --- src/controllers/OCRController.ts | 111 ++++++++++++++++ src/controllers/VerticalOCRController.ts | 119 ++++++++++++++++++ src/hooks/computer_vision/useOCR.ts | 82 ++---------- src/hooks/computer_vision/useVerticalOCR.ts | 91 +++----------- src/index.tsx | 1 + src/modules/computer_vision/OCRModule.ts | 57 ++------- .../computer_vision/VerticalOCRModule.ts | 68 +++------- src/native/RnExecutorchModules.ts | 48 +++++++ 8 files changed, 332 insertions(+), 245 deletions(-) create mode 100644 src/controllers/OCRController.ts create mode 100644 src/controllers/VerticalOCRController.ts diff --git a/src/controllers/OCRController.ts b/src/controllers/OCRController.ts new file mode 100644 index 00000000..a6cf1a5d --- /dev/null +++ b/src/controllers/OCRController.ts @@ -0,0 +1,111 @@ +import { symbols } from '../constants/ocr/symbols'; +import { ETError, getError } from '../Error'; +import { _OCRModule } from '../native/RnExecutorchModules'; +import { ResourceSource } from '../types/common'; +import { OCRLanguage } from '../types/ocr'; +import { + fetchResource, + calculateDownloadProgres, +} from '../utils/fetchResource'; + +export class OCRController { + private nativeModule: _OCRModule; + public isReady: boolean = false; + public isGenerating: boolean = false; + public error: string | null = null; + private modelDownloadProgressCallback: (downloadProgress: number) => void; + private isReadyCallback: (isReady: boolean) => void; + private isGeneratingCallback: (isGenerating: boolean) => void; + private errorCallback: (error: string) => void; + + constructor({ + modelDownloadProgressCallback = (_downloadProgress: number) => {}, + isReadyCallback = (_isReady: boolean) => {}, + isGeneratingCallback = (_isGenerating: boolean) => {}, + errorCallback = (_error: string) => {}, + }) { + this.nativeModule = new _OCRModule(); + this.modelDownloadProgressCallback = modelDownloadProgressCallback; + this.isReadyCallback = isReadyCallback; + this.isGeneratingCallback = isGeneratingCallback; + this.errorCallback = errorCallback; + } + + public loadModel = async ( + detectorSource: ResourceSource, + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerMedium: ResourceSource; + recognizerSmall: ResourceSource; + }, + language: OCRLanguage + ) => { + try { + if (!detectorSource || Object.keys(recognizerSources).length !== 3) + return; + + if (!symbols[language]) { + throw new Error(getError(ETError.LanguageNotSupported)); + } + this.isReady = false; + this.isReadyCallback(false); + + const detectorPath = await fetchResource( + detectorSource, + calculateDownloadProgres(4, 0, this.modelDownloadProgressCallback) + ); + + const recognizerPaths = { + recognizerLarge: await fetchResource( + recognizerSources.recognizerLarge, + calculateDownloadProgres(4, 1, this.modelDownloadProgressCallback) + ), + recognizerMedium: await fetchResource( + recognizerSources.recognizerMedium, + calculateDownloadProgres(4, 2, this.modelDownloadProgressCallback) + ), + recognizerSmall: await fetchResource( + recognizerSources.recognizerSmall, + calculateDownloadProgres(4, 3, this.modelDownloadProgressCallback) + ), + }; + + await this.nativeModule.loadModule( + detectorPath, + recognizerPaths.recognizerLarge, + recognizerPaths.recognizerMedium, + recognizerPaths.recognizerSmall, + symbols[language] + ); + + this.isReady = true; + this.isReadyCallback(this.isReady); + } catch (e) { + if (this.errorCallback) { + this.errorCallback(getError(e)); + } else { + throw new Error(getError(e)); + } + } + }; + + public forward = async (input: string) => { + if (!this.isReady) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + if (this.isGenerating) { + throw new Error(getError(ETError.ModelGenerating)); + } + + try { + this.isGenerating = true; + this.isGeneratingCallback(this.isGenerating); + return await this.nativeModule.forward(input); + } catch (e) { + throw new Error(getError(e)); + } finally { + this.isGenerating = false; + this.isGeneratingCallback(this.isGenerating); + } + }; +} diff --git a/src/controllers/VerticalOCRController.ts b/src/controllers/VerticalOCRController.ts new file mode 100644 index 00000000..f09e70a7 --- /dev/null +++ b/src/controllers/VerticalOCRController.ts @@ -0,0 +1,119 @@ +import { symbols } from '../constants/ocr/symbols'; +import { ETError, getError } from '../Error'; +import { _VerticalOCRModule } from '../native/RnExecutorchModules'; +import { ResourceSource } from '../types/common'; +import { OCRLanguage } from '../types/ocr'; +import { + fetchResource, + calculateDownloadProgres, +} from '../utils/fetchResource'; + +export class VerticalOCRController { + private nativeModule: _VerticalOCRModule; + public isReady: boolean = false; + public isGenerating: boolean = false; + public error: string | null = null; + private modelDownloadProgressCallback: (downloadProgress: number) => void; + private isReadyCallback: (isReady: boolean) => void; + private isGeneratingCallback: (isGenerating: boolean) => void; + private errorCallback: (error: string) => void; + + constructor({ + modelDownloadProgressCallback = (_downloadProgress: number) => {}, + isReadyCallback = (_isReady: boolean) => {}, + isGeneratingCallback = (_isGenerating: boolean) => {}, + errorCallback = (_error: string) => {}, + }) { + this.nativeModule = new _VerticalOCRModule(); + this.modelDownloadProgressCallback = modelDownloadProgressCallback; + this.isReadyCallback = isReadyCallback; + this.isGeneratingCallback = isGeneratingCallback; + this.errorCallback = errorCallback; + } + + public loadModel = async ( + detectorSources: { + detectorLarge: ResourceSource; + detectorNarrow: ResourceSource; + }, + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerSmall: ResourceSource; + }, + language: OCRLanguage, + independentCharacters: boolean + ) => { + try { + if ( + Object.keys(detectorSources).length !== 2 || + Object.keys(recognizerSources).length !== 2 + ) + return; + + if (!symbols[language]) { + throw new Error(getError(ETError.LanguageNotSupported)); + } + + this.isReady = false; + this.isReadyCallback(this.isReady); + + const recognizerPath = independentCharacters + ? await fetchResource( + recognizerSources.recognizerSmall, + calculateDownloadProgres(3, 0, this.modelDownloadProgressCallback) + ) + : await fetchResource( + recognizerSources.recognizerLarge, + calculateDownloadProgres(3, 0, this.modelDownloadProgressCallback) + ); + + const detectorPaths = { + detectorLarge: await fetchResource( + detectorSources.detectorLarge, + calculateDownloadProgres(3, 1, this.modelDownloadProgressCallback) + ), + detectorNarrow: await fetchResource( + detectorSources.detectorNarrow, + calculateDownloadProgres(3, 2, this.modelDownloadProgressCallback) + ), + }; + + await this.nativeModule.loadModule( + detectorPaths.detectorLarge, + detectorPaths.detectorNarrow, + recognizerPath, + symbols[language], + independentCharacters + ); + + this.isReady = true; + this.isReadyCallback(this.isReady); + } catch (e) { + if (this.errorCallback) { + this.errorCallback(getError(e)); + } else { + throw new Error(getError(e)); + } + } + }; + + public forward = async (input: string) => { + if (!this.isReady) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + if (this.isGenerating) { + throw new Error(getError(ETError.ModelGenerating)); + } + + try { + this.isGenerating = true; + this.isGeneratingCallback(this.isGenerating); + return await this.nativeModule.forward(input); + } catch (e) { + throw new Error(getError(e)); + } finally { + this.isGenerating = false; + this.isGeneratingCallback(this.isGenerating); + } + }; +} diff --git a/src/hooks/computer_vision/useOCR.ts b/src/hooks/computer_vision/useOCR.ts index a2473ccb..faa52a9c 100644 --- a/src/hooks/computer_vision/useOCR.ts +++ b/src/hooks/computer_vision/useOCR.ts @@ -1,13 +1,7 @@ import { useEffect, useState } from 'react'; -import { - calculateDownloadProgres, - fetchResource, -} from '../../utils/fetchResource'; -import { symbols } from '../../constants/ocr/symbols'; -import { getError, ETError } from '../../Error'; -import { OCR } from '../../native/RnExecutorchModules'; import { ResourceSource } from '../../types/common'; import { OCRDetection, OCRLanguage } from '../../types/ocr'; +import { OCRController } from '../../controllers/OCRController'; interface OCRModule { error: string | null; @@ -35,80 +29,30 @@ export const useOCR = ({ const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); + const [model, _] = useState( + () => + new OCRController({ + modelDownloadProgressCallback: setDownloadProgress, + isReadyCallback: setIsReady, + isGeneratingCallback: setIsGenerating, + errorCallback: setError, + }) + ); + useEffect(() => { const loadModel = async () => { - try { - if (!detectorSource || Object.keys(recognizerSources).length !== 3) - return; - - if (!symbols[language]) { - setError(getError(ETError.LanguageNotSupported)); - return; - } - - setIsReady(false); - - const detectorPath = await fetchResource( - detectorSource, - calculateDownloadProgres(4, 0, setDownloadProgress) - ); - - const recognizerPaths = { - recognizerLarge: await fetchResource( - recognizerSources.recognizerLarge, - calculateDownloadProgres(4, 1, setDownloadProgress) - ), - recognizerMedium: await fetchResource( - recognizerSources.recognizerMedium, - calculateDownloadProgres(4, 2, setDownloadProgress) - ), - recognizerSmall: await fetchResource( - recognizerSources.recognizerSmall, - calculateDownloadProgres(4, 3, setDownloadProgress) - ), - }; - - await OCR.loadModule( - detectorPath, - recognizerPaths.recognizerLarge, - recognizerPaths.recognizerMedium, - recognizerPaths.recognizerSmall, - symbols[language] - ); - setIsReady(true); - } catch (e) { - setError(getError(e)); - } + await model.loadModel(detectorSource, recognizerSources, language); }; loadModel(); // eslint-disable-next-line react-hooks/exhaustive-deps }, [detectorSource, language, JSON.stringify(recognizerSources)]); - const forward = async (input: string) => { - if (!isReady) { - throw new Error(getError(ETError.ModuleNotLoaded)); - } - if (isGenerating) { - throw new Error(getError(ETError.ModelGenerating)); - } - - try { - setIsGenerating(true); - const output = await OCR.forward(input); - return output; - } catch (e) { - throw new Error(getError(e)); - } finally { - setIsGenerating(false); - } - }; - return { error, isReady, isGenerating, - forward, + forward: model.forward, downloadProgress, }; }; diff --git a/src/hooks/computer_vision/useVerticalOCR.ts b/src/hooks/computer_vision/useVerticalOCR.ts index 8c039ef4..65e9ed73 100644 --- a/src/hooks/computer_vision/useVerticalOCR.ts +++ b/src/hooks/computer_vision/useVerticalOCR.ts @@ -1,13 +1,7 @@ import { useEffect, useState } from 'react'; -import { - calculateDownloadProgres, - fetchResource, -} from '../../utils/fetchResource'; -import { symbols } from '../../constants/ocr/symbols'; -import { getError, ETError } from '../../Error'; -import { VerticalOCR } from '../../native/RnExecutorchModules'; import { ResourceSource } from '../../types/common'; import { OCRDetection, OCRLanguage } from '../../types/ocr'; +import { VerticalOCRController } from '../../controllers/VerticalOCRController'; interface OCRModule { error: string | null; @@ -39,54 +33,24 @@ export const useVerticalOCR = ({ const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); + const [model, _] = useState( + () => + new VerticalOCRController({ + modelDownloadProgressCallback: setDownloadProgress, + isReadyCallback: setIsReady, + isGeneratingCallback: setIsGenerating, + errorCallback: setError, + }) + ); + useEffect(() => { const loadModel = async () => { - try { - if ( - Object.keys(detectorSources).length !== 2 || - Object.keys(recognizerSources).length !== 2 - ) - return; - - if (!symbols[language]) { - setError(getError(ETError.LanguageNotSupported)); - return; - } - - setIsReady(false); - - const recognizerPath = independentCharacters - ? await fetchResource( - recognizerSources.recognizerSmall, - calculateDownloadProgres(3, 0, setDownloadProgress) - ) - : await fetchResource( - recognizerSources.recognizerLarge, - calculateDownloadProgres(3, 0, setDownloadProgress) - ); - - const detectorPaths = { - detectorLarge: await fetchResource( - detectorSources.detectorLarge, - calculateDownloadProgres(3, 1, setDownloadProgress) - ), - detectorNarrow: await fetchResource( - detectorSources.detectorNarrow, - calculateDownloadProgres(3, 2, setDownloadProgress) - ), - }; - - await VerticalOCR.loadModule( - detectorPaths.detectorLarge, - detectorPaths.detectorNarrow, - recognizerPath, - symbols[language], - independentCharacters - ); - setIsReady(true); - } catch (e) { - setError(getError(e)); - } + await model.loadModel( + detectorSources, + recognizerSources, + language, + independentCharacters + ); }; loadModel(); @@ -100,30 +64,11 @@ export const useVerticalOCR = ({ JSON.stringify(recognizerSources), ]); - const forward = async (input: string) => { - if (!isReady) { - throw new Error(getError(ETError.ModuleNotLoaded)); - } - if (isGenerating) { - throw new Error(getError(ETError.ModelGenerating)); - } - - try { - setIsGenerating(true); - const output = await VerticalOCR.forward(input); - return output; - } catch (e) { - throw new Error(getError(e)); - } finally { - setIsGenerating(false); - } - }; - return { error, isReady, isGenerating, - forward, + forward: model.forward, downloadProgress, }; }; diff --git a/src/index.tsx b/src/index.tsx index c9b6fca0..9d50e776 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -14,6 +14,7 @@ export * from './modules/computer_vision/ClassificationModule'; export * from './modules/computer_vision/ObjectDetectionModule'; export * from './modules/computer_vision/StyleTransferModule'; export * from './modules/computer_vision/OCRModule'; +export * from './modules/computer_vision/VerticalOCRModule'; export * from './modules/natural_language_processing/LLMModule'; diff --git a/src/modules/computer_vision/OCRModule.ts b/src/modules/computer_vision/OCRModule.ts index f62d32a7..c7a28ef6 100644 --- a/src/modules/computer_vision/OCRModule.ts +++ b/src/modules/computer_vision/OCRModule.ts @@ -1,14 +1,10 @@ -import { symbols } from '../../constants/ocr/symbols'; -import { getError, ETError } from '../../Error'; -import { OCR } from '../../native/RnExecutorchModules'; +import { OCRController } from '../../controllers/OCRController'; import { ResourceSource } from '../../types/common'; import { OCRLanguage } from '../../types/ocr'; -import { - calculateDownloadProgres, - fetchResource, -} from '../../utils/fetchResource'; export class OCRModule { + static module: OCRController; + static onDownloadProgressCallback = (_downloadProgress: number) => {}; static async load( @@ -20,52 +16,15 @@ export class OCRModule { }, language: OCRLanguage = 'en' ) { - try { - if (!detectorSource || Object.keys(recognizerSources).length !== 3) - return; - - if (!symbols[language]) { - throw new Error(getError(ETError.LanguageNotSupported)); - } - - const detectorPath = await fetchResource( - detectorSource, - calculateDownloadProgres(4, 0, this.onDownloadProgressCallback) - ); - - const recognizerPaths = { - recognizerLarge: await fetchResource( - recognizerSources.recognizerLarge, - calculateDownloadProgres(4, 1, this.onDownloadProgressCallback) - ), - recognizerMedium: await fetchResource( - recognizerSources.recognizerMedium, - calculateDownloadProgres(4, 2, this.onDownloadProgressCallback) - ), - recognizerSmall: await fetchResource( - recognizerSources.recognizerSmall, - calculateDownloadProgres(4, 3, this.onDownloadProgressCallback) - ), - }; + this.module = new OCRController({ + modelDownloadProgressCallback: this.onDownloadProgressCallback, + }); - await OCR.loadModule( - detectorPath, - recognizerPaths.recognizerLarge, - recognizerPaths.recognizerMedium, - recognizerPaths.recognizerSmall, - symbols[language] - ); - } catch (e) { - throw new Error(getError(e)); - } + await this.module.loadModel(detectorSource, recognizerSources, language); } static async forward(input: string) { - try { - return await OCR.forward(input); - } catch (e) { - throw new Error(getError(e)); - } + return await this.module.forward(input); } static onDownloadProgress(callback: (downloadProgress: number) => void) { diff --git a/src/modules/computer_vision/VerticalOCRModule.ts b/src/modules/computer_vision/VerticalOCRModule.ts index fb85ef6b..4c8b1120 100644 --- a/src/modules/computer_vision/VerticalOCRModule.ts +++ b/src/modules/computer_vision/VerticalOCRModule.ts @@ -1,14 +1,10 @@ -import { symbols } from '../../constants/ocr/symbols'; -import { getError, ETError } from '../../Error'; -import { VerticalOCR } from '../../native/RnExecutorchModules'; +import { VerticalOCRController } from '../../controllers/VerticalOCRController'; import { ResourceSource } from '../../types/common'; import { OCRLanguage } from '../../types/ocr'; -import { - calculateDownloadProgres, - fetchResource, -} from '../../utils/fetchResource'; export class VerticalOCRModule { + static module: VerticalOCRController; + static onDownloadProgressCallback = (_downloadProgress: number) => {}; static async load( @@ -23,56 +19,20 @@ export class VerticalOCRModule { language: OCRLanguage = 'en', independentCharacters: boolean = false ) { - try { - if ( - Object.keys(detectorSources).length !== 2 || - Object.keys(recognizerSources).length !== 2 - ) - return; - - if (!symbols[language]) { - throw new Error(getError(ETError.LanguageNotSupported)); - } - - const recognizerPath = independentCharacters - ? await fetchResource( - recognizerSources.recognizerSmall, - calculateDownloadProgres(3, 0, this.onDownloadProgressCallback) - ) - : await fetchResource( - recognizerSources.recognizerLarge, - calculateDownloadProgres(3, 0, this.onDownloadProgressCallback) - ); - - const detectorPaths = { - detectorLarge: await fetchResource( - detectorSources.detectorLarge, - calculateDownloadProgres(3, 1, this.onDownloadProgressCallback) - ), - detectorNarrow: await fetchResource( - detectorSources.detectorNarrow, - calculateDownloadProgres(3, 2, this.onDownloadProgressCallback) - ), - }; - - await VerticalOCR.loadModule( - detectorPaths.detectorLarge, - detectorPaths.detectorNarrow, - recognizerPath, - symbols[language], - independentCharacters - ); - } catch (e) { - throw new Error(getError(e)); - } + this.module = new VerticalOCRController({ + modelDownloadProgressCallback: this.onDownloadProgressCallback, + }); + + await this.module.loadModel( + detectorSources, + recognizerSources, + language, + independentCharacters + ); } static async forward(input: string) { - try { - return await VerticalOCR.forward(input); - } catch (e) { - throw new Error(getError(e)); - } + return await this.module.forward(input); } static onDownloadProgress(callback: (downloadProgress: number) => void) { diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index 49e4b89e..49ac1e52 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -3,6 +3,8 @@ import { Spec as ClassificationInterface } from './NativeClassification'; import { Spec as ObjectDetectionInterface } from './NativeObjectDetection'; import { Spec as StyleTransferInterface } from './NativeStyleTransfer'; import { Spec as ETModuleInterface } from './NativeETModule'; +import { Spec as OCRInterface } from './NativeOCR'; +import { Spec as VerticalOCRInterface } from './NativeVerticalOCR'; const LINKING_ERROR = `The package 'react-native-executorch' doesn't seem to be linked. Make sure: \n\n` + @@ -167,6 +169,50 @@ class _ClassificationModule { } } +class _OCRModule { + async forward(input: string): ReturnType { + return await OCR.forward(input); + } + + async loadModule( + detectorSource: string, + recognizerSourceLarge: string, + recognizerSourceMedium: string, + recognizerSourceSmall: string, + symbols: string + ) { + return await OCR.loadModule( + detectorSource, + recognizerSourceLarge, + recognizerSourceMedium, + recognizerSourceSmall, + symbols + ); + } +} + +class _VerticalOCRModule { + async forward(input: string): ReturnType { + return await VerticalOCR.forward(input); + } + + async loadModule( + detectorLargeSource: string, + detectorMediumSource: string, + recognizerSource: string, + symbols: string, + independentCharacters: boolean + ): ReturnType { + return await VerticalOCR.loadModule( + detectorLargeSource, + detectorMediumSource, + recognizerSource, + symbols, + independentCharacters + ); + } +} + class _ETModule { async forward( inputs: number[][], @@ -201,4 +247,6 @@ export { _StyleTransferModule, _ObjectDetectionModule, _SpeechToTextModule, + _OCRModule, + _VerticalOCRModule, }; From dd6fd567feb81b70247f7c73e78caa242c5af38b Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 4 Mar 2025 11:23:48 +0100 Subject: [PATCH 07/12] feat: added thresholding to single character processing --- ios/RnExecutorch/VerticalOCR.mm | 1 + .../models/ocr/utils/RecognizerUtils.h | 1 + .../models/ocr/utils/RecognizerUtils.mm | 37 +++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/ios/RnExecutorch/VerticalOCR.mm b/ios/RnExecutorch/VerticalOCR.mm index d195fb11..2d9ed579 100644 --- a/ios/RnExecutorch/VerticalOCR.mm +++ b/ios/RnExecutorch/VerticalOCR.mm @@ -96,6 +96,7 @@ - (void)forward:(NSString *)input NSDictionary *paddingsBox = [RecognizerUtils calculateResizeRatioAndPaddings:boxWidth height:boxHeight desiredWidth:320 desiredHeight:1280]; cv::Mat croppedCharacter = [RecognizerUtils cropImageWithBoundingBox:image bbox:boxCords originalBbox:cords paddings:paddingsBox originalPaddings:paddings]; if(self->independentCharacters){ + croppedCharacter = [RecognizerUtils cropSingleCharacter:croppedCharacter]; croppedCharacter = [RecognizerUtils normalizeForRecognizer:croppedCharacter adjustContrast:0.0 isVertical: YES]; NSArray *recognitionResult = [recognizer runModel:croppedCharacter]; NSArray *predIndex = [recognitionResult objectAtIndex:0]; diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h index eb47a0b4..51d93638 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h @@ -30,5 +30,6 @@ originalBbox:(NSArray *)originalBbox paddings:(NSDictionary *)paddings originalPaddings:(NSDictionary *)originalPaddings; ++ (cv::Mat)cropSingleCharacter:(cv::Mat)img; @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index 28e419f2..253f84c1 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -260,4 +260,41 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray return croppedImage; } ++ (cv::Mat)cropSingleCharacter:(cv::Mat)img { + cv::Mat thresh; + cv::threshold(img, thresh, 0, 255, cv::THRESH_BINARY + cv::THRESH_OTSU); + + cv::Mat labels, stats, centroids; + const int numLabels = connectedComponentsWithStats(thresh, labels, stats, centroids, 8); + const CGFloat centralThreshold = 0.3; + const int height = thresh.rows; + const int width = thresh.cols; + + const int minX = centralThreshold * width; + const int maxX = (1 - centralThreshold) * width; + const int minY = centralThreshold * height; + const int maxY = (1 - centralThreshold) * height; + + int selectedComponent = -1; + + for (int i = 1; i < numLabels; i++) { + const int area = stats.at(i, cv::CC_STAT_AREA); + const double cx = centroids.at(i, 0); + const double cy = centroids.at(i, 1); + + if (minX < cx && cx < maxX && minY < cy && cy < maxY && area > 70) { + if (selectedComponent == -1 || area > stats.at(selectedComponent, cv::CC_STAT_AREA)) { + selectedComponent = i; + } + } + } + cv::Mat mask = cv::Mat::zeros(img.size(), CV_8UC1); + if (selectedComponent != -1) { + mask = (labels == selectedComponent) / 255; + } + cv::Mat resultImage = cv::Mat::zeros(img.size(), img.type()); + img.copyTo(resultImage, mask); + + return resultImage; +} @end From 68b6aae67b5020343eeffb35ee1964ffc3ce92bf Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 4 Mar 2025 12:55:32 +0100 Subject: [PATCH 08/12] feat: improved pipeline for single character processing --- .../com/swmansion/rnexecutorch/VerticalOCR.kt | 1 + .../models/ocr/utils/RecognizerUtils.kt | 67 +++++++++++++++++++ .../models/ocr/utils/RecognizerUtils.mm | 62 ++++++++++++----- 3 files changed, 114 insertions(+), 16 deletions(-) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt b/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt index 859ebaa7..2d800677 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt @@ -112,6 +112,7 @@ class VerticalOCR(reactContext: ReactApplicationContext) : ) if (this.independentCharacters) { + croppedCharacter = RecognizerUtils.cropSingleCharacter(croppedCharacter) croppedCharacter = RecognizerUtils.normalizeForRecognizer(croppedCharacter, 0.0, true) val recognitionResult = recognizer.runModel(croppedCharacter) val predIndex = recognitionResult.first diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt index 654aaa41..c21dd90f 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt @@ -4,6 +4,8 @@ import com.swmansion.rnexecutorch.utils.ImageProcessor import org.opencv.core.Core import org.opencv.core.CvType import org.opencv.core.Mat +import org.opencv.core.MatOfFloat +import org.opencv.core.MatOfInt import org.opencv.core.MatOfPoint2f import org.opencv.core.Point import org.opencv.core.Rect @@ -320,5 +322,70 @@ class RecognizerUtils { return boundingBox } + + fun cropSingleCharacter(img: Mat): Mat { + val histogram = Mat() + val histSize = MatOfInt(256) + val range = MatOfFloat(0f, 256f) + Imgproc.calcHist( + listOf(img), + MatOfInt(0), + Mat(), + histogram, + histSize, + range + ) + + val midPoint = 256 / 2 + var sumLeft = 0.0 + var sumRight = 0.0 + for (i in 0 until midPoint) { + sumLeft += histogram.get(i, 0)[0] + } + for (i in midPoint until 256) { + sumRight += histogram.get(i, 0)[0] + } + + val thresholdType = if (sumLeft < sumRight) Imgproc.THRESH_BINARY_INV else Imgproc.THRESH_BINARY + + val thresh = Mat() + Imgproc.threshold(img, thresh, 0.0, 255.0, thresholdType + Imgproc.THRESH_OTSU) + + val labels = Mat() + val stats = Mat() + val centroids = Mat() + val numLabels = Imgproc.connectedComponentsWithStats(thresh, labels, stats, centroids, 8) + + val centralThreshold = 0.3 + val height = thresh.rows() + val width = thresh.cols() + val minX = centralThreshold * width + val maxX = (1 - centralThreshold) * width + val minY = centralThreshold * height + val maxY = (1 - centralThreshold) * height + + var selectedComponent = -1 + for (i in 1 until numLabels) { + val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() + val cx = centroids.get(i, 0)[0] + val cy = centroids.get(i, 1)[0] + if (cx > minX && cx < maxX && cy > minY && cy < maxY && area > 70) { + if (selectedComponent == -1 || area > stats.get(selectedComponent, Imgproc.CC_STAT_AREA)[0]) { + selectedComponent = i + } + } + } + + val mask = Mat.zeros(img.size(), CvType.CV_8UC1) + if (selectedComponent != -1) { + Core.compare(labels, Scalar(selectedComponent.toDouble()), mask, Core.CMP_EQ) + } + + val resultImage = Mat.zeros(img.size(), img.type()) + img.copyTo(resultImage, mask) + + Core.bitwise_not(resultImage, resultImage) + return resultImage + } } } diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index 253f84c1..f589e79d 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -261,40 +261,70 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray } + (cv::Mat)cropSingleCharacter:(cv::Mat)img { + cv::Mat histogram; + + int histSize = 256; + float range[] = {0, 256}; + const float *histRange = {range}; + bool uniform = true, accumulate = false; + + cv::calcHist(&img, 1, 0, cv::Mat(), histogram, 1, &histSize, &histRange, uniform, + accumulate); + + int midPoint = histSize / 2; + + double sumLeft = 0.0, sumRight = 0.0; + for (int i = 0; i < midPoint; i++) { + sumLeft += histogram.at(i); + } + for (int i = midPoint; i < histSize; i++) { + sumRight += histogram.at(i); + } + + int thresholdType; + if (sumLeft < sumRight) { + thresholdType = cv::THRESH_BINARY_INV; + } else { + thresholdType = cv::THRESH_BINARY; + } + cv::Mat thresh; - cv::threshold(img, thresh, 0, 255, cv::THRESH_BINARY + cv::THRESH_OTSU); - + cv::threshold(img, thresh, 0, 255, thresholdType + cv::THRESH_OTSU); + cv::Mat labels, stats, centroids; - const int numLabels = connectedComponentsWithStats(thresh, labels, stats, centroids, 8); + const int numLabels = + connectedComponentsWithStats(thresh, labels, stats, centroids, 8); const CGFloat centralThreshold = 0.3; const int height = thresh.rows; const int width = thresh.cols; - + const int minX = centralThreshold * width; const int maxX = (1 - centralThreshold) * width; const int minY = centralThreshold * height; const int maxY = (1 - centralThreshold) * height; - + int selectedComponent = -1; - + for (int i = 1; i < numLabels; i++) { - const int area = stats.at(i, cv::CC_STAT_AREA); - const double cx = centroids.at(i, 0); - const double cy = centroids.at(i, 1); - - if (minX < cx && cx < maxX && minY < cy && cy < maxY && area > 70) { - if (selectedComponent == -1 || area > stats.at(selectedComponent, cv::CC_STAT_AREA)) { - selectedComponent = i; - } + const int area = stats.at(i, cv::CC_STAT_AREA); + const double cx = centroids.at(i, 0); + const double cy = centroids.at(i, 1); + + if (minX < cx && cx < maxX && minY < cy && cy < maxY && area > 70) { + if (selectedComponent == -1 || + area > stats.at(selectedComponent, cv::CC_STAT_AREA)) { + selectedComponent = i; } + } } cv::Mat mask = cv::Mat::zeros(img.size(), CV_8UC1); if (selectedComponent != -1) { - mask = (labels == selectedComponent) / 255; + mask = (labels == selectedComponent) / 255; } cv::Mat resultImage = cv::Mat::zeros(img.size(), img.type()); img.copyTo(resultImage, mask); - + cv::bitwise_not(resultImage, resultImage); return resultImage; } + @end From 70e9139f48bb993a06f61866ec6972fb8b5f7bd3 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 4 Mar 2025 14:05:59 +0100 Subject: [PATCH 09/12] feat: create constants for magic numbers(ios) --- ios/RnExecutorch/OCR.h | 2 -- ios/RnExecutorch/OCR.mm | 2 ++ ios/RnExecutorch/VerticalOCR.mm | 6 ++--- .../models/ocr/RecognitionHandler.h | 7 ----- .../models/ocr/RecognitionHandler.mm | 7 ++--- ios/RnExecutorch/models/ocr/Recognizer.mm | 8 +++--- ios/RnExecutorch/models/ocr/utils/Constants.h | 15 ++++++++--- .../models/ocr/utils/RecognizerUtils.mm | 26 ++++++++++--------- 8 files changed, 39 insertions(+), 34 deletions(-) diff --git a/ios/RnExecutorch/OCR.h b/ios/RnExecutorch/OCR.h index 68c08785..4994108b 100644 --- a/ios/RnExecutorch/OCR.h +++ b/ios/RnExecutorch/OCR.h @@ -1,7 +1,5 @@ #import -constexpr CGFloat recognizerRatio = 1.6; - @interface OCR : NSObject @end diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm index 59740c90..509e3876 100644 --- a/ios/RnExecutorch/OCR.mm +++ b/ios/RnExecutorch/OCR.mm @@ -2,6 +2,7 @@ #import "models/ocr/Detector.h" #import "models/ocr/RecognitionHandler.h" #import "utils/ImageProcessor.h" +#import "models/ocr/utils/Constants.h" #import #import @@ -80,6 +81,7 @@ of different sizes (e.g. large - 512x64, medium - 256x64, small - 128x64). cv::Mat image = [ImageProcessor readImage:input]; NSArray *result = [detector runModel:image]; cv::Size detectorSize = [detector getModelImageSize]; + const CGFloat recognizerRatio = recognizerImageSize / detectorSize.width; cv::cvtColor(image, image, cv::COLOR_BGR2GRAY); result = [self->recognitionHandler recognize:result diff --git a/ios/RnExecutorch/VerticalOCR.mm b/ios/RnExecutorch/VerticalOCR.mm index 2d9ed579..ef5e58a2 100644 --- a/ios/RnExecutorch/VerticalOCR.mm +++ b/ios/RnExecutorch/VerticalOCR.mm @@ -90,10 +90,10 @@ - (void)forward:(NSString *)input NSNumber *confidenceScore = @0.0; NSArray *boxResult = [detectorNarrow runModel:croppedImage]; std::vector croppedCharacters; - + cv::Size narrowRecognizerSize = [detectorNarrow getModelImageSize]; for(NSDictionary *characterBox in boxResult){ NSArray *boxCords = characterBox[@"bbox"]; - NSDictionary *paddingsBox = [RecognizerUtils calculateResizeRatioAndPaddings:boxWidth height:boxHeight desiredWidth:320 desiredHeight:1280]; + NSDictionary *paddingsBox = [RecognizerUtils calculateResizeRatioAndPaddings:boxWidth height:boxHeight desiredWidth:narrowRecognizerSize.width desiredHeight:narrowRecognizerSize.height]; cv::Mat croppedCharacter = [RecognizerUtils cropImageWithBoundingBox:image bbox:boxCords originalBbox:cords paddings:paddingsBox originalPaddings:paddings]; if(self->independentCharacters){ croppedCharacter = [RecognizerUtils cropSingleCharacter:croppedCharacter]; @@ -113,7 +113,7 @@ - (void)forward:(NSString *)input }else{ cv::Mat mergedCharacters; cv::hconcat(croppedCharacters.data(), (int)croppedCharacters.size(), mergedCharacters); - mergedCharacters = [OCRUtils resizeWithPadding:mergedCharacters desiredWidth:512 desiredHeight:64]; + mergedCharacters = [OCRUtils resizeWithPadding:mergedCharacters desiredWidth:largeRecognizerWidth desiredHeight:recognizerHeight]; mergedCharacters = [RecognizerUtils normalizeForRecognizer:mergedCharacters adjustContrast:0.0 isVertical: NO]; NSArray *recognitionResult = [recognizer runModel:mergedCharacters]; NSArray *predIndex = [recognitionResult objectAtIndex:0]; diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.h b/ios/RnExecutorch/models/ocr/RecognitionHandler.h index 41250437..7f674d98 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.h +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.h @@ -1,12 +1,5 @@ #import "opencv2/opencv.hpp" -constexpr int modelHeight = 64; -constexpr int largeModelWidth = 512; -constexpr int mediumModelWidth = 256; -constexpr int smallModelWidth = 128; -constexpr CGFloat lowConfidenceThreshold = 0.3; -constexpr CGFloat adjustContrast = 0.2; - @interface RecognitionHandler : NSObject - (instancetype)initWithSymbols:(NSString *)symbols; diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index 5793c646..a2632977 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -3,6 +3,7 @@ #import "./utils/CTCLabelConverter.h" #import "./utils/OCRUtils.h" #import "./utils/RecognizerUtils.h" +#import "./utils/Constants.h" #import "ExecutorchLib/ETModel.h" #import "Recognizer.h" #import @@ -72,9 +73,9 @@ - (void)loadRecognizers:(NSString *)largeRecognizerPath - (NSArray *)runModel:(cv::Mat)croppedImage { NSArray *result; - if (croppedImage.cols >= largeModelWidth) { + if (croppedImage.cols >= largeRecognizerWidth) { result = [recognizerLarge runModel:croppedImage]; - } else if (croppedImage.cols >= mediumModelWidth) { + } else if (croppedImage.cols >= mediumRecognizerWidth) { result = [recognizerMedium runModel:croppedImage]; } else { result = [recognizerSmall runModel:croppedImage]; @@ -103,7 +104,7 @@ - (NSArray *)recognize:(NSArray *)bBoxesList for (NSDictionary *box in bBoxesList) { cv::Mat croppedImage = [RecognizerUtils getCroppedImage:box image:imgGray - modelHeight:modelHeight]; + modelHeight:recognizerHeight]; if (croppedImage.empty()) { continue; } diff --git a/ios/RnExecutorch/models/ocr/Recognizer.mm b/ios/RnExecutorch/models/ocr/Recognizer.mm index 8b339bc2..e3ee9089 100644 --- a/ios/RnExecutorch/models/ocr/Recognizer.mm +++ b/ios/RnExecutorch/models/ocr/Recognizer.mm @@ -14,8 +14,8 @@ @implementation Recognizer { - (cv::Size)getModelImageSize { NSArray *inputShape = [module getInputShape:@0]; - NSNumber *widthNumber = inputShape.lastObject; - NSNumber *heightNumber = inputShape[inputShape.count - 2]; + NSNumber *widthNumber = inputShape[inputShape.count - 2]; + NSNumber *heightNumber = inputShape.lastObject; const int height = [heightNumber intValue]; const int width = [widthNumber intValue]; @@ -24,8 +24,8 @@ @implementation Recognizer { - (cv::Size)getModelOutputSize { NSArray *outputShape = [module getOutputShape:@0]; - NSNumber *widthNumber = outputShape.lastObject; - NSNumber *heightNumber = outputShape[outputShape.count - 2]; + NSNumber *widthNumber = outputShape[outputShape.count - 2]; + NSNumber *heightNumber = outputShape.lastObject; const int height = [heightNumber intValue]; const int width = [widthNumber intValue]; diff --git a/ios/RnExecutorch/models/ocr/utils/Constants.h b/ios/RnExecutorch/models/ocr/utils/Constants.h index 92470511..ba1e1622 100644 --- a/ios/RnExecutorch/models/ocr/utils/Constants.h +++ b/ios/RnExecutorch/models/ocr/utils/Constants.h @@ -7,11 +7,20 @@ constexpr CGFloat distanceThreshold = 2.0; constexpr CGFloat heightThreshold = 2.0; constexpr CGFloat restoreRatio = 3.2; constexpr CGFloat restoreRatioVertical = 2.0; +constexpr CGFloat singleCharacterCenterThreshold = 0.3; +constexpr CGFloat lowConfidenceThreshold = 0.3; +constexpr CGFloat adjustContrast = 0.2; constexpr int minSideThreshold = 15; constexpr int maxSideThreshold = 30; -constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); +constexpr int recognizerHeight = 64; +constexpr int largeRecognizerWidth = 512; +constexpr int mediumRecognizerWidth = 256; +constexpr int smallRecognizerWidth = 128; +constexpr int smallVerticalRecognizerWidth = 64; +constexpr int maxWidth = largeRecognizerWidth + (largeRecognizerWidth * 0.15); constexpr int minSize = 20; +constexpr int singleCharacterMinSize = 70; +constexpr int recognizerImageSize = 1280; const cv::Scalar mean(0.485, 0.456, 0.406); -const cv::Scalar variance(0.229, 0.224, 0.225); - +const cv::Scalar variance(0.229, 0.224, 0.225); \ No newline at end of file diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index f589e79d..fe52f3b6 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -1,5 +1,6 @@ #import "RecognizerUtils.h" #import "OCRUtils.h" +#import "Constants.h" @implementation RecognizerUtils @@ -62,17 +63,17 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { image = [self adjustContrastGrey:image target:adjustContrast]; } - int desiredWidth = (isVertical) ? 64 : 128; + int desiredWidth = (isVertical) ? smallVerticalRecognizerWidth : smallRecognizerWidth; - if (image.cols >= 512) { - desiredWidth = 512; - } else if (image.cols >= 256) { - desiredWidth = 256; + if (image.cols >= largeRecognizerWidth) { + desiredWidth = largeRecognizerWidth; + } else if (image.cols >= mediumRecognizerWidth) { + desiredWidth = mediumRecognizerWidth; } image = [OCRUtils resizeWithPadding:image desiredWidth:desiredWidth - desiredHeight:64]; + desiredHeight:recognizerHeight]; image.convertTo(image, CV_32F, 1.0 / 255.0); image = (image - 0.5) * 2.0; @@ -253,14 +254,15 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray cv::Rect rect = cv::boundingRect(points); cv::Mat croppedImage = img(rect); - cv::cvtColor(croppedImage, croppedImage, cv::COLOR_BGR2GRAY); - cv::resize(croppedImage, croppedImage, cv::Size(64, 64), 0, 0, - cv::INTER_AREA); - cv::medianBlur(croppedImage, croppedImage, 1); return croppedImage; } + (cv::Mat)cropSingleCharacter:(cv::Mat)img { + cv::cvtColor(img, img, cv::COLOR_BGR2GRAY); + cv::resize(img, img, cv::Size(smallVerticalRecognizerWidth, recognizerHeight), 0, 0, + cv::INTER_AREA); + cv::medianBlur(img, img, 1); + cv::Mat histogram; int histSize = 256; @@ -294,7 +296,7 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray cv::Mat labels, stats, centroids; const int numLabels = connectedComponentsWithStats(thresh, labels, stats, centroids, 8); - const CGFloat centralThreshold = 0.3; + const CGFloat centralThreshold = singleCharacterCenterThreshold; const int height = thresh.rows; const int width = thresh.cols; @@ -310,7 +312,7 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray const double cx = centroids.at(i, 0); const double cy = centroids.at(i, 1); - if (minX < cx && cx < maxX && minY < cy && cy < maxY && area > 70) { + if (minX < cx && cx < maxX && minY < cy && cy < maxY && area > singleCharacterMinSize) { if (selectedComponent == -1 || area > stats.at(selectedComponent, cv::CC_STAT_AREA)) { selectedComponent = i; From b8514b226dea7ee9458bab95c086cc2096e52b12 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 4 Mar 2025 15:01:59 +0100 Subject: [PATCH 10/12] feat: add const for min size magic number(android) --- .../com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt | 1 + .../swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt | 2 +- .../swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt index 6e65bd64..5dc25cd7 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt @@ -24,6 +24,7 @@ class Constants { const val MAX_SIDE_THRESHOLD = 30 const val MAX_WIDTH = (LARGE_MODEL_WIDTH + (LARGE_MODEL_WIDTH * 0.15)).toInt() const val MIN_SIZE = 20 + const val SINGLE_CHARACTER_MIN_SIZE = 70 val MEAN = Scalar(0.485, 0.456, 0.406) val VARIANCE = Scalar(0.229, 0.224, 0.225) } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt index b07123ac..c6b5789a 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt @@ -332,7 +332,7 @@ class DetectorUtils { val detectedBoxes = mutableListOf() for (i in 1 until nLabels) { val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() - if (area < 20) continue + if (area < Constants.MIN_SIZE) continue val height = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() val width = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt index c21dd90f..1847e8ee 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt @@ -369,7 +369,7 @@ class RecognizerUtils { val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() val cx = centroids.get(i, 0)[0] val cy = centroids.get(i, 1)[0] - if (cx > minX && cx < maxX && cy > minY && cy < maxY && area > 70) { + if (cx > minX && cx < maxX && cy > minY && cy < maxY && area > Constants.SINGLE_CHARACTER_MIN_SIZE) { if (selectedComponent == -1 || area > stats.get(selectedComponent, Imgproc.CC_STAT_AREA)[0]) { selectedComponent = i } From 3f6a13ab7cfe04e9452badd947f79d6788c02fe1 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 4 Mar 2025 15:07:31 +0100 Subject: [PATCH 11/12] fix: set progress to 1 after every file is downloaded --- src/utils/fetchResource.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/utils/fetchResource.ts b/src/utils/fetchResource.ts index 8164d8fe..18dd959e 100644 --- a/src/utils/fetchResource.ts +++ b/src/utils/fetchResource.ts @@ -93,4 +93,7 @@ export const calculateDownloadProgres = const scaledProgress = progress * contributionPerFile; const updatedProgress = baseProgress + scaledProgress; setProgress(updatedProgress); + if (progress === 1 && currentFileIndex === numberOfFiles - 1) { + setProgress(1); + } }; From 42aa0f23e6862dd9a145924fe75e3ea196aebb84 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Tue, 4 Mar 2025 16:14:24 +0100 Subject: [PATCH 12/12] fix: suggested changes --- ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm | 7 +------ src/utils/fetchResource.ts | 7 ++++--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index fe52f3b6..1908ad6f 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -283,12 +283,7 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray sumRight += histogram.at(i); } - int thresholdType; - if (sumLeft < sumRight) { - thresholdType = cv::THRESH_BINARY_INV; - } else { - thresholdType = cv::THRESH_BINARY; - } + const int thresholdType = (sumLeft < sumRight) ? cv::THRESH_BINARY_INV : cv::THRESH_BINARY; cv::Mat thresh; cv::threshold(img, thresh, 0, 255, thresholdType + cv::THRESH_OTSU); diff --git a/src/utils/fetchResource.ts b/src/utils/fetchResource.ts index 18dd959e..ecaec034 100644 --- a/src/utils/fetchResource.ts +++ b/src/utils/fetchResource.ts @@ -88,12 +88,13 @@ export const calculateDownloadProgres = setProgress: (downloadProgress: number) => void ) => (progress: number) => { + if (progress === 1 && currentFileIndex === numberOfFiles - 1) { + setProgress(1); + return; + } const contributionPerFile = 1 / numberOfFiles; const baseProgress = contributionPerFile * currentFileIndex; const scaledProgress = progress * contributionPerFile; const updatedProgress = baseProgress + scaledProgress; setProgress(updatedProgress); - if (progress === 1 && currentFileIndex === numberOfFiles - 1) { - setProgress(1); - } };