From b5ef9596f4e548d46e45fc64681c5094221973ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Louren=C3=A7o=20Silva?= Date: Mon, 19 Feb 2024 12:02:59 +0000 Subject: [PATCH 1/4] feature: yolov8 support --- .../models/test_yolov8.py | 99 +++++++++++++++ unstructured_inference/constants.py | 1 + unstructured_inference/models/base.py | 9 ++ unstructured_inference/models/yolov8.py | 114 ++++++++++++++++++ 4 files changed, 223 insertions(+) create mode 100644 test_unstructured_inference/models/test_yolov8.py create mode 100644 unstructured_inference/models/yolov8.py diff --git a/test_unstructured_inference/models/test_yolov8.py b/test_unstructured_inference/models/test_yolov8.py new file mode 100644 index 00000000..7616b1ac --- /dev/null +++ b/test_unstructured_inference/models/test_yolov8.py @@ -0,0 +1,99 @@ +import os + +import pytest + +from unstructured_inference.inference.layout import process_file_with_model + + +@pytest.mark.slow() +def test_layout_yolov8_local_parsing_image(): + filename = os.path.join("sample-docs", "test-image.jpg") + # NOTE(benjamin) keep_output = True create a file for each image in + # localstorage for visualization of the result + document_layout = process_file_with_model(filename, model_name="yolov8s", is_image=True) + # NOTE(benjamin) The example image should result in one page result + assert len(document_layout.pages) == 1 + # NOTE(benjamin) The example sent to the test contains 13 detections + types_known = ["Text", "Section-header", "Page-header"] + known_regions = [e for e in document_layout.pages[0].elements if e.type in types_known] + assert len(known_regions) == 13 + assert hasattr( + document_layout.pages[0].elements[0], + "prob", + ) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities + assert isinstance( + document_layout.pages[0].elements[0].prob, + float, + ) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float + + +@pytest.mark.slow() +def test_layout_yolov8_local_parsing_pdf(): + filename = os.path.join("sample-docs", "loremipsum.pdf") + document_layout = process_file_with_model(filename, model_name="yolov8s") + assert len(document_layout.pages) == 1 + # NOTE(benjamin) The example sent to the test contains 5 text detections + text_elements = [e for e in document_layout.pages[0].elements if e.type == "Text"] + assert len(text_elements) == 5 + assert hasattr( + document_layout.pages[0].elements[0], + "prob", + ) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities + assert isinstance( + document_layout.pages[0].elements[0].prob, + float, + ) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float + + +@pytest.mark.slow() +def test_layout_yolov8_local_parsing_empty_pdf(): + filename = os.path.join("sample-docs", "empty-document.pdf") + document_layout = process_file_with_model(filename, model_name="yolov8s") + assert len(document_layout.pages) == 1 + # NOTE(benjamin) The example sent to the test contains 0 detections + assert len(document_layout.pages[0].elements) == 0 + + +######################## +# ONLY SHORT TESTS BELOW +######################## + + +def test_layout_yolov8_local_parsing_image_soft(): + filename = os.path.join("sample-docs", "example_table.jpg") + # NOTE(benjamin) keep_output = True create a file for each image in + # localstorage for visualization of the result + document_layout = process_file_with_model(filename, model_name="yolov8s", is_image=True) + # NOTE(benjamin) The example image should result in one page result + assert len(document_layout.pages) == 1 + # NOTE(benjamin) Soft version of the test, run make test-long in order to run with full model + assert len(document_layout.pages[0].elements) > 0 + assert hasattr( + document_layout.pages[0].elements[0], + "prob", + ) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities + assert isinstance( + document_layout.pages[0].elements[0].prob, + float, + ) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float + + +def test_layout_yolov8_local_parsing_pdf_soft(): + filename = os.path.join("sample-docs", "loremipsum.pdf") + document_layout = process_file_with_model(filename, model_name="yolov8s") + assert len(document_layout.pages) == 1 + # NOTE(benjamin) Soft version of the test, run make test-long in order to run with full model + assert len(document_layout.pages[0].elements) > 0 + assert hasattr( + document_layout.pages[0].elements[0], + "prob", + ) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities + + +def test_layout_yolov8_local_parsing_empty_pdf_soft(): + filename = os.path.join("sample-docs", "empty-document.pdf") + document_layout = process_file_with_model(filename, model_name="yolov8s") + assert len(document_layout.pages) == 1 + # NOTE(benjamin) The example sent to the test contains 0 detections + text_elements_page_1 = [el for el in document_layout.pages[0].elements if el.type != "Image"] + assert len(text_elements_page_1) == 0 diff --git a/unstructured_inference/constants.py b/unstructured_inference/constants.py index d8139ed2..56c9f43d 100644 --- a/unstructured_inference/constants.py +++ b/unstructured_inference/constants.py @@ -8,6 +8,7 @@ class AnnotationResult(Enum): class Source(Enum): YOLOX = "yolox" + YOLOv8 = "yolov8" DETECTRON2_ONNX = "detectron2_onnx" DETECTRON2_LP = "detectron2_lp" CHIPPER = "chipper" diff --git a/unstructured_inference/models/base.py b/unstructured_inference/models/base.py index 26336e23..74752a4d 100644 --- a/unstructured_inference/models/base.py +++ b/unstructured_inference/models/base.py @@ -26,6 +26,12 @@ from unstructured_inference.models.yolox import ( UnstructuredYoloXModel, ) +from unstructured_inference.models.yolov8 import ( + MODEL_TYPES as YOLOV8_MODEL_TYPES, +) +from unstructured_inference.models.yolov8 import ( + UnstructuredYolov8Model, +) DEFAULT_MODEL = "yolox" @@ -35,6 +41,7 @@ **{name: UnstructuredDetectronModel for name in DETECTRON2_MODEL_TYPES}, **{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES}, **{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES}, + **{name: UnstructuredYolov8Model for name in YOLOV8_MODEL_TYPES}, **{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES}, "super_gradients": UnstructuredSuperGradients, } @@ -65,6 +72,8 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel: initialize_params = DETECTRON2_ONNX_MODEL_TYPES[model_name] elif model_name in YOLOX_MODEL_TYPES: initialize_params = YOLOX_MODEL_TYPES[model_name] + elif model_name in YOLOV8_MODEL_TYPES: + initialize_params = YOLOV8_MODEL_TYPES[model_name] elif model_name in CHIPPER_MODEL_TYPES: initialize_params = CHIPPER_MODEL_TYPES[model_name] else: diff --git a/unstructured_inference/models/yolov8.py b/unstructured_inference/models/yolov8.py new file mode 100644 index 00000000..d43c506d --- /dev/null +++ b/unstructured_inference/models/yolov8.py @@ -0,0 +1,114 @@ +from typing import List, cast + +import numpy as np +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision.ops import nms + +from unstructured_inference.constants import ElementType, Source +from unstructured_inference.inference.layoutelement import LayoutElement +from unstructured_inference.models.unstructuredmodel import UnstructuredObjectDetectionModel +from unstructured_inference.utils import LazyDict, LazyEvaluateInfo +from ultralytics import YOLO + +YOLOv8_LABEL_MAP = { + 0: ElementType.CAPTION, + 1: ElementType.FOOTNOTE, + 2: ElementType.FORMULA, + 3: ElementType.LIST_ITEM, + 4: ElementType.PAGE_FOOTER, + 5: ElementType.PAGE_HEADER, + 6: ElementType.PICTURE, + 7: ElementType.SECTION_HEADER, + 8: ElementType.TABLE, + 9: ElementType.TEXT, + 10: ElementType.TITLE, +} +label_to_color = { + ElementType.CAPTION: "black", + ElementType.FOOTNOTE: "cyan", + ElementType.FORMULA: "black", + ElementType.LIST_ITEM: "green", + ElementType.PAGE_FOOTER: "blue", + ElementType.PAGE_HEADER: "yellow", + ElementType.PICTURE: "black", + ElementType.SECTION_HEADER: "purple", + ElementType.TABLE: "black", + ElementType.TEXT: "black", + ElementType.TITLE: "red", +} + +model = YOLO('/home/joao/yolov8n/weights/best.pt') +MODEL_TYPES = { + "yolov8n": LazyDict( + model_path=LazyEvaluateInfo( + hf_hub_download, + "neuralshift/doc-layout-yolov8n", + "weights/best.pt", + ), + label_map=YOLOv8_LABEL_MAP, + ), + "yolov8s": LazyDict( + model_path=LazyEvaluateInfo( + hf_hub_download, + "neuralshift/doc-layout-yolov8s", + "weights/best.pt", + ), + label_map=YOLOv8_LABEL_MAP, + ), +} + + +class UnstructuredYolov8Model(UnstructuredObjectDetectionModel): + def predict(self, x: Image): + """Predict using YoloX model.""" + super().predict(x) + return self.image_processing(x) + + def initialize(self, model_path: str, label_map: dict): + """Start inference session for YoloX model.""" + self.model = YOLO(model=model_path) + self.layout_classes = label_map + + def image_processing( + self, + image: Image = None, + ) -> List[LayoutElement]: + """Method runing YoloX for layout detection, returns a PageLayout + parameters + ---------- + page + Path for image file with the image to process + origin_img + If specified, an Image object for process with YoloX model + page_number + Number asigned to the PageLayout returned + output_directory + Boolean indicating if result will be stored + """ + input_shape = (640, 640) + processed_image = image.resize(input_shape, Image.BILINEAR) + ratio = np.array(input_shape) / np.array(image.size) + + # NMS + boxes = self.model(processed_image)[0].boxes + valid_boxes = nms(boxes.xyxy, boxes.conf, 0.1) + boxes = boxes[valid_boxes] + boxes = boxes[boxes.conf > 0.3] + + regions = sorted([ + LayoutElement.from_coords( + box.xyxy[0][0] / ratio[0], + box.xyxy[0][1] / ratio[1], + box.xyxy[0][2] / ratio[0], + box.xyxy[0][3] / ratio[1], + text=None, + type=self.layout_classes[int(box.cls.item())], + prob=box.conf.item(), + source=Source.YOLOv8, + ) for box in boxes + ], key=lambda element: element.bbox.y1) + + page_layout = cast(List[LayoutElement], regions) # TODO(benjamin): encode image as base64? + + return page_layout From 3ed4d408ae02ea6be4d5a019c90c4ab38849e64b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Louren=C3=A7o=20Silva?= Date: Mon, 19 Feb 2024 15:02:07 +0000 Subject: [PATCH 2/4] fix: disable verbosity --- unstructured_inference/models/yolov8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unstructured_inference/models/yolov8.py b/unstructured_inference/models/yolov8.py index d43c506d..5564ec0f 100644 --- a/unstructured_inference/models/yolov8.py +++ b/unstructured_inference/models/yolov8.py @@ -91,7 +91,7 @@ def image_processing( ratio = np.array(input_shape) / np.array(image.size) # NMS - boxes = self.model(processed_image)[0].boxes + boxes = self.model(processed_image, verbose=False)[0].boxes valid_boxes = nms(boxes.xyxy, boxes.conf, 0.1) boxes = boxes[valid_boxes] boxes = boxes[boxes.conf > 0.3] From 946ddf356744edff5c9517671c8f9e73bd8f6ca3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Louren=C3=A7o=20Silva?= Date: Mon, 19 Feb 2024 15:03:33 +0000 Subject: [PATCH 3/4] chore: update documentation --- unstructured_inference/models/yolov8.py | 31 +++++-------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/unstructured_inference/models/yolov8.py b/unstructured_inference/models/yolov8.py index 5564ec0f..13e23344 100644 --- a/unstructured_inference/models/yolov8.py +++ b/unstructured_inference/models/yolov8.py @@ -24,19 +24,6 @@ 9: ElementType.TEXT, 10: ElementType.TITLE, } -label_to_color = { - ElementType.CAPTION: "black", - ElementType.FOOTNOTE: "cyan", - ElementType.FORMULA: "black", - ElementType.LIST_ITEM: "green", - ElementType.PAGE_FOOTER: "blue", - ElementType.PAGE_HEADER: "yellow", - ElementType.PICTURE: "black", - ElementType.SECTION_HEADER: "purple", - ElementType.TABLE: "black", - ElementType.TEXT: "black", - ElementType.TITLE: "red", -} model = YOLO('/home/joao/yolov8n/weights/best.pt') MODEL_TYPES = { @@ -61,12 +48,12 @@ class UnstructuredYolov8Model(UnstructuredObjectDetectionModel): def predict(self, x: Image): - """Predict using YoloX model.""" + """Predict using Yolov8 model.""" super().predict(x) return self.image_processing(x) def initialize(self, model_path: str, label_map: dict): - """Start inference session for YoloX model.""" + """Start inference session for Yolov8 model.""" self.model = YOLO(model=model_path) self.layout_classes = label_map @@ -74,17 +61,11 @@ def image_processing( self, image: Image = None, ) -> List[LayoutElement]: - """Method runing YoloX for layout detection, returns a PageLayout - parameters + """Method runing Yolov8 for layout detection, returns a list of + LayoutElement ---------- - page - Path for image file with the image to process - origin_img - If specified, an Image object for process with YoloX model - page_number - Number asigned to the PageLayout returned - output_directory - Boolean indicating if result will be stored + image + Image to process """ input_shape = (640, 640) processed_image = image.resize(input_shape, Image.BILINEAR) From 346d368e55cb00c56a51d37c5ec7a647af193b95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Louren=C3=A7o=20Silva?= Date: Mon, 19 Feb 2024 15:13:57 +0000 Subject: [PATCH 4/4] fix: yolov8 bbox data type --- unstructured_inference/models/yolov8.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unstructured_inference/models/yolov8.py b/unstructured_inference/models/yolov8.py index 13e23344..c1ab0601 100644 --- a/unstructured_inference/models/yolov8.py +++ b/unstructured_inference/models/yolov8.py @@ -79,10 +79,10 @@ def image_processing( regions = sorted([ LayoutElement.from_coords( - box.xyxy[0][0] / ratio[0], - box.xyxy[0][1] / ratio[1], - box.xyxy[0][2] / ratio[0], - box.xyxy[0][3] / ratio[1], + box.xyxy[0][0].item() / ratio[0], + box.xyxy[0][1].item() / ratio[1], + box.xyxy[0][2].item() / ratio[0], + box.xyxy[0][3].item() / ratio[1], text=None, type=self.layout_classes[int(box.cls.item())], prob=box.conf.item(),