Skip to content

Commit e8caa4e

Browse files
authored
Chore: allow table model to accept optional OCR data (#256)
## Summary Change `run_prediction` in the table model to accept optional OCR data for table OCR refactor as an alternative to getting OCR tokens in `get_tokens`. ## TODO please see [CORE-2259](https://unstructured-ai.atlassian.net/browse/CORE-2259) to update `ocr_token` from dict to a data class. [CORE-2259]: https://unstructured-ai.atlassian.net/browse/CORE-2259?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
1 parent 05f9b61 commit e8caa4e

File tree

4 files changed

+46
-7
lines changed

4 files changed

+46
-7
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.7.9
2+
3+
* Allow table model to accept optional OCR tokens
4+
15
## 0.7.8
26

37
* Fix: include onnx as base dependency.

test_unstructured_inference/models/test_tables.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,21 @@ def test_table_prediction_tesseract(table_transformer, example_image):
361361
) in prediction
362362

363363

364+
def test_table_prediction_tesseract_with_ocr_tokens(table_transformer, example_image):
365+
ocr_tokens = [
366+
{
367+
# bounding box should match table structure
368+
"bbox": [70.0, 245.0, 127.0, 266.0],
369+
"block_num": 0,
370+
"line_num": 0,
371+
"span_num": 0,
372+
"text": "Blind",
373+
},
374+
]
375+
prediction = table_transformer.predict(example_image, ocr_tokens=ocr_tokens)
376+
assert prediction == "<table><tr><td>Blind</td></tr></table>"
377+
378+
364379
@pytest.mark.skipif(skip_outside_ci, reason="Skipping paddle test run outside of CI")
365380
def test_table_prediction_paddle(monkeypatch, example_image):
366381
monkeypatch.setenv("TABLE_OCR", "paddle")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.8" # pragma: no cover
1+
__version__ = "0.7.9" # pragma: no cover

unstructured_inference/models/tables.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import xml.etree.ElementTree as ET
66
from collections import defaultdict
77
from pathlib import Path
8-
from typing import List, Optional, Union
8+
from typing import Dict, List, Optional, Union
99

1010
import cv2
1111
import numpy as np
@@ -33,10 +33,24 @@ class UnstructuredTableTransformerModel(UnstructuredModel):
3333
def __init__(self):
3434
pass
3535

36-
def predict(self, x: Image):
37-
"""Predict table structure deferring to run_prediction"""
36+
def predict(self, x: Image, ocr_tokens: Optional[List[Dict]] = None):
37+
"""Predict table structure deferring to run_prediction with ocr tokens
38+
39+
Note:
40+
`ocr_tokens` is a list of dictionaries representing OCR tokens,
41+
where each dictionary has the following format:
42+
{
43+
"bbox": [int, int, int, int], # Bounding box coordinates of the token
44+
"block_num": int, # Block number
45+
"line_num": int, # Line number
46+
"span_num": int, # Span number
47+
"text": str, # Text content of the token
48+
}
49+
The bounding box coordinates should match the table structure.
50+
FIXME: refactor token data into a dataclass so we have clear expectations of the fields
51+
"""
3852
super().predict(x)
39-
return self.run_prediction(x)
53+
return self.run_prediction(x, ocr_tokens=ocr_tokens)
4054

4155
def initialize(
4256
self,
@@ -161,12 +175,18 @@ def run_prediction(
161175
self,
162176
x: Image,
163177
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
178+
ocr_tokens: Optional[List[Dict]] = None,
164179
):
165180
"""Predict table structure"""
166181
outputs_structure = self.get_structure(x, pad_for_structure_detection)
167-
tokens = self.get_tokens(x=x)
182+
if ocr_tokens is None:
183+
logger.warning(
184+
"Table OCR from get_tokens method will be deprecated. "
185+
"In the future the OCR tokens are expected to be passed in.",
186+
)
187+
ocr_tokens = self.get_tokens(x=x)
168188

169-
html = recognize(outputs_structure, x, tokens=tokens, out_html=True)["html"]
189+
html = recognize(outputs_structure, x, tokens=ocr_tokens, out_html=True)["html"]
170190
prediction = html[0] if html else ""
171191
return prediction
172192

0 commit comments

Comments
 (0)