Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ repos:
pass_filenames: false
language: system
files: '\.py$'
- id: poetry
name: Poetry check
entry: poetry lock --check
pass_filenames: false
language: system
# - id: poetry
# name: Poetry check
# entry: poetry lock --check
# pass_filenames: false
# language: system
- id: system
name: MyPy
entry: poetry run mypy docling_ibm_models
Expand Down
5 changes: 1 addition & 4 deletions demo/demo_layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,7 @@ def main(args):
Path(viz_dir).mkdir(parents=True, exist_ok=True)

# Download models from HF
download_path = snapshot_download(
repo_id="ds4sd/docling-models", revision="v2.1.0"
)
artifact_path = os.path.join(download_path, "model_artifacts/layout")
artifact_path = snapshot_download(repo_id="ds4sd/docling-layout-heron", revision="main")

# Test the LayoutPredictor
demo(logger, artifact_path, device, num_threads, img_dir, viz_dir)
Expand Down
42 changes: 20 additions & 22 deletions docling_ibm_models/layoutmodel/layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torchvision.transforms as T
from PIL import Image
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from transformers import RTDetrImageProcessor, RTDetrV2ForObjectDetection

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,24 +44,23 @@ def __init__(
"""
# Initialize classes map:
self._classes_map = {
0: "background",
1: "Caption",
2: "Footnote",
3: "Formula",
4: "List-item",
5: "Page-footer",
6: "Page-header",
7: "Picture",
8: "Section-header",
9: "Table",
10: "Text",
11: "Title",
12: "Document Index",
13: "Code",
14: "Checkbox-Selected",
15: "Checkbox-Unselected",
16: "Form",
17: "Key-Value Region",
0: "Caption",
1: "Footnote",
2: "Formula",
3: "List-item",
4: "Page-footer",
5: "Page-header",
6: "Picture",
7: "Section-header",
8: "Table",
9: "Text",
10: "Title",
11: "Document Index",
12: "Code",
13: "Checkbox-Selected",
14: "Checkbox-Unselected",
15: "Form",
16: "Key-Value Region",
}

# Blacklisted classes
Expand All @@ -87,7 +86,7 @@ def __init__(
processor_config = os.path.join(artifact_path, "preprocessor_config.json")
model_config = os.path.join(artifact_path, "config.json")
self._image_processor = RTDetrImageProcessor.from_json_file(processor_config)
self._model = RTDetrForObjectDetection.from_pretrained(
self._model = RTDetrV2ForObjectDetection.from_pretrained(
artifact_path, config=model_config
).to(self._device)
self._model.eval()
Expand Down Expand Up @@ -154,8 +153,7 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
result["scores"], result["labels"], result["boxes"]
):
score = float(score.item())

label_id = int(label_id.item()) + 1 # Advance the label_id
label_id = int(label_id.item())
label_str = self._classes_map[label_id]

# Filter out blacklisted classes
Expand Down
6 changes: 3 additions & 3 deletions tests/test_layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def init() -> dict:
"image_size": 640,
"threshold": 0.6,
},
"pred_bboxes": 9,
# "pred_bboxes": 9,
"pred_bboxes": 12,
}

# Download models from HF
download_path = snapshot_download(repo_id="ds4sd/docling-models", revision="v2.1.0")
artifact_path = os.path.join(download_path, "model_artifacts/layout")
artifact_path = snapshot_download(repo_id="ds4sd/docling-layout-heron", revision="main")

# Add the missing config keys
init["artifact_path"] = artifact_path
Expand Down
Loading