Vik Paruchuri commited on
Commit
d090d63
·
1 Parent(s): c85fe35

Improve table benchmark, parsing

Browse files
README.md CHANGED
@@ -209,6 +209,14 @@ This will benchmark marker against other text extraction methods. It sets up ba
209
 
210
  Omit `--nougat` to exclude nougat from the benchmark. I don't recommend running nougat on CPU, since it is very slow.
211
 
 
 
 
 
 
 
 
 
212
  # Thanks
213
 
214
  This work would not have been possible without amazing open source models and datasets, including (but not limited to):
 
209
 
210
  Omit `--nougat` to exclude nougat from the benchmark. I don't recommend running nougat on CPU, since it is very slow.
211
 
212
+ ### Table benchmark
213
+
214
+ There is a benchmark for table parsing, which you can run with:
215
+
216
+ ```shell
217
+ python benchmarks/table.py test_data/tables.json
218
+ ```
219
+
220
  # Thanks
221
 
222
  This work would not have been possible without amazing open source models and datasets, including (but not limited to):
benchmarks/table.py CHANGED
@@ -3,6 +3,7 @@ import json
3
 
4
  import datasets
5
  from surya.schema import LayoutResult, LayoutBox
 
6
 
7
  from marker.benchmark.table import score_table
8
  from marker.schema.bbox import rescale_bbox
@@ -20,7 +21,7 @@ def main():
20
  ds = datasets.load_dataset(args.dataset, split="train")
21
 
22
  results = []
23
- for i in range(len(ds)):
24
  row = ds[i]
25
  marker_page = Page(**json.loads(row["marker_page"]))
26
  table_bbox = row["table_bbox"]
@@ -55,6 +56,7 @@ def main():
55
 
56
  table_block = table_blocks[0]
57
  table_md = table_block.lines[0].spans[0].text
 
58
  results.append({
59
  "score": score_table(table_md, gpt4_table),
60
  "arxiv_id": row["arxiv_id"],
 
3
 
4
  import datasets
5
  from surya.schema import LayoutResult, LayoutBox
6
+ from tqdm import tqdm
7
 
8
  from marker.benchmark.table import score_table
9
  from marker.schema.bbox import rescale_bbox
 
21
  ds = datasets.load_dataset(args.dataset, split="train")
22
 
23
  results = []
24
+ for i in tqdm(range(len(ds)), desc="Evaluating tables"):
25
  row = ds[i]
26
  marker_page = Page(**json.loads(row["marker_page"]))
27
  table_bbox = row["table_bbox"]
 
56
 
57
  table_block = table_blocks[0]
58
  table_md = table_block.lines[0].spans[0].text
59
+
60
  results.append({
61
  "score": score_table(table_md, gpt4_table),
62
  "arxiv_id": row["arxiv_id"],
marker/benchmark/table.py CHANGED
@@ -2,23 +2,40 @@ from rapidfuzz import fuzz
2
  import re
3
 
4
 
5
- def split_to_rows(table):
6
  table = table.strip()
7
  table = re.sub(r" {2,}", "", table)
8
  table_rows = table.split("\n")
9
- return [t for t in table_rows if t.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def score_table(hypothesis, reference):
13
- hypothesis = split_to_rows(hypothesis)
14
- reference = split_to_rows(reference)
15
 
16
  alignments = []
17
- for row in reference:
18
- max_alignment = 0
19
- for hrow in hypothesis:
20
- alignment = fuzz.ratio(hrow, row, score_cutoff=30) / 100
21
- if alignment > max_alignment:
22
- max_alignment = alignment
23
- alignments.append(max_alignment)
24
- return sum(alignments) / len(reference)
 
2
  import re
3
 
4
 
5
+ def split_to_cells(table):
6
  table = table.strip()
7
  table = re.sub(r" {2,}", "", table)
8
  table_rows = table.split("\n")
9
+ table_rows = [t for t in table_rows if t.strip()]
10
+ table_cells = [r.split("|") for r in table_rows]
11
+ return table_cells
12
+
13
+
14
+ def align_rows(hypothesis, ref_row):
15
+ best_alignment = []
16
+ best_alignment_score = 0
17
+ for j in range(0, len(hypothesis)):
18
+ alignments = []
19
+ for i in range(len(ref_row)):
20
+ if i >= len(hypothesis[j]):
21
+ alignments.append(0)
22
+ continue
23
+ alignment = fuzz.ratio(hypothesis[j][i], ref_row[i], score_cutoff=30) / 100
24
+ alignments.append(alignment)
25
+ if len(alignments) == 0:
26
+ continue
27
+ alignment_score = sum(alignments) / len(alignments)
28
+ if alignment_score >= best_alignment_score:
29
+ best_alignment = alignments
30
+ best_alignment_score = alignment_score
31
+ return best_alignment
32
 
33
 
34
  def score_table(hypothesis, reference):
35
+ hypothesis = split_to_cells(hypothesis)
36
+ reference = split_to_cells(reference)
37
 
38
  alignments = []
39
+ for i in range(0, len(reference)):
40
+ alignments.extend(align_rows(hypothesis, reference[i]))
41
+ return sum(alignments) / len(alignments)
 
 
 
 
 
marker/tables/cells.py CHANGED
@@ -1,116 +1,116 @@
1
  from PIL import Image, ImageDraw
2
  import copy
3
 
 
 
4
  from marker.tables.edges import get_vertical_lines
5
  import numpy as np
 
 
6
 
7
 
8
- def get_column_lines(page, table_box, table_rows, align="l", y_tolerance=10, x_tolerance=4):
9
- table_height = (table_box[3] - table_box[1]) * 2
10
- table_width = table_box[2] - table_box[0]
11
- img_size = (int(table_width), int(table_height))
12
- draw_img = Image.new("RGB", img_size)
13
- draw = ImageDraw.Draw(draw_img)
14
- for row in table_rows:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  for cell in row:
16
- line_bbox = list(copy.deepcopy(cell[0]))
17
- match align:
18
- case "l":
19
- line_bbox[2] = line_bbox[0]
20
- case "r":
21
- line_bbox[0] = line_bbox[2]
22
- case "c":
23
- line_bbox[0] = line_bbox[0] + (line_bbox[2] - line_bbox[0]) / 2
24
- line_bbox[2] = line_bbox[0]
25
-
26
- line_bbox[1] -= y_tolerance
27
- line_bbox[3] += y_tolerance
28
- line_bbox[0] -= table_box[0]
29
- line_bbox[2] -= table_box[0]
30
- line_bbox[1] -= table_box[1]
31
- line_bbox[3] -= table_box[1]
32
- draw.rectangle(line_bbox, outline="red", width=x_tolerance)
33
-
34
- np_img = np.array(draw_img, dtype=np.float32) / 255.0
35
- columns = get_vertical_lines(np_img, divisor=2, x_tolerance=10, y_tolerance=1)
36
- columns = sorted(columns, key=lambda x: x[0])
37
-
38
- # Remove short columns (single cells, probably)
39
- # Rescale coordinates back to image
40
- rescaled = []
41
- for c in columns:
42
- if c[3] - c[1] < table_height / 5:
43
- continue
44
- c[0] += table_box[0]
45
- c[2] += table_box[0]
46
- c[1] += table_box[1]
47
- c[3] += table_box[1]
48
- rescaled.append(c)
49
- return rescaled
50
-
51
-
52
- def assign_cells_to_columns(page, table_box, rows, tolerance=5):
53
- alignments = ["l", "r", "c"]
54
- columns = {}
55
- for align in alignments:
56
- columns[align] = get_column_lines(page, table_box, rows, align=align)
57
-
58
- # Find the column alignment that is closest to the number of columns
59
- max_cols = max([len(r) for r in rows])
60
- columns = min(columns.items(), key=lambda x: abs(len(x) - max_cols))[1]
61
-
62
- formatted_rows = []
63
- for table_row in rows:
64
- formatted_row = []
65
- for cell_idx in range(len(table_row) - 1, -1, -1):
66
- cell = copy.deepcopy(table_row[cell_idx])
67
- cell_bbox = cell[0]
68
-
69
- found = False
70
- for j in range(len(columns) - 1, -1, -1):
71
- if columns[j][0] - tolerance < cell_bbox[0]:
72
- if len(formatted_row) > 0:
73
- prev_column = formatted_row[-1][0]
74
- blanks = prev_column - j
75
- if blanks > 1:
76
- for b in range(1, blanks):
77
- formatted_row.append([prev_column - b, ""])
78
- formatted_row.append([j, cell[1]])
79
- found = True
80
  break
81
- if not found:
82
- formatted_row.append([cell_idx, cell[1]])
83
- formatted_rows.append(formatted_row[::-1])
84
-
85
- # Ensure rows have sequential column indices
86
- # Also identify the total number of columns
87
- col_count = 0
88
- for row in formatted_rows:
89
- prev_col = -1
90
- for col in row:
91
- col_idx = col[0]
92
- if col_idx <= prev_col:
93
- col[0] = prev_col + 1
94
- prev_col = col[0]
95
- col_count = max(col_count, col[0] + 1)
96
-
97
- # Assign cells to correct column positions
98
- clean_rows = []
99
- for row in formatted_rows:
100
- clean_row = []
101
- for col in range(col_count):
102
- found = False
103
- for cell in row:
104
- if cell[0] == col:
105
- clean_row.append(cell)
106
- found = True
107
- break
108
- if not found:
109
- clean_row.append((col, ""))
110
- clean_rows.append([cell[1] for cell in clean_row])
111
 
112
- max_cols = max([len(r) for r in clean_rows])
113
- for row in clean_rows:
114
- while len(row) < max_cols:
 
 
 
 
 
 
 
 
 
 
 
 
115
  row.append("")
116
- return clean_rows
 
 
 
 
 
 
 
 
 
 
 
 
1
  from PIL import Image, ImageDraw
2
  import copy
3
 
4
+ from marker.schema.bbox import rescale_bbox, box_intersection_pct
5
+ from marker.schema.page import Page
6
  from marker.tables.edges import get_vertical_lines
7
  import numpy as np
8
+ from sklearn.cluster import DBSCAN
9
+ from marker.settings import settings
10
 
11
 
12
+ def cluster_coords(coords, row_count):
13
+ if len(coords) == 0:
14
+ return []
15
+ coords = np.array(sorted(set(coords))).reshape(-1, 1)
16
+
17
+ clustering = DBSCAN(eps=.01, min_samples=max(2, row_count // 4)).fit(coords)
18
+ clusters = clustering.labels_
19
+
20
+ separators = []
21
+ for label in set(clusters):
22
+ clustered_points = coords[clusters == label]
23
+ separators.append(np.mean(clustered_points))
24
+
25
+ separators = sorted(separators)
26
+ return separators
27
+
28
+
29
+ def find_column_separators(page: Page, table_box, rows, round_factor=.002, min_count=1):
30
+ left_edges = []
31
+ right_edges = []
32
+ centers = []
33
+
34
+ line_boxes = [p.bbox for p in page.text_lines.bboxes]
35
+ line_boxes = [rescale_bbox(page.text_lines.image_bbox, page.bbox, l) for l in line_boxes]
36
+ line_boxes = [l for l in line_boxes if box_intersection_pct(l, table_box) > settings.BBOX_INTERSECTION_THRESH]
37
+
38
+ pwidth = page.bbox[2] - page.bbox[0]
39
+ pheight = page.bbox[3] - page.bbox[1]
40
+ for cell in line_boxes:
41
+ ncell = [cell[0] / pwidth, cell[1] / pheight, cell[2] / pwidth, cell[3] / pheight]
42
+ left_edges.append(ncell[0] / round_factor * round_factor)
43
+ right_edges.append(ncell[2] / round_factor * round_factor)
44
+ centers.append((ncell[0] + ncell[2]) / 2 * round_factor / round_factor)
45
+
46
+ left_edges = [l for l in left_edges if left_edges.count(l) > min_count]
47
+ right_edges = [r for r in right_edges if right_edges.count(r) > min_count]
48
+ centers = [c for c in centers if centers.count(c) > min_count]
49
+
50
+ sorted_left = cluster_coords(left_edges, len(rows))
51
+ sorted_right = cluster_coords(right_edges, len(rows))
52
+ sorted_center = cluster_coords(centers, len(rows))
53
+
54
+ # Find list with minimum length
55
+ separators = max([sorted_left, sorted_right, sorted_center], key=len)
56
+ separators.append(1)
57
+ separators.insert(0, 0)
58
+ return separators
59
+
60
+
61
+ def assign_cells_to_columns(page, table_box, rows, round_factor=.002, tolerance=.01):
62
+ separators = find_column_separators(page, table_box, rows, round_factor=round_factor)
63
+ additional_column_index = 0
64
+ pwidth = page.bbox[2] - page.bbox[0]
65
+ row_dicts = []
66
+
67
+ for row in rows:
68
+ new_row = {}
69
+ last_col_index = -1
70
  for cell in row:
71
+ left_edge = cell[0][0] / pwidth
72
+ column_index = -1
73
+ for i, separator in enumerate(separators):
74
+ if left_edge - tolerance < separator and last_col_index < i:
75
+ column_index = i
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  break
77
+ if column_index == -1:
78
+ column_index = len(separators) + additional_column_index
79
+ additional_column_index += 1
80
+ new_row[column_index] = cell[1]
81
+ last_col_index = column_index
82
+ additional_column_index = 0
83
+ row_dicts.append(new_row)
84
+
85
+ max_row_idx = 0
86
+ for row in row_dicts:
87
+ max_row_idx = max(max_row_idx, max(row.keys()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ # Assign sorted cells to columns, account for blanks
90
+ new_rows = []
91
+ for row in row_dicts:
92
+ flat_row = []
93
+ for row_idx in range(1, max_row_idx + 1):
94
+ if row_idx in row:
95
+ flat_row.append(row[row_idx])
96
+ else:
97
+ flat_row.append("")
98
+ new_rows.append(flat_row)
99
+
100
+ # Pad rows to have the same length
101
+ max_row_len = max([len(r) for r in new_rows])
102
+ for row in new_rows:
103
+ while len(row) < max_row_len:
104
  row.append("")
105
+
106
+ cols_to_remove = set()
107
+ for idx, col in enumerate(zip(*new_rows)):
108
+ col_total = sum([len(cell.strip()) > 0 for cell in col])
109
+ if col_total == 0:
110
+ cols_to_remove.add(idx)
111
+
112
+ rows = []
113
+ for row in new_rows:
114
+ rows.append([col for idx, col in enumerate(row) if idx not in cols_to_remove])
115
+
116
+ return rows
marker/tables/table.py CHANGED
@@ -7,7 +7,6 @@ from typing import List
7
  from marker.settings import settings
8
  from marker.tables.cells import assign_cells_to_columns
9
  from marker.tables.utils import sort_table_blocks, replace_dots, replace_newlines
10
- from marker.schema.bbox import BboxElement
11
 
12
 
13
  def get_table_surya(page, table_box, space_tol=.01) -> List[List[str]]:
 
7
  from marker.settings import settings
8
  from marker.tables.cells import assign_cells_to_columns
9
  from marker.tables.utils import sort_table_blocks, replace_dots, replace_newlines
 
10
 
11
 
12
  def get_table_surya(page, table_box, space_tol=.01) -> List[List[str]]:
marker/tables/utils.py CHANGED
@@ -34,4 +34,4 @@ def replace_dots(text):
34
  def replace_newlines(text):
35
  # Replace all newlines
36
  newline_pattern = re.compile(r'[\r\n]+')
37
- return newline_pattern.sub(' ', text.strip())
 
34
  def replace_newlines(text):
35
  # Replace all newlines
36
  newline_pattern = re.compile(r'[\r\n]+')
37
+ return newline_pattern.sub(' ', text).strip()