@@ -28,7 +28,8 @@ class TableProcessor(BaseProcessor):
2828 A processor for recognizing tables in the document.
2929 """
3030
31- block_types = (BlockTypes .Table , BlockTypes .TableOfContents , BlockTypes .Form )
31+ block_types = (BlockTypes .Table ,
32+ BlockTypes .TableOfContents , BlockTypes .Form )
3233 detect_boxes : Annotated [
3334 bool ,
3435 "Whether to detect boxes for the table recognition model." ,
@@ -64,7 +65,8 @@ class TableProcessor(BaseProcessor):
6465 bool ,
6566 "Whether to disable the tqdm progress bar." ,
6667 ] = False
67- drop_repeated_text : Annotated [bool , "Drop repeated text in OCR results." ] = False
68+ drop_repeated_text : Annotated [bool ,
69+ "Drop repeated text in OCR results." ] = False
6870
6971 def __init__ (
7072 self ,
@@ -128,7 +130,8 @@ def __call__(self, document: Document):
128130 )
129131 self .assign_text_to_cells (tables , table_data )
130132 self .split_combined_rows (tables ) # Split up rows that were combined
131- self .combine_dollar_column (tables ) # Combine columns that are just dollar signs
133+ # Combine columns that are just dollar signs
134+ self .combine_dollar_column (tables )
132135
133136 # Assign table cells to the table
134137 table_idx = 0
@@ -174,7 +177,8 @@ def __call__(self, document: Document):
174177 )
175178 for child , intersection in zip (child_contained_blocks , intersections ):
176179 # Adjust this to percentage of the child block that is enclosed by the table
177- intersection_pct = intersection / max (child .polygon .area , 1 )
180+ intersection_pct = intersection / \
181+ max (child .polygon .area , 1 )
178182 if intersection_pct > 0.95 and child .id in page .structure :
179183 page .structure .remove (child .id )
180184
@@ -186,7 +190,8 @@ def finalize_cell_text(self, cell: SuryaTableCell):
186190 if not text or text == "." :
187191 continue
188192 text = re .sub (r"(\s\.){2,}" , "" , text ) # Replace . . .
189- text = re .sub (r"\.{2,}" , "" , text ) # Replace ..., like in table of contents
193+ # Replace ..., like in table of contents
194+ text = re .sub (r"\.{2,}" , "" , text )
190195 text = self .normalize_spaces (fix_text (text ))
191196 fixed_text .append (text )
192197 return fixed_text
@@ -236,7 +241,8 @@ def combine_dollar_column(self, tables: List[TableResult]):
236241 col < max_col ,
237242 ]
238243 ):
239- next_col_cells = [c for c in table .cells if c .col_id == col + 1 ]
244+ next_col_cells = [
245+ c for c in table .cells if c .col_id == col + 1 ]
240246 next_col_rows = [c .row_id for c in next_col_cells ]
241247 col_rows = [c .row_id for c in col_cells ]
242248 if (
@@ -293,7 +299,8 @@ def split_combined_rows(self, tables: List[TableResult]):
293299 # Cells in this row
294300 # Deepcopy is because we do an in-place mutation later, and that can cause rows to shift to match rows in unique_rows
295301 # making them be processed twice
296- row_cells = deepcopy ([c for c in table .cells if c .row_id == row ])
302+ row_cells = deepcopy (
303+ [c for c in table .cells if c .row_id == row ])
297304 rowspans = [c .rowspan for c in row_cells ]
298305 line_lens = [
299306 len (c .text_lines ) if isinstance (c .text_lines , list ) else 1
@@ -312,14 +319,16 @@ def split_combined_rows(self, tables: List[TableResult]):
312319 len (rowspan_cells ) == 0 ,
313320 all ([rowspan == 1 for rowspan in rowspans ]),
314321 all ([line_len > 1 for line_len in line_lens ]),
315- all ([line_len == line_lens [0 ] for line_len in line_lens ]),
322+ all ([line_len == line_lens [0 ]
323+ for line_len in line_lens ]),
316324 ]
317325 )
318326 line_lens_counter = Counter (line_lens )
319327 counter_keys = sorted (list (line_lens_counter .keys ()))
320328 should_split_partial_row = all (
321329 [
322- len (row_cells ) > 3 , # Only split if there are more than 3 cells
330+ # Only split if there are more than 3 cells
331+ len (row_cells ) > 3 ,
323332 len (rowspan_cells ) == 0 ,
324333 all ([r == 1 for r in rowspans ]),
325334 len (line_lens_counter ) == 2
@@ -420,8 +429,10 @@ def assign_text_to_cells(self, tables: List[TableResult], table_data: list):
420429 for k in cell_text :
421430 # TODO: see if the text needs to be sorted (based on rotation)
422431 text = cell_text [k ]
423- assert all ("text" in t for t in text ), "All text lines must have text"
424- assert all ("bbox" in t for t in text ), "All text lines must have a bbox"
432+ assert all (
433+ "text" in t for t in text ), "All text lines must have text"
434+ assert all (
435+ "bbox" in t for t in text ), "All text lines must have a bbox"
425436 table_cells [k ].text_lines = text
426437
427438 def assign_pdftext_lines (self , extract_blocks : list , filepath : str ):
@@ -491,13 +502,16 @@ def get_detection_batch_size(self):
491502 return self .detection_batch_size
492503 elif settings .TORCH_DEVICE_MODEL == "cuda" :
493504 return 10
505+ elif settings .TORCH_DEVICE_MODEL == "mps" :
506+ # CPU fallback under MPS; modestly higher than plain CPU default
507+ return 6
494508 return 4
495509
496510 def get_table_rec_batch_size (self ):
497511 if self .table_rec_batch_size is not None :
498512 return self .table_rec_batch_size
499513 elif settings .TORCH_DEVICE_MODEL == "mps" :
500- return 6
514+ return 8
501515 elif settings .TORCH_DEVICE_MODEL == "cuda" :
502516 return 14
503517 return 6
@@ -506,7 +520,7 @@ def get_recognition_batch_size(self):
506520 if self .recognition_batch_size is not None :
507521 return self .recognition_batch_size
508522 elif settings .TORCH_DEVICE_MODEL == "mps" :
509- return 32
523+ return 24
510524 elif settings .TORCH_DEVICE_MODEL == "cuda" :
511525 return 32
512526 return 32
0 commit comments