Vik Paruchuri
commited on
Commit
·
5c027fe
1
Parent(s):
cc1d60d
Additional fixes
Browse files- benchmarks/overall/overall.py +1 -1
- benchmarks/overall/scoring.py +30 -3
- benchmarks/table/inference.py +26 -8
- benchmarks/table/table.py +2 -17
- marker/renderers/markdown.py +1 -1
benchmarks/overall/overall.py
CHANGED
|
@@ -31,7 +31,7 @@ def get_method_scores(ds, model_dict, max_rows=None, score_func=marker_scoring_f
|
|
| 31 |
doc_type = sample["classification"]
|
| 32 |
|
| 33 |
try:
|
| 34 |
-
gt_html = [block["html"] for block in gt_blocks]
|
| 35 |
scores = score_func(model_dict, sample, gt_html, **kwargs)
|
| 36 |
except ValueError as e:
|
| 37 |
print(f"Error with sample {idx}: {e}")
|
|
|
|
| 31 |
doc_type = sample["classification"]
|
| 32 |
|
| 33 |
try:
|
| 34 |
+
gt_html = [block["html"] for block in gt_blocks if len(block["html"]) > 0]
|
| 35 |
scores = score_func(model_dict, sample, gt_html, **kwargs)
|
| 36 |
except ValueError as e:
|
| 37 |
print(f"Error with sample {idx}: {e}")
|
benchmarks/overall/scoring.py
CHANGED
|
@@ -69,8 +69,10 @@ def standardize_markdown(markdown):
|
|
| 69 |
markdown = re.sub(pattern, standardize_math, markdown)
|
| 70 |
|
| 71 |
# Replace image urls
|
| 72 |
-
pattern = r'!\[(.*?)\]\((
|
| 73 |
-
markdown = re.sub(pattern, r'![
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# Clean up html tags
|
| 76 |
markdown = markdown.replace("<br>", "\n")
|
|
@@ -84,10 +86,35 @@ def standardize_markdown(markdown):
|
|
| 84 |
markdown = re.sub("\\.+", ".", markdown) # Replace repeated periods with a single period, like in table of contents
|
| 85 |
markdown = re.sub("#+", "#", markdown) # Replace repeated headers with a single header
|
| 86 |
markdown = re.sub(r"\$", "", markdown) # Remove equation delimiters
|
| 87 |
-
markdown = markdown.encode().decode('unicode-escape') # Decode unicode characters properly
|
| 88 |
return markdown.strip().lower()
|
| 89 |
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def standardize_math(match):
|
| 92 |
try:
|
| 93 |
delim = "$$" if match.group(0).startswith('$$') else "$"
|
|
|
|
| 69 |
markdown = re.sub(pattern, standardize_math, markdown)
|
| 70 |
|
| 71 |
# Replace image urls
|
| 72 |
+
pattern = r'!\[(.*?)\]\((https?://[^\s\)]+)\)'
|
| 73 |
+
markdown = re.sub(pattern, r'![link]', markdown)
|
| 74 |
+
markdown = strip_latex_symbols(markdown)
|
| 75 |
+
markdown = replace_centered_lines(markdown)
|
| 76 |
|
| 77 |
# Clean up html tags
|
| 78 |
markdown = markdown.replace("<br>", "\n")
|
|
|
|
| 86 |
markdown = re.sub("\\.+", ".", markdown) # Replace repeated periods with a single period, like in table of contents
|
| 87 |
markdown = re.sub("#+", "#", markdown) # Replace repeated headers with a single header
|
| 88 |
markdown = re.sub(r"\$", "", markdown) # Remove equation delimiters
|
| 89 |
+
markdown = markdown.encode().decode('unicode-escape', errors="ignore") # Decode unicode characters properly
|
| 90 |
return markdown.strip().lower()
|
| 91 |
|
| 92 |
|
| 93 |
+
def replace_centered_lines(text):
|
| 94 |
+
def replace_match(m):
|
| 95 |
+
content = m.group(0)
|
| 96 |
+
dash_count = content.count('-')
|
| 97 |
+
return '-' * dash_count
|
| 98 |
+
|
| 99 |
+
pattern = r':-+:'
|
| 100 |
+
return re.sub(pattern, replace_match, text)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def strip_latex_symbols(text):
|
| 104 |
+
# Handle short math mode sequences first - only match $ $ with brief content
|
| 105 |
+
text = re.sub(r'\$\s*\\?[a-zA-Z]+\d?\s*\$', '', text)
|
| 106 |
+
|
| 107 |
+
# Handle common patterns inside remaining math mode
|
| 108 |
+
patterns = [
|
| 109 |
+
r'\$\s*\\?[a-zA-Z]+\d?\s*\$', # \alpha or \alpha2 in math mode
|
| 110 |
+
r'\$\s*\d+\\[a-zA-Z]+\s*\$', # 45\circ in math mode
|
| 111 |
+
r'\$\s*[a-zA-Z0-9]\\[a-zA-Z]+\s*\$' # x\dagger in math mode
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
pattern = '|'.join(patterns)
|
| 115 |
+
return re.sub(pattern, '', text)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
def standardize_math(match):
|
| 119 |
try:
|
| 120 |
delim = "$$" if match.group(0).startswith('$$') else "$"
|
benchmarks/table/inference.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
import
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
from bs4 import BeautifulSoup
|
| 4 |
import pypdfium2 as pdfium
|
|
@@ -10,18 +11,27 @@ from benchmarks.table.gemini import gemini_table_rec
|
|
| 10 |
from marker.config.parser import ConfigParser
|
| 11 |
from marker.converters.table import TableConverter
|
| 12 |
from marker.models import create_model_dict
|
|
|
|
|
|
|
| 13 |
from marker.util import matrix_intersection_area
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
|
| 17 |
models = create_model_dict()
|
| 18 |
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True})
|
| 19 |
total_unaligned = 0
|
| 20 |
results = []
|
| 21 |
|
| 22 |
-
dataset = datasets.load_dataset(dataset, split='train')
|
| 23 |
-
dataset = dataset.shuffle(seed=0)
|
| 24 |
-
|
| 25 |
iterations = len(dataset)
|
| 26 |
if max_rows is not None:
|
| 27 |
iterations = min(max_rows, len(dataset))
|
|
@@ -45,7 +55,8 @@ def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, m
|
|
| 45 |
marker_json = converter(temp_pdf_file.name).children
|
| 46 |
|
| 47 |
doc = pdfium.PdfDocument(temp_pdf_file.name)
|
| 48 |
-
page_image = doc[0].render(scale=
|
|
|
|
| 49 |
|
| 50 |
if len(marker_json) == 0 or len(gt_tables) == 0:
|
| 51 |
print(f'No tables detected, skipping...')
|
|
@@ -55,10 +66,17 @@ def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, m
|
|
| 55 |
marker_tables = extract_tables(marker_json)
|
| 56 |
marker_table_boxes = [table.bbox for table in marker_tables]
|
| 57 |
page_bbox = marker_json[0].bbox
|
| 58 |
-
|
| 59 |
table_images = [
|
| 60 |
-
page_image.crop(
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# Normalize the bboxes
|
| 64 |
for bbox in marker_table_boxes:
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
import numpy as np
|
| 4 |
from bs4 import BeautifulSoup
|
| 5 |
import pypdfium2 as pdfium
|
|
|
|
| 11 |
from marker.config.parser import ConfigParser
|
| 12 |
from marker.converters.table import TableConverter
|
| 13 |
from marker.models import create_model_dict
|
| 14 |
+
from marker.renderers.json import JSONBlockOutput
|
| 15 |
+
from marker.schema.polygon import PolygonBox
|
| 16 |
from marker.util import matrix_intersection_area
|
| 17 |
|
| 18 |
|
| 19 |
+
def extract_tables(children: List[JSONBlockOutput]):
|
| 20 |
+
tables = []
|
| 21 |
+
for child in children:
|
| 22 |
+
if child.block_type == 'Table':
|
| 23 |
+
tables.append(child)
|
| 24 |
+
elif child.children:
|
| 25 |
+
tables.extend(extract_tables(child.children))
|
| 26 |
+
return tables
|
| 27 |
+
|
| 28 |
+
|
| 29 |
def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
|
| 30 |
models = create_model_dict()
|
| 31 |
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True})
|
| 32 |
total_unaligned = 0
|
| 33 |
results = []
|
| 34 |
|
|
|
|
|
|
|
|
|
|
| 35 |
iterations = len(dataset)
|
| 36 |
if max_rows is not None:
|
| 37 |
iterations = min(max_rows, len(dataset))
|
|
|
|
| 55 |
marker_json = converter(temp_pdf_file.name).children
|
| 56 |
|
| 57 |
doc = pdfium.PdfDocument(temp_pdf_file.name)
|
| 58 |
+
page_image = doc[0].render(scale=96/72).to_pil()
|
| 59 |
+
doc.close()
|
| 60 |
|
| 61 |
if len(marker_json) == 0 or len(gt_tables) == 0:
|
| 62 |
print(f'No tables detected, skipping...')
|
|
|
|
| 66 |
marker_tables = extract_tables(marker_json)
|
| 67 |
marker_table_boxes = [table.bbox for table in marker_tables]
|
| 68 |
page_bbox = marker_json[0].bbox
|
| 69 |
+
|
| 70 |
table_images = [
|
| 71 |
+
page_image.crop(
|
| 72 |
+
PolygonBox.from_bbox(bbox)
|
| 73 |
+
.rescale(
|
| 74 |
+
(page_bbox[2], page_bbox[3]), (page_image.width, page_image.height)
|
| 75 |
+
).bbox
|
| 76 |
+
)
|
| 77 |
+
for bbox
|
| 78 |
+
in marker_table_boxes
|
| 79 |
+
]
|
| 80 |
|
| 81 |
# Normalize the bboxes
|
| 82 |
for bbox in marker_table_boxes:
|
benchmarks/table/table.py
CHANGED
|
@@ -1,7 +1,4 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
-
from benchmarks.table.inference import inference_tables
|
| 4 |
-
|
| 5 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS
|
| 6 |
|
| 7 |
from pathlib import Path
|
|
@@ -15,11 +12,9 @@ import click
|
|
| 15 |
from tabulate import tabulate
|
| 16 |
import json
|
| 17 |
from concurrent.futures import ProcessPoolExecutor
|
| 18 |
-
from marker.renderers.json import JSONBlockOutput
|
| 19 |
-
from marker.settings import settings
|
| 20 |
|
| 21 |
-
from marker.
|
| 22 |
-
from
|
| 23 |
|
| 24 |
from scoring import wrap_table_html, similarity_eval_html
|
| 25 |
|
|
@@ -31,16 +26,6 @@ def update_teds_score(result, prefix: str = "marker"):
|
|
| 31 |
return result
|
| 32 |
|
| 33 |
|
| 34 |
-
def extract_tables(children: List[JSONBlockOutput]):
|
| 35 |
-
tables = []
|
| 36 |
-
for child in children:
|
| 37 |
-
if child.block_type == 'Table':
|
| 38 |
-
tables.append(child)
|
| 39 |
-
elif child.children:
|
| 40 |
-
tables.extend(extract_tables(child.children))
|
| 41 |
-
return tables
|
| 42 |
-
|
| 43 |
-
|
| 44 |
@click.command(help="Benchmark Table to HTML Conversion")
|
| 45 |
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.")
|
| 46 |
@click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use")
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
| 2 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS
|
| 3 |
|
| 4 |
from pathlib import Path
|
|
|
|
| 12 |
from tabulate import tabulate
|
| 13 |
import json
|
| 14 |
from concurrent.futures import ProcessPoolExecutor
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
from marker.settings import settings
|
| 17 |
+
from benchmarks.table.inference import inference_tables
|
| 18 |
|
| 19 |
from scoring import wrap_table_html, similarity_eval_html
|
| 20 |
|
|
|
|
| 26 |
return result
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
@click.command(help="Benchmark Table to HTML Conversion")
|
| 30 |
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.")
|
| 31 |
@click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use")
|
marker/renderers/markdown.py
CHANGED
|
@@ -128,7 +128,7 @@ class Markdownify(MarkdownConverter):
|
|
| 128 |
grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
|
| 129 |
except IndexError:
|
| 130 |
# Sometimes the colspan/rowspan predictions can overflow
|
| 131 |
-
print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
|
| 132 |
continue
|
| 133 |
|
| 134 |
col_idx += colspan
|
|
|
|
| 128 |
grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
|
| 129 |
except IndexError:
|
| 130 |
# Sometimes the colspan/rowspan predictions can overflow
|
| 131 |
+
print(f"Overflow in columns: {col_idx + c} >= {total_cols} or rows: {row_idx + r} >= {total_rows}")
|
| 132 |
continue
|
| 133 |
|
| 134 |
col_idx += colspan
|