Vik Paruchuri commited on
Commit
5c027fe
·
1 Parent(s): cc1d60d

Additional fixes

Browse files
benchmarks/overall/overall.py CHANGED
@@ -31,7 +31,7 @@ def get_method_scores(ds, model_dict, max_rows=None, score_func=marker_scoring_f
31
  doc_type = sample["classification"]
32
 
33
  try:
34
- gt_html = [block["html"] for block in gt_blocks]
35
  scores = score_func(model_dict, sample, gt_html, **kwargs)
36
  except ValueError as e:
37
  print(f"Error with sample {idx}: {e}")
 
31
  doc_type = sample["classification"]
32
 
33
  try:
34
+ gt_html = [block["html"] for block in gt_blocks if len(block["html"]) > 0]
35
  scores = score_func(model_dict, sample, gt_html, **kwargs)
36
  except ValueError as e:
37
  print(f"Error with sample {idx}: {e}")
benchmarks/overall/scoring.py CHANGED
@@ -69,8 +69,10 @@ def standardize_markdown(markdown):
69
  markdown = re.sub(pattern, standardize_math, markdown)
70
 
71
  # Replace image urls
72
- pattern = r'!\[(.*?)\]\((.*?)(?:\?.*?width=(\d+).*?height=(\d+).*?)\)'
73
- markdown = re.sub(pattern, r'![/api/placeholder]', markdown)
 
 
74
 
75
  # Clean up html tags
76
  markdown = markdown.replace("<br>", "\n")
@@ -84,10 +86,35 @@ def standardize_markdown(markdown):
84
  markdown = re.sub("\\.+", ".", markdown) # Replace repeated periods with a single period, like in table of contents
85
  markdown = re.sub("#+", "#", markdown) # Replace repeated headers with a single header
86
  markdown = re.sub(r"\$", "", markdown) # Remove equation delimiters
87
- markdown = markdown.encode().decode('unicode-escape') # Decode unicode characters properly
88
  return markdown.strip().lower()
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def standardize_math(match):
92
  try:
93
  delim = "$$" if match.group(0).startswith('$$') else "$"
 
69
  markdown = re.sub(pattern, standardize_math, markdown)
70
 
71
  # Replace image urls
72
+ pattern = r'!\[(.*?)\]\((https?://[^\s\)]+)\)'
73
+ markdown = re.sub(pattern, r'![link]', markdown)
74
+ markdown = strip_latex_symbols(markdown)
75
+ markdown = replace_centered_lines(markdown)
76
 
77
  # Clean up html tags
78
  markdown = markdown.replace("<br>", "\n")
 
86
  markdown = re.sub("\\.+", ".", markdown) # Replace repeated periods with a single period, like in table of contents
87
  markdown = re.sub("#+", "#", markdown) # Replace repeated headers with a single header
88
  markdown = re.sub(r"\$", "", markdown) # Remove equation delimiters
89
+ markdown = markdown.encode().decode('unicode-escape', errors="ignore") # Decode unicode characters properly
90
  return markdown.strip().lower()
91
 
92
 
93
+ def replace_centered_lines(text):
94
+ def replace_match(m):
95
+ content = m.group(0)
96
+ dash_count = content.count('-')
97
+ return '-' * dash_count
98
+
99
+ pattern = r':-+:'
100
+ return re.sub(pattern, replace_match, text)
101
+
102
+
103
+ def strip_latex_symbols(text):
104
+ # Handle short math mode sequences first - only match $ $ with brief content
105
+ text = re.sub(r'\$\s*\\?[a-zA-Z]+\d?\s*\$', '', text)
106
+
107
+ # Handle common patterns inside remaining math mode
108
+ patterns = [
109
+ r'\$\s*\\?[a-zA-Z]+\d?\s*\$', # \alpha or \alpha2 in math mode
110
+ r'\$\s*\d+\\[a-zA-Z]+\s*\$', # 45\circ in math mode
111
+ r'\$\s*[a-zA-Z0-9]\\[a-zA-Z]+\s*\$' # x\dagger in math mode
112
+ ]
113
+
114
+ pattern = '|'.join(patterns)
115
+ return re.sub(pattern, '', text)
116
+
117
+
118
  def standardize_math(match):
119
  try:
120
  delim = "$$" if match.group(0).startswith('$$') else "$"
benchmarks/table/inference.py CHANGED
@@ -1,4 +1,5 @@
1
- import datasets
 
2
  import numpy as np
3
  from bs4 import BeautifulSoup
4
  import pypdfium2 as pdfium
@@ -10,18 +11,27 @@ from benchmarks.table.gemini import gemini_table_rec
10
  from marker.config.parser import ConfigParser
11
  from marker.converters.table import TableConverter
12
  from marker.models import create_model_dict
 
 
13
  from marker.util import matrix_intersection_area
14
 
15
 
 
 
 
 
 
 
 
 
 
 
16
  def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
17
  models = create_model_dict()
18
  config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True})
19
  total_unaligned = 0
20
  results = []
21
 
22
- dataset = datasets.load_dataset(dataset, split='train')
23
- dataset = dataset.shuffle(seed=0)
24
-
25
  iterations = len(dataset)
26
  if max_rows is not None:
27
  iterations = min(max_rows, len(dataset))
@@ -45,7 +55,8 @@ def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, m
45
  marker_json = converter(temp_pdf_file.name).children
46
 
47
  doc = pdfium.PdfDocument(temp_pdf_file.name)
48
- page_image = doc[0].render(scale=92 / 72).to_pil()
 
49
 
50
  if len(marker_json) == 0 or len(gt_tables) == 0:
51
  print(f'No tables detected, skipping...')
@@ -55,10 +66,17 @@ def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, m
55
  marker_tables = extract_tables(marker_json)
56
  marker_table_boxes = [table.bbox for table in marker_tables]
57
  page_bbox = marker_json[0].bbox
58
- w_scaler, h_scaler = page_image.width / page_bbox[2], page_image.height / page_bbox[3]
59
  table_images = [
60
- page_image.crop([bbox[0] * w_scaler, bbox[1] * h_scaler, bbox[2] * w_scaler, bbox[3] * h_scaler]) for bbox
61
- in marker_table_boxes]
 
 
 
 
 
 
 
62
 
63
  # Normalize the bboxes
64
  for bbox in marker_table_boxes:
 
1
+ from typing import List
2
+
3
  import numpy as np
4
  from bs4 import BeautifulSoup
5
  import pypdfium2 as pdfium
 
11
  from marker.config.parser import ConfigParser
12
  from marker.converters.table import TableConverter
13
  from marker.models import create_model_dict
14
+ from marker.renderers.json import JSONBlockOutput
15
+ from marker.schema.polygon import PolygonBox
16
  from marker.util import matrix_intersection_area
17
 
18
 
19
+ def extract_tables(children: List[JSONBlockOutput]):
20
+ tables = []
21
+ for child in children:
22
+ if child.block_type == 'Table':
23
+ tables.append(child)
24
+ elif child.children:
25
+ tables.extend(extract_tables(child.children))
26
+ return tables
27
+
28
+
29
  def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
30
  models = create_model_dict()
31
  config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True})
32
  total_unaligned = 0
33
  results = []
34
 
 
 
 
35
  iterations = len(dataset)
36
  if max_rows is not None:
37
  iterations = min(max_rows, len(dataset))
 
55
  marker_json = converter(temp_pdf_file.name).children
56
 
57
  doc = pdfium.PdfDocument(temp_pdf_file.name)
58
+ page_image = doc[0].render(scale=96/72).to_pil()
59
+ doc.close()
60
 
61
  if len(marker_json) == 0 or len(gt_tables) == 0:
62
  print(f'No tables detected, skipping...')
 
66
  marker_tables = extract_tables(marker_json)
67
  marker_table_boxes = [table.bbox for table in marker_tables]
68
  page_bbox = marker_json[0].bbox
69
+
70
  table_images = [
71
+ page_image.crop(
72
+ PolygonBox.from_bbox(bbox)
73
+ .rescale(
74
+ (page_bbox[2], page_bbox[3]), (page_image.width, page_image.height)
75
+ ).bbox
76
+ )
77
+ for bbox
78
+ in marker_table_boxes
79
+ ]
80
 
81
  # Normalize the bboxes
82
  for bbox in marker_table_boxes:
benchmarks/table/table.py CHANGED
@@ -1,7 +1,4 @@
1
  import os
2
-
3
- from benchmarks.table.inference import inference_tables
4
-
5
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS
6
 
7
  from pathlib import Path
@@ -15,11 +12,9 @@ import click
15
  from tabulate import tabulate
16
  import json
17
  from concurrent.futures import ProcessPoolExecutor
18
- from marker.renderers.json import JSONBlockOutput
19
- from marker.settings import settings
20
 
21
- from marker.config.parser import ConfigParser
22
- from marker.models import create_model_dict
23
 
24
  from scoring import wrap_table_html, similarity_eval_html
25
 
@@ -31,16 +26,6 @@ def update_teds_score(result, prefix: str = "marker"):
31
  return result
32
 
33
 
34
- def extract_tables(children: List[JSONBlockOutput]):
35
- tables = []
36
- for child in children:
37
- if child.block_type == 'Table':
38
- tables.append(child)
39
- elif child.children:
40
- tables.extend(extract_tables(child.children))
41
- return tables
42
-
43
-
44
  @click.command(help="Benchmark Table to HTML Conversion")
45
  @click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.")
46
  @click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use")
 
1
  import os
 
 
 
2
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS
3
 
4
  from pathlib import Path
 
12
  from tabulate import tabulate
13
  import json
14
  from concurrent.futures import ProcessPoolExecutor
 
 
15
 
16
+ from marker.settings import settings
17
+ from benchmarks.table.inference import inference_tables
18
 
19
  from scoring import wrap_table_html, similarity_eval_html
20
 
 
26
  return result
27
 
28
 
 
 
 
 
 
 
 
 
 
 
29
  @click.command(help="Benchmark Table to HTML Conversion")
30
  @click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.")
31
  @click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use")
marker/renderers/markdown.py CHANGED
@@ -128,7 +128,7 @@ class Markdownify(MarkdownConverter):
128
  grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
129
  except IndexError:
130
  # Sometimes the colspan/rowspan predictions can overflow
131
- print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
132
  continue
133
 
134
  col_idx += colspan
 
128
  grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
129
  except IndexError:
130
  # Sometimes the colspan/rowspan predictions can overflow
131
+ print(f"Overflow in columns: {col_idx + c} >= {total_cols} or rows: {row_idx + r} >= {total_rows}")
132
  continue
133
 
134
  col_idx += colspan