diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6ce2add..69ca948 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,11 +14,11 @@ repos: pass_filenames: false language: system files: '\.py$' - - id: poetry - name: Poetry check - entry: poetry lock --check - pass_filenames: false - language: system + # - id: poetry + # name: Poetry check + # entry: poetry lock --check + # pass_filenames: false + # language: system - id: system name: MyPy entry: poetry run mypy docling_ibm_models diff --git a/demo/demo_layout_predictor.py b/demo/demo_layout_predictor.py index 56c030f..75bac2a 100644 --- a/demo/demo_layout_predictor.py +++ b/demo/demo_layout_predictor.py @@ -118,10 +118,7 @@ def main(args): Path(viz_dir).mkdir(parents=True, exist_ok=True) # Download models from HF - download_path = snapshot_download( - repo_id="ds4sd/docling-models", revision="v2.1.0" - ) - artifact_path = os.path.join(download_path, "model_artifacts/layout") + artifact_path = snapshot_download(repo_id="ds4sd/docling-layout-heron", revision="main") # Test the LayoutPredictor demo(logger, artifact_path, device, num_threads, img_dir, viz_dir) diff --git a/docling_ibm_models/layoutmodel/layout_predictor.py b/docling_ibm_models/layoutmodel/layout_predictor.py index 60ab1a5..4ae8648 100644 --- a/docling_ibm_models/layoutmodel/layout_predictor.py +++ b/docling_ibm_models/layoutmodel/layout_predictor.py @@ -11,7 +11,7 @@ import torch import torchvision.transforms as T from PIL import Image -from transformers import RTDetrForObjectDetection, RTDetrImageProcessor +from transformers import RTDetrImageProcessor, RTDetrV2ForObjectDetection _log = logging.getLogger(__name__) @@ -44,24 +44,23 @@ def __init__( """ # Initialize classes map: self._classes_map = { - 0: "background", - 1: "Caption", - 2: "Footnote", - 3: "Formula", - 4: "List-item", - 5: "Page-footer", - 6: "Page-header", - 7: "Picture", - 8: "Section-header", - 9: "Table", - 10: "Text", - 11: "Title", - 12: "Document Index", - 13: "Code", - 14: "Checkbox-Selected", - 15: "Checkbox-Unselected", - 16: "Form", - 17: "Key-Value Region", + 0: "Caption", + 1: "Footnote", + 2: "Formula", + 3: "List-item", + 4: "Page-footer", + 5: "Page-header", + 6: "Picture", + 7: "Section-header", + 8: "Table", + 9: "Text", + 10: "Title", + 11: "Document Index", + 12: "Code", + 13: "Checkbox-Selected", + 14: "Checkbox-Unselected", + 15: "Form", + 16: "Key-Value Region", } # Blacklisted classes @@ -87,7 +86,7 @@ def __init__( processor_config = os.path.join(artifact_path, "preprocessor_config.json") model_config = os.path.join(artifact_path, "config.json") self._image_processor = RTDetrImageProcessor.from_json_file(processor_config) - self._model = RTDetrForObjectDetection.from_pretrained( + self._model = RTDetrV2ForObjectDetection.from_pretrained( artifact_path, config=model_config ).to(self._device) self._model.eval() @@ -154,8 +153,7 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]: result["scores"], result["labels"], result["boxes"] ): score = float(score.item()) - - label_id = int(label_id.item()) + 1 # Advance the label_id + label_id = int(label_id.item()) label_str = self._classes_map[label_id] # Filter out blacklisted classes diff --git a/tests/test_layout_predictor.py b/tests/test_layout_predictor.py index 109ba42..9501826 100644 --- a/tests/test_layout_predictor.py +++ b/tests/test_layout_predictor.py @@ -31,12 +31,12 @@ def init() -> dict: "image_size": 640, "threshold": 0.6, }, - "pred_bboxes": 9, + # "pred_bboxes": 9, + "pred_bboxes": 12, } # Download models from HF - download_path = snapshot_download(repo_id="ds4sd/docling-models", revision="v2.1.0") - artifact_path = os.path.join(download_path, "model_artifacts/layout") + artifact_path = snapshot_download(repo_id="ds4sd/docling-layout-heron", revision="main") # Add the missing config keys init["artifact_path"] = artifact_path