Skip to content

Commit 5d15683

Browse files
authored
Merge pull request #22 from longieirl/fix/rfc-17-inject-row-classifier
feat: inject shared RowClassifier chain (RFC #17)
2 parents 3d13406 + e59536d commit 5d15683

4 files changed

Lines changed: 38 additions & 3 deletions

File tree

packages/parser-core/src/bankstatements_core/extraction/boundary_detector.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import logging
1010
from dataclasses import dataclass
1111

12-
from bankstatements_core.extraction.row_classifiers import create_row_classifier_chain
12+
from bankstatements_core.extraction.row_classifiers import (
13+
RowClassifier,
14+
create_row_classifier_chain,
15+
)
1316

1417
logger = logging.getLogger(__name__)
1518

@@ -49,6 +52,7 @@ def __init__(
4952
min_section_gap: int = 50,
5053
structure_breakdown_threshold: int = 8,
5154
dynamic_boundary_threshold: int = 15,
55+
row_classifier: RowClassifier | None = None,
5256
):
5357
"""
5458
Initialize boundary detector.
@@ -60,11 +64,16 @@ def __init__(
6064
min_section_gap: Minimum gap in pixels to consider a section boundary
6165
structure_breakdown_threshold: Number of empty columns to consider structure broken
6266
dynamic_boundary_threshold: Consecutive non-transaction rows before ending extraction
67+
row_classifier: Optional RowClassifier chain; creates default if not provided
6368
"""
6469
self.columns = columns
6570
self.fallback_bottom_y = fallback_bottom_y
6671
self.table_top_y = table_top_y
67-
self._row_classifier = create_row_classifier_chain()
72+
self._row_classifier = (
73+
row_classifier
74+
if row_classifier is not None
75+
else create_row_classifier_chain()
76+
)
6877

6978
# Configuration parameters
7079
self.min_gap_threshold = min_section_gap

packages/parser-core/src/bankstatements_core/extraction/extraction_facade.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from bankstatements_core.extraction.extraction_params import TABLE_BOTTOM_Y, TABLE_TOP_Y
1515

1616
if TYPE_CHECKING:
17+
from bankstatements_core.extraction.row_classifiers import RowClassifier
1718
from bankstatements_core.templates.template_model import BankTemplate
1819

1920

@@ -25,6 +26,7 @@ def detect_table_end_boundary_smart(
2526
min_section_gap: int = 50,
2627
structure_breakdown_threshold: int = 8,
2728
dynamic_boundary_threshold: int = 15,
29+
row_classifier: "RowClassifier | None" = None,
2830
) -> int:
2931
"""
3032
Detect table end intelligently (facade).
@@ -39,6 +41,7 @@ def detect_table_end_boundary_smart(
3941
min_section_gap: Minimum gap in pixels to consider a section boundary
4042
structure_breakdown_threshold: Number of empty columns to consider structure broken
4143
dynamic_boundary_threshold: Consecutive non-transaction rows before ending extraction
44+
row_classifier: Optional RowClassifier chain; creates default if not provided
4245
4346
Returns:
4447
Detected bottom Y coordinate
@@ -52,6 +55,7 @@ def detect_table_end_boundary_smart(
5255
min_section_gap=min_section_gap,
5356
structure_breakdown_threshold=structure_breakdown_threshold,
5457
dynamic_boundary_threshold=dynamic_boundary_threshold,
58+
row_classifier=row_classifier,
5559
)
5660

5761
return detector.detect_boundary(words)

packages/parser-core/src/bankstatements_core/extraction/pdf_extractor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,11 @@ def _determine_boundaries_and_extract(
248248
)
249249

250250
dynamic_bottom_y = detect_table_end_boundary_smart(
251-
all_words, table_top_y, self.columns, table_bottom_y
251+
all_words,
252+
table_top_y,
253+
self.columns,
254+
table_bottom_y,
255+
row_classifier=self._row_classifier,
252256
)
253257

254258
# Safety check: Cap dynamic boundary at static boundary to prevent over-extraction

packages/parser-core/tests/extraction/test_boundary_detector.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,21 @@ def test_configuration_via_constructor(self):
379379
assert detector.min_gap_threshold == 100
380380
assert detector.structure_breakdown_threshold == 10
381381
assert detector.consecutive_threshold == 20
382+
383+
def test_injected_classifier_is_used(self):
384+
from unittest.mock import Mock
385+
386+
from bankstatements_core.extraction.row_classifiers import RowClassifier
387+
388+
mock_chain = Mock(spec=RowClassifier)
389+
mock_chain.classify.return_value = "metadata"
390+
detector = TableBoundaryDetector(
391+
columns=TEST_COLUMNS, row_classifier=mock_chain
392+
)
393+
words = [{"text": "Footer", "x0": 60, "top": 350}]
394+
detector.detect_boundary(words)
395+
mock_chain.classify.assert_called()
396+
397+
def test_default_classifier_created_when_not_injected(self):
398+
detector = TableBoundaryDetector(columns=TEST_COLUMNS)
399+
assert detector._row_classifier is not None

0 commit comments

Comments
 (0)