Vik Paruchuri
commited on
Commit
·
d090d63
1
Parent(s):
c85fe35
Improve table benchmark, parsing
Browse files- README.md +8 -0
- benchmarks/table.py +3 -1
- marker/benchmark/table.py +29 -12
- marker/tables/cells.py +105 -105
- marker/tables/table.py +0 -1
- marker/tables/utils.py +1 -1
README.md
CHANGED
|
@@ -209,6 +209,14 @@ This will benchmark marker against other text extraction methods. It sets up ba
|
|
| 209 |
|
| 210 |
Omit `--nougat` to exclude nougat from the benchmark. I don't recommend running nougat on CPU, since it is very slow.
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
# Thanks
|
| 213 |
|
| 214 |
This work would not have been possible without amazing open source models and datasets, including (but not limited to):
|
|
|
|
| 209 |
|
| 210 |
Omit `--nougat` to exclude nougat from the benchmark. I don't recommend running nougat on CPU, since it is very slow.
|
| 211 |
|
| 212 |
+
### Table benchmark
|
| 213 |
+
|
| 214 |
+
There is a benchmark for table parsing, which you can run with:
|
| 215 |
+
|
| 216 |
+
```shell
|
| 217 |
+
python benchmarks/table.py test_data/tables.json
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
# Thanks
|
| 221 |
|
| 222 |
This work would not have been possible without amazing open source models and datasets, including (but not limited to):
|
benchmarks/table.py
CHANGED
|
@@ -3,6 +3,7 @@ import json
|
|
| 3 |
|
| 4 |
import datasets
|
| 5 |
from surya.schema import LayoutResult, LayoutBox
|
|
|
|
| 6 |
|
| 7 |
from marker.benchmark.table import score_table
|
| 8 |
from marker.schema.bbox import rescale_bbox
|
|
@@ -20,7 +21,7 @@ def main():
|
|
| 20 |
ds = datasets.load_dataset(args.dataset, split="train")
|
| 21 |
|
| 22 |
results = []
|
| 23 |
-
for i in range(len(ds)):
|
| 24 |
row = ds[i]
|
| 25 |
marker_page = Page(**json.loads(row["marker_page"]))
|
| 26 |
table_bbox = row["table_bbox"]
|
|
@@ -55,6 +56,7 @@ def main():
|
|
| 55 |
|
| 56 |
table_block = table_blocks[0]
|
| 57 |
table_md = table_block.lines[0].spans[0].text
|
|
|
|
| 58 |
results.append({
|
| 59 |
"score": score_table(table_md, gpt4_table),
|
| 60 |
"arxiv_id": row["arxiv_id"],
|
|
|
|
| 3 |
|
| 4 |
import datasets
|
| 5 |
from surya.schema import LayoutResult, LayoutBox
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
|
| 8 |
from marker.benchmark.table import score_table
|
| 9 |
from marker.schema.bbox import rescale_bbox
|
|
|
|
| 21 |
ds = datasets.load_dataset(args.dataset, split="train")
|
| 22 |
|
| 23 |
results = []
|
| 24 |
+
for i in tqdm(range(len(ds)), desc="Evaluating tables"):
|
| 25 |
row = ds[i]
|
| 26 |
marker_page = Page(**json.loads(row["marker_page"]))
|
| 27 |
table_bbox = row["table_bbox"]
|
|
|
|
| 56 |
|
| 57 |
table_block = table_blocks[0]
|
| 58 |
table_md = table_block.lines[0].spans[0].text
|
| 59 |
+
|
| 60 |
results.append({
|
| 61 |
"score": score_table(table_md, gpt4_table),
|
| 62 |
"arxiv_id": row["arxiv_id"],
|
marker/benchmark/table.py
CHANGED
|
@@ -2,23 +2,40 @@ from rapidfuzz import fuzz
|
|
| 2 |
import re
|
| 3 |
|
| 4 |
|
| 5 |
-
def
|
| 6 |
table = table.strip()
|
| 7 |
table = re.sub(r" {2,}", "", table)
|
| 8 |
table_rows = table.split("\n")
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def score_table(hypothesis, reference):
|
| 13 |
-
hypothesis =
|
| 14 |
-
reference =
|
| 15 |
|
| 16 |
alignments = []
|
| 17 |
-
for
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
alignment = fuzz.ratio(hrow, row, score_cutoff=30) / 100
|
| 21 |
-
if alignment > max_alignment:
|
| 22 |
-
max_alignment = alignment
|
| 23 |
-
alignments.append(max_alignment)
|
| 24 |
-
return sum(alignments) / len(reference)
|
|
|
|
| 2 |
import re
|
| 3 |
|
| 4 |
|
| 5 |
+
def split_to_cells(table):
|
| 6 |
table = table.strip()
|
| 7 |
table = re.sub(r" {2,}", "", table)
|
| 8 |
table_rows = table.split("\n")
|
| 9 |
+
table_rows = [t for t in table_rows if t.strip()]
|
| 10 |
+
table_cells = [r.split("|") for r in table_rows]
|
| 11 |
+
return table_cells
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def align_rows(hypothesis, ref_row):
|
| 15 |
+
best_alignment = []
|
| 16 |
+
best_alignment_score = 0
|
| 17 |
+
for j in range(0, len(hypothesis)):
|
| 18 |
+
alignments = []
|
| 19 |
+
for i in range(len(ref_row)):
|
| 20 |
+
if i >= len(hypothesis[j]):
|
| 21 |
+
alignments.append(0)
|
| 22 |
+
continue
|
| 23 |
+
alignment = fuzz.ratio(hypothesis[j][i], ref_row[i], score_cutoff=30) / 100
|
| 24 |
+
alignments.append(alignment)
|
| 25 |
+
if len(alignments) == 0:
|
| 26 |
+
continue
|
| 27 |
+
alignment_score = sum(alignments) / len(alignments)
|
| 28 |
+
if alignment_score >= best_alignment_score:
|
| 29 |
+
best_alignment = alignments
|
| 30 |
+
best_alignment_score = alignment_score
|
| 31 |
+
return best_alignment
|
| 32 |
|
| 33 |
|
| 34 |
def score_table(hypothesis, reference):
|
| 35 |
+
hypothesis = split_to_cells(hypothesis)
|
| 36 |
+
reference = split_to_cells(reference)
|
| 37 |
|
| 38 |
alignments = []
|
| 39 |
+
for i in range(0, len(reference)):
|
| 40 |
+
alignments.extend(align_rows(hypothesis, reference[i]))
|
| 41 |
+
return sum(alignments) / len(alignments)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
marker/tables/cells.py
CHANGED
|
@@ -1,116 +1,116 @@
|
|
| 1 |
from PIL import Image, ImageDraw
|
| 2 |
import copy
|
| 3 |
|
|
|
|
|
|
|
| 4 |
from marker.tables.edges import get_vertical_lines
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
-
def
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
for cell in row:
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
line_bbox[0] = line_bbox[2]
|
| 22 |
-
case "c":
|
| 23 |
-
line_bbox[0] = line_bbox[0] + (line_bbox[2] - line_bbox[0]) / 2
|
| 24 |
-
line_bbox[2] = line_bbox[0]
|
| 25 |
-
|
| 26 |
-
line_bbox[1] -= y_tolerance
|
| 27 |
-
line_bbox[3] += y_tolerance
|
| 28 |
-
line_bbox[0] -= table_box[0]
|
| 29 |
-
line_bbox[2] -= table_box[0]
|
| 30 |
-
line_bbox[1] -= table_box[1]
|
| 31 |
-
line_bbox[3] -= table_box[1]
|
| 32 |
-
draw.rectangle(line_bbox, outline="red", width=x_tolerance)
|
| 33 |
-
|
| 34 |
-
np_img = np.array(draw_img, dtype=np.float32) / 255.0
|
| 35 |
-
columns = get_vertical_lines(np_img, divisor=2, x_tolerance=10, y_tolerance=1)
|
| 36 |
-
columns = sorted(columns, key=lambda x: x[0])
|
| 37 |
-
|
| 38 |
-
# Remove short columns (single cells, probably)
|
| 39 |
-
# Rescale coordinates back to image
|
| 40 |
-
rescaled = []
|
| 41 |
-
for c in columns:
|
| 42 |
-
if c[3] - c[1] < table_height / 5:
|
| 43 |
-
continue
|
| 44 |
-
c[0] += table_box[0]
|
| 45 |
-
c[2] += table_box[0]
|
| 46 |
-
c[1] += table_box[1]
|
| 47 |
-
c[3] += table_box[1]
|
| 48 |
-
rescaled.append(c)
|
| 49 |
-
return rescaled
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def assign_cells_to_columns(page, table_box, rows, tolerance=5):
|
| 53 |
-
alignments = ["l", "r", "c"]
|
| 54 |
-
columns = {}
|
| 55 |
-
for align in alignments:
|
| 56 |
-
columns[align] = get_column_lines(page, table_box, rows, align=align)
|
| 57 |
-
|
| 58 |
-
# Find the column alignment that is closest to the number of columns
|
| 59 |
-
max_cols = max([len(r) for r in rows])
|
| 60 |
-
columns = min(columns.items(), key=lambda x: abs(len(x) - max_cols))[1]
|
| 61 |
-
|
| 62 |
-
formatted_rows = []
|
| 63 |
-
for table_row in rows:
|
| 64 |
-
formatted_row = []
|
| 65 |
-
for cell_idx in range(len(table_row) - 1, -1, -1):
|
| 66 |
-
cell = copy.deepcopy(table_row[cell_idx])
|
| 67 |
-
cell_bbox = cell[0]
|
| 68 |
-
|
| 69 |
-
found = False
|
| 70 |
-
for j in range(len(columns) - 1, -1, -1):
|
| 71 |
-
if columns[j][0] - tolerance < cell_bbox[0]:
|
| 72 |
-
if len(formatted_row) > 0:
|
| 73 |
-
prev_column = formatted_row[-1][0]
|
| 74 |
-
blanks = prev_column - j
|
| 75 |
-
if blanks > 1:
|
| 76 |
-
for b in range(1, blanks):
|
| 77 |
-
formatted_row.append([prev_column - b, ""])
|
| 78 |
-
formatted_row.append([j, cell[1]])
|
| 79 |
-
found = True
|
| 80 |
break
|
| 81 |
-
if
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
if col_idx <= prev_col:
|
| 93 |
-
col[0] = prev_col + 1
|
| 94 |
-
prev_col = col[0]
|
| 95 |
-
col_count = max(col_count, col[0] + 1)
|
| 96 |
-
|
| 97 |
-
# Assign cells to correct column positions
|
| 98 |
-
clean_rows = []
|
| 99 |
-
for row in formatted_rows:
|
| 100 |
-
clean_row = []
|
| 101 |
-
for col in range(col_count):
|
| 102 |
-
found = False
|
| 103 |
-
for cell in row:
|
| 104 |
-
if cell[0] == col:
|
| 105 |
-
clean_row.append(cell)
|
| 106 |
-
found = True
|
| 107 |
-
break
|
| 108 |
-
if not found:
|
| 109 |
-
clean_row.append((col, ""))
|
| 110 |
-
clean_rows.append([cell[1] for cell in clean_row])
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
row.append("")
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from PIL import Image, ImageDraw
|
| 2 |
import copy
|
| 3 |
|
| 4 |
+
from marker.schema.bbox import rescale_bbox, box_intersection_pct
|
| 5 |
+
from marker.schema.page import Page
|
| 6 |
from marker.tables.edges import get_vertical_lines
|
| 7 |
import numpy as np
|
| 8 |
+
from sklearn.cluster import DBSCAN
|
| 9 |
+
from marker.settings import settings
|
| 10 |
|
| 11 |
|
| 12 |
+
def cluster_coords(coords, row_count):
|
| 13 |
+
if len(coords) == 0:
|
| 14 |
+
return []
|
| 15 |
+
coords = np.array(sorted(set(coords))).reshape(-1, 1)
|
| 16 |
+
|
| 17 |
+
clustering = DBSCAN(eps=.01, min_samples=max(2, row_count // 4)).fit(coords)
|
| 18 |
+
clusters = clustering.labels_
|
| 19 |
+
|
| 20 |
+
separators = []
|
| 21 |
+
for label in set(clusters):
|
| 22 |
+
clustered_points = coords[clusters == label]
|
| 23 |
+
separators.append(np.mean(clustered_points))
|
| 24 |
+
|
| 25 |
+
separators = sorted(separators)
|
| 26 |
+
return separators
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def find_column_separators(page: Page, table_box, rows, round_factor=.002, min_count=1):
|
| 30 |
+
left_edges = []
|
| 31 |
+
right_edges = []
|
| 32 |
+
centers = []
|
| 33 |
+
|
| 34 |
+
line_boxes = [p.bbox for p in page.text_lines.bboxes]
|
| 35 |
+
line_boxes = [rescale_bbox(page.text_lines.image_bbox, page.bbox, l) for l in line_boxes]
|
| 36 |
+
line_boxes = [l for l in line_boxes if box_intersection_pct(l, table_box) > settings.BBOX_INTERSECTION_THRESH]
|
| 37 |
+
|
| 38 |
+
pwidth = page.bbox[2] - page.bbox[0]
|
| 39 |
+
pheight = page.bbox[3] - page.bbox[1]
|
| 40 |
+
for cell in line_boxes:
|
| 41 |
+
ncell = [cell[0] / pwidth, cell[1] / pheight, cell[2] / pwidth, cell[3] / pheight]
|
| 42 |
+
left_edges.append(ncell[0] / round_factor * round_factor)
|
| 43 |
+
right_edges.append(ncell[2] / round_factor * round_factor)
|
| 44 |
+
centers.append((ncell[0] + ncell[2]) / 2 * round_factor / round_factor)
|
| 45 |
+
|
| 46 |
+
left_edges = [l for l in left_edges if left_edges.count(l) > min_count]
|
| 47 |
+
right_edges = [r for r in right_edges if right_edges.count(r) > min_count]
|
| 48 |
+
centers = [c for c in centers if centers.count(c) > min_count]
|
| 49 |
+
|
| 50 |
+
sorted_left = cluster_coords(left_edges, len(rows))
|
| 51 |
+
sorted_right = cluster_coords(right_edges, len(rows))
|
| 52 |
+
sorted_center = cluster_coords(centers, len(rows))
|
| 53 |
+
|
| 54 |
+
# Find list with minimum length
|
| 55 |
+
separators = max([sorted_left, sorted_right, sorted_center], key=len)
|
| 56 |
+
separators.append(1)
|
| 57 |
+
separators.insert(0, 0)
|
| 58 |
+
return separators
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def assign_cells_to_columns(page, table_box, rows, round_factor=.002, tolerance=.01):
|
| 62 |
+
separators = find_column_separators(page, table_box, rows, round_factor=round_factor)
|
| 63 |
+
additional_column_index = 0
|
| 64 |
+
pwidth = page.bbox[2] - page.bbox[0]
|
| 65 |
+
row_dicts = []
|
| 66 |
+
|
| 67 |
+
for row in rows:
|
| 68 |
+
new_row = {}
|
| 69 |
+
last_col_index = -1
|
| 70 |
for cell in row:
|
| 71 |
+
left_edge = cell[0][0] / pwidth
|
| 72 |
+
column_index = -1
|
| 73 |
+
for i, separator in enumerate(separators):
|
| 74 |
+
if left_edge - tolerance < separator and last_col_index < i:
|
| 75 |
+
column_index = i
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
break
|
| 77 |
+
if column_index == -1:
|
| 78 |
+
column_index = len(separators) + additional_column_index
|
| 79 |
+
additional_column_index += 1
|
| 80 |
+
new_row[column_index] = cell[1]
|
| 81 |
+
last_col_index = column_index
|
| 82 |
+
additional_column_index = 0
|
| 83 |
+
row_dicts.append(new_row)
|
| 84 |
+
|
| 85 |
+
max_row_idx = 0
|
| 86 |
+
for row in row_dicts:
|
| 87 |
+
max_row_idx = max(max_row_idx, max(row.keys()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
# Assign sorted cells to columns, account for blanks
|
| 90 |
+
new_rows = []
|
| 91 |
+
for row in row_dicts:
|
| 92 |
+
flat_row = []
|
| 93 |
+
for row_idx in range(1, max_row_idx + 1):
|
| 94 |
+
if row_idx in row:
|
| 95 |
+
flat_row.append(row[row_idx])
|
| 96 |
+
else:
|
| 97 |
+
flat_row.append("")
|
| 98 |
+
new_rows.append(flat_row)
|
| 99 |
+
|
| 100 |
+
# Pad rows to have the same length
|
| 101 |
+
max_row_len = max([len(r) for r in new_rows])
|
| 102 |
+
for row in new_rows:
|
| 103 |
+
while len(row) < max_row_len:
|
| 104 |
row.append("")
|
| 105 |
+
|
| 106 |
+
cols_to_remove = set()
|
| 107 |
+
for idx, col in enumerate(zip(*new_rows)):
|
| 108 |
+
col_total = sum([len(cell.strip()) > 0 for cell in col])
|
| 109 |
+
if col_total == 0:
|
| 110 |
+
cols_to_remove.add(idx)
|
| 111 |
+
|
| 112 |
+
rows = []
|
| 113 |
+
for row in new_rows:
|
| 114 |
+
rows.append([col for idx, col in enumerate(row) if idx not in cols_to_remove])
|
| 115 |
+
|
| 116 |
+
return rows
|
marker/tables/table.py
CHANGED
|
@@ -7,7 +7,6 @@ from typing import List
|
|
| 7 |
from marker.settings import settings
|
| 8 |
from marker.tables.cells import assign_cells_to_columns
|
| 9 |
from marker.tables.utils import sort_table_blocks, replace_dots, replace_newlines
|
| 10 |
-
from marker.schema.bbox import BboxElement
|
| 11 |
|
| 12 |
|
| 13 |
def get_table_surya(page, table_box, space_tol=.01) -> List[List[str]]:
|
|
|
|
| 7 |
from marker.settings import settings
|
| 8 |
from marker.tables.cells import assign_cells_to_columns
|
| 9 |
from marker.tables.utils import sort_table_blocks, replace_dots, replace_newlines
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def get_table_surya(page, table_box, space_tol=.01) -> List[List[str]]:
|
marker/tables/utils.py
CHANGED
|
@@ -34,4 +34,4 @@ def replace_dots(text):
|
|
| 34 |
def replace_newlines(text):
|
| 35 |
# Replace all newlines
|
| 36 |
newline_pattern = re.compile(r'[\r\n]+')
|
| 37 |
-
return newline_pattern.sub(' ', text.strip()
|
|
|
|
| 34 |
def replace_newlines(text):
|
| 35 |
# Replace all newlines
|
| 36 |
newline_pattern = re.compile(r'[\r\n]+')
|
| 37 |
+
return newline_pattern.sub(' ', text).strip()
|