Vik Paruchuri commited on
Commit
1643ef3
·
1 Parent(s): b334fad

Misc fixes, benchmark updates

Browse files
.gitignore CHANGED
@@ -13,6 +13,7 @@ temp.md
13
  temp
14
  conversion_results
15
  uploads
 
16
 
17
  # Byte-compiled / optimized / DLL files
18
  __pycache__/
 
13
  temp
14
  conversion_results
15
  uploads
16
+ /cache
17
 
18
  # Byte-compiled / optimized / DLL files
19
  __pycache__/
benchmarks/overall/download/base.py CHANGED
@@ -32,10 +32,10 @@ class Downloader:
32
  "uuid": datasets.Value("string"),
33
  "time": datasets.Value("float"),
34
  }))
35
- out_ds.push_to_hub(f"datalab-to/marker_benchmark_{self.service}")
36
 
37
  def generate_data(self):
38
- max_rows = 2200
39
  for idx, sample in tqdm(enumerate(self.ds), desc=f"Saving {self.service} results"):
40
  cache_file = self.cache_path / f"{idx}.json"
41
  if cache_file.exists():
@@ -47,6 +47,9 @@ class Downloader:
47
  except JSONDecodeError as e:
48
  print(f"Error with sample {idx}: {e}")
49
  continue
 
 
 
50
  out_data["uuid"] = sample["uuid"]
51
 
52
  with cache_file.open("w") as f:
 
32
  "uuid": datasets.Value("string"),
33
  "time": datasets.Value("float"),
34
  }))
35
+ out_ds.push_to_hub(f"datalab-to/marker_benchmark_{self.service}", private=True)
36
 
37
  def generate_data(self):
38
+ max_rows = self.max_rows
39
  for idx, sample in tqdm(enumerate(self.ds), desc=f"Saving {self.service} results"):
40
  cache_file = self.cache_path / f"{idx}.json"
41
  if cache_file.exists():
 
47
  except JSONDecodeError as e:
48
  print(f"Error with sample {idx}: {e}")
49
  continue
50
+ except Exception as e:
51
+ print(f"Error with sample {idx}: {e}")
52
+ continue
53
  out_data["uuid"] = sample["uuid"]
54
 
55
  with cache_file.open("w") as f:
benchmarks/overall/download/llamaparse.py CHANGED
@@ -1,5 +1,4 @@
1
  import io
2
- import os
3
  import time
4
 
5
  import requests
 
1
  import io
 
2
  import time
3
 
4
  import requests
benchmarks/overall/download/main.py CHANGED
@@ -2,17 +2,19 @@ import click
2
 
3
  from benchmarks.overall.download.llamaparse import LlamaParseDownloader
4
  from benchmarks.overall.download.mathpix import MathpixDownloader
 
5
 
6
 
7
  @click.command("Download data from inference services")
8
- @click.argument("service", type=click.Choice(["mathpix", "llamaparse"]))
9
- @click.argument("--max_rows", type=int, default=2200)
10
- @click.argument("--api_key", type=str, default=None)
11
- @click.argument("--app_id", type=str, default=None)
12
  def main(service: str, max_rows: int, api_key: str, app_id: str):
13
  registry = {
14
  "mathpix": MathpixDownloader,
15
- "llamaparse": LlamaParseDownloader
 
16
  }
17
  downloader = registry[service](api_key, app_id, max_rows=max_rows)
18
 
 
2
 
3
  from benchmarks.overall.download.llamaparse import LlamaParseDownloader
4
  from benchmarks.overall.download.mathpix import MathpixDownloader
5
+ from benchmarks.overall.download.mistral import MistralDownloader
6
 
7
 
8
  @click.command("Download data from inference services")
9
+ @click.argument("service", type=click.Choice(["mathpix", "llamaparse", "mistral"]))
10
+ @click.option("--max_rows", type=int, default=2200)
11
+ @click.option("--api_key", type=str, default=None)
12
+ @click.option("--app_id", type=str, default=None)
13
  def main(service: str, max_rows: int, api_key: str, app_id: str):
14
  registry = {
15
  "mathpix": MathpixDownloader,
16
+ "llamaparse": LlamaParseDownloader,
17
+ "mistral": MistralDownloader,
18
  }
19
  downloader = registry[service](api_key, app_id, max_rows=max_rows)
20
 
benchmarks/overall/download/mistral.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import time
3
+ import requests
4
+
5
+ from benchmarks.overall.download.base import Downloader
6
+
7
+
8
+ class MistralDownloader(Downloader):
9
+ service = "mistral"
10
+
11
+ def get_html(self, pdf_bytes):
12
+ rand_name = str(time.time()) + ".pdf"
13
+ start = time.time()
14
+ buff = io.BytesIO(pdf_bytes)
15
+ md = upload_and_process_file(self.api_key, rand_name, buff)
16
+ end = time.time()
17
+ if isinstance(md, bytes):
18
+ md = md.decode("utf-8")
19
+
20
+ return {
21
+ "md": md,
22
+ "time": end - start,
23
+ }
24
+
25
+
26
+ def upload_and_process_file(api_key: str, fname: str, buff):
27
+ headers = {
28
+ "Authorization": f"Bearer {api_key}"
29
+ }
30
+
31
+ upload_headers = headers.copy()
32
+ files = {
33
+ 'file': (fname, buff, 'application/pdf'),
34
+ 'purpose': (None, 'ocr')
35
+ }
36
+
37
+ upload_response = requests.post(
38
+ 'https://api.mistral.ai/v1/files',
39
+ headers=upload_headers,
40
+ files=files
41
+ )
42
+ upload_response.raise_for_status()
43
+ file_id = upload_response.json()['id']
44
+
45
+ url_headers = headers.copy()
46
+ url_headers["Accept"] = "application/json"
47
+
48
+ url_response = requests.get(
49
+ f'https://api.mistral.ai/v1/files/{file_id}/url?expiry=24',
50
+ headers=url_headers
51
+ )
52
+ url_response.raise_for_status()
53
+ signed_url = url_response.json()['url']
54
+
55
+ ocr_headers = headers.copy()
56
+ ocr_headers["Content-Type"] = "application/json"
57
+
58
+ ocr_data = {
59
+ "model": "mistral-ocr-latest",
60
+ "document": {
61
+ "type": "document_url",
62
+ "document_url": signed_url
63
+ },
64
+ "include_image_base64": True
65
+ }
66
+ ocr_response = requests.post(
67
+ 'https://api.mistral.ai/v1/ocr',
68
+ headers=ocr_headers,
69
+ json=ocr_data
70
+ )
71
+ ocr_response.raise_for_status()
72
+ result = ocr_response.json()
73
+ return result["pages"][0]["markdown"]
benchmarks/overall/elo.py CHANGED
@@ -176,7 +176,7 @@ def display_win_rates_table(win_rates: dict):
176
  @click.argument("dataset", type=str)
177
  @click.option("--methods", type=str, help="List of methods to compare: comma separated like marker,mathpix")
178
  @click.option("--row_samples", type=int, default=2, help="Number of samples per row")
179
- @click.option("--max_rows", type=int, default=100, help="Maximum number of rows to process")
180
  def main(
181
  dataset: str,
182
  methods: str,
@@ -187,8 +187,9 @@ def main(
187
  method_lst = methods.split(",")
188
  win_rates = {m: defaultdict(lambda: defaultdict(int)) for m in method_lst}
189
  comparer = Comparer()
 
190
 
191
- for i in tqdm(range(min(len(ds), max_rows)), desc="Calculating win rates..."):
192
  row = ds[i]
193
  # Avoid any bias in ordering
194
  random.shuffle(method_lst)
 
176
  @click.argument("dataset", type=str)
177
  @click.option("--methods", type=str, help="List of methods to compare: comma separated like marker,mathpix")
178
  @click.option("--row_samples", type=int, default=2, help="Number of samples per row")
179
+ @click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process")
180
  def main(
181
  dataset: str,
182
  methods: str,
 
187
  method_lst = methods.split(",")
188
  win_rates = {m: defaultdict(lambda: defaultdict(int)) for m in method_lst}
189
  comparer = Comparer()
190
+ max_rows = max_rows or len(ds)
191
 
192
+ for i in tqdm(range(max_rows), desc="Calculating win rates..."):
193
  row = ds[i]
194
  # Avoid any bias in ordering
195
  random.shuffle(method_lst)
benchmarks/overall/methods/mistral.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+
3
+ from benchmarks.overall.methods import BaseMethod, BenchmarkResult
4
+
5
+
6
+ class MistralMethod(BaseMethod):
7
+ mistral_ds: datasets.Dataset = None
8
+
9
+ def __call__(self, sample) -> BenchmarkResult:
10
+ uuid = sample["uuid"]
11
+ data = None
12
+ for row in self.mistral_ds:
13
+ if str(row["uuid"]) == str(uuid):
14
+ data = row
15
+ break
16
+ if not data:
17
+ raise ValueError(f"Could not find data for uuid {uuid}")
18
+
19
+ return {
20
+ "markdown": data["md"],
21
+ "time": data["time"]
22
+ }
benchmarks/overall/overall.py CHANGED
@@ -89,7 +89,7 @@ def get_method_scores(benchmark_dataset: datasets.Dataset, methods: List[str], s
89
  @click.command(help="Benchmark PDF to MD conversion.")
90
  @click.option("--dataset", type=str, help="Path to the benchmark dataset", default="datalab-to/marker_benchmark")
91
  @click.option("--out_dataset", type=str, help="Path to the output dataset", default=None)
92
- @click.option("--methods", type=str, help="Comma separated list of other methods to compare against. Possible values: marker,mathpix,llamaparse,docling", default="marker")
93
  @click.option("--scores", type=str, help="Comma separated list of scoring functions to use. Possible values: heuristic,llm", default="heuristic")
94
  @click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "overall"), help="Output path for results.")
95
  @click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process.")
@@ -145,6 +145,9 @@ def main(
145
  if "llamaparse" in methods:
146
  artifacts["llamaparse_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_llamaparse", split="train")
147
 
 
 
 
148
  if "olmocr" in methods:
149
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
150
  model = Qwen2VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-0225-preview",
@@ -167,7 +170,7 @@ def main(
167
  if use_llm:
168
  out_dataset += "_llm"
169
  dataset = build_dataset(benchmark_dataset, result, score_types, max_rows=max_rows)
170
- dataset.push_to_hub(out_dataset)
171
 
172
 
173
  if __name__ == "__main__":
 
89
  @click.command(help="Benchmark PDF to MD conversion.")
90
  @click.option("--dataset", type=str, help="Path to the benchmark dataset", default="datalab-to/marker_benchmark")
91
  @click.option("--out_dataset", type=str, help="Path to the output dataset", default=None)
92
+ @click.option("--methods", type=str, help="Comma separated list of other methods to compare against. Possible values: marker,mathpix,llamaparse,docling,mistral", default="marker")
93
  @click.option("--scores", type=str, help="Comma separated list of scoring functions to use. Possible values: heuristic,llm", default="heuristic")
94
  @click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "overall"), help="Output path for results.")
95
  @click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process.")
 
145
  if "llamaparse" in methods:
146
  artifacts["llamaparse_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_llamaparse", split="train")
147
 
148
+ if "mistral" in methods:
149
+ artifacts["mistral_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_mistral", split="train")
150
+
151
  if "olmocr" in methods:
152
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
153
  model = Qwen2VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-0225-preview",
 
170
  if use_llm:
171
  out_dataset += "_llm"
172
  dataset = build_dataset(benchmark_dataset, result, score_types, max_rows=max_rows)
173
+ dataset.push_to_hub(out_dataset, private=True)
174
 
175
 
176
  if __name__ == "__main__":
benchmarks/overall/registry.py CHANGED
@@ -3,6 +3,7 @@ from benchmarks.overall.methods.gt import GTMethod
3
  from benchmarks.overall.methods.llamaparse import LlamaParseMethod
4
  from benchmarks.overall.methods.marker import MarkerMethod
5
  from benchmarks.overall.methods.mathpix import MathpixMethod
 
6
  from benchmarks.overall.methods.olmocr import OlmOCRMethod
7
  from benchmarks.overall.scorers.heuristic import HeuristicScorer
8
  from benchmarks.overall.scorers.llm import LLMScorer
@@ -18,5 +19,6 @@ METHOD_REGISTRY = {
18
  "mathpix": MathpixMethod,
19
  "llamaparse": LlamaParseMethod,
20
  "docling": DoclingMethod,
21
- "olmocr": OlmOCRMethod
 
22
  }
 
3
  from benchmarks.overall.methods.llamaparse import LlamaParseMethod
4
  from benchmarks.overall.methods.marker import MarkerMethod
5
  from benchmarks.overall.methods.mathpix import MathpixMethod
6
+ from benchmarks.overall.methods.mistral import MistralMethod
7
  from benchmarks.overall.methods.olmocr import OlmOCRMethod
8
  from benchmarks.overall.scorers.heuristic import HeuristicScorer
9
  from benchmarks.overall.scorers.llm import LLMScorer
 
19
  "mathpix": MathpixMethod,
20
  "llamaparse": LlamaParseMethod,
21
  "docling": DoclingMethod,
22
+ "olmocr": OlmOCRMethod,
23
+ "mistral": MistralMethod
24
  }
marker/output.py CHANGED
@@ -7,9 +7,12 @@ from pydantic import BaseModel
7
  from marker.renderers.html import HTMLOutput
8
  from marker.renderers.json import JSONOutput, JSONBlockOutput
9
  from marker.renderers.markdown import MarkdownOutput
 
 
10
  from marker.settings import settings
11
 
12
- def json_to_html(block: JSONBlockOutput):
 
13
  # Utility function to take in json block output and give html for the block.
14
  if not getattr(block, "children", None):
15
  return block.html
 
7
  from marker.renderers.html import HTMLOutput
8
  from marker.renderers.json import JSONOutput, JSONBlockOutput
9
  from marker.renderers.markdown import MarkdownOutput
10
+ from marker.schema.blocks import Block, BlockOutput
11
+ from marker.schema.document import Document
12
  from marker.settings import settings
13
 
14
+
15
+ def json_to_html(block: JSONBlockOutput | BlockOutput):
16
  # Utility function to take in json block output and give html for the block.
17
  if not getattr(block, "children", None):
18
  return block.html
marker/processors/llm/llm_form.py CHANGED
@@ -2,6 +2,7 @@ from typing import List
2
 
3
  from pydantic import BaseModel
4
 
 
5
  from marker.processors.llm import PromptData, BaseLLMSimpleBlockProcessor, BlockData
6
 
7
  from marker.schema import BlockTypes
@@ -77,7 +78,7 @@ Comparison: The html representation has the labels in the first row and the valu
77
  prompt_data = []
78
  for block_data in self.inference_blocks(document):
79
  block = block_data["block"]
80
- block_html = block.render(document).html
81
  prompt = self.form_rewriting_prompt.replace("{block_html}", block_html)
82
  image = self.extract_image(document, block)
83
  prompt_data.append({
@@ -92,7 +93,7 @@ Comparison: The html representation has the labels in the first row and the valu
92
 
93
  def rewrite_block(self, response: dict, prompt_data: PromptData, document: Document):
94
  block = prompt_data["block"]
95
- block_html = block.render(document).html
96
 
97
  if not response or "corrected_html" not in response:
98
  block.update_metadata(llm_error_count=1)
 
2
 
3
  from pydantic import BaseModel
4
 
5
+ from marker.output import json_to_html
6
  from marker.processors.llm import PromptData, BaseLLMSimpleBlockProcessor, BlockData
7
 
8
  from marker.schema import BlockTypes
 
78
  prompt_data = []
79
  for block_data in self.inference_blocks(document):
80
  block = block_data["block"]
81
+ block_html = json_to_html(block.render(document))
82
  prompt = self.form_rewriting_prompt.replace("{block_html}", block_html)
83
  image = self.extract_image(document, block)
84
  prompt_data.append({
 
93
 
94
  def rewrite_block(self, response: dict, prompt_data: PromptData, document: Document):
95
  block = prompt_data["block"]
96
+ block_html = json_to_html(block.render(document))
97
 
98
  if not response or "corrected_html" not in response:
99
  block.update_metadata(llm_error_count=1)
marker/processors/llm/llm_mathblock.py CHANGED
@@ -5,6 +5,7 @@ from typing import List, Tuple, Annotated
5
  from pydantic import BaseModel
6
  from tqdm import tqdm
7
 
 
8
  from marker.processors.llm import BaseLLMComplexBlockProcessor
9
 
10
  from marker.schema import BlockTypes
@@ -27,8 +28,8 @@ class LLMMathBlockProcessor(BaseLLMComplexBlockProcessor):
27
  additional_block_types = (BlockTypes.Text, BlockTypes.Caption, BlockTypes.SectionHeader, BlockTypes.Footnote) # Seconday, can also contain math
28
 
29
  text_math_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
30
- You will receive an image of a text block and a set of extracted lines corresponding to the text in the image.
31
- Your task is to correct any errors in the extracted block, including math, formatting, and other inaccuracies, and output the corrected block in html format. Stay as faithful to the original text as possible.
32
 
33
  **Instructions:**
34
 
@@ -39,7 +40,7 @@ Your task is to correct any errors in the extracted block, including math, forma
39
  5. If there are no errors in any of the extracted text, output "No corrections needed".
40
  6. Correct any errors in the extracted text, including:
41
  * Inline math: Ensure all mathematical expressions are correctly formatted and rendered. Surround them with <math>...</math> tags. The math expressions should be rendered in simple, concise, KaTeX-compatible LaTeX. Do not use $ or $$ as delimiters.
42
- * If a math expression is not in LaTeX format, convert it to LaTeX format, and surround it with <math>...</math> tags.
43
  * Formatting: Maintain consistent formatting with the text block image, including spacing, indentation, subscripts/superscripts, and special characters. Use the <i>, <b>, <sup>, <sub>, and <span> tags to format the text as needed.
44
  * Other inaccuracies: If the image is handwritten then you may correct any spelling errors, or other discrepancies.
45
  * Ensure lines wrap properly, and that newlines are not in the middle of sentences.
@@ -125,7 +126,7 @@ Adversarial training <i>(AT)</i> <a href='#page-9-1'>[23]</a>, which aims to min
125
  pbar.close()
126
 
127
  def get_block_text(self, block: Block, document: Document) -> str:
128
- html = block.render(document).html
129
  return html
130
 
131
  def get_block_lines(self, block: Block, document: Document) -> Tuple[list, list]:
 
5
  from pydantic import BaseModel
6
  from tqdm import tqdm
7
 
8
+ from marker.output import json_to_html
9
  from marker.processors.llm import BaseLLMComplexBlockProcessor
10
 
11
  from marker.schema import BlockTypes
 
28
  additional_block_types = (BlockTypes.Text, BlockTypes.Caption, BlockTypes.SectionHeader, BlockTypes.Footnote) # Seconday, can also contain math
29
 
30
  text_math_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
31
+ You will receive an image of a text block and extracted text corresponding to the text in the image.
32
+ Your task is to correct any errors in the extracted text, including math, formatting, and other inaccuracies, and output the corrected block in html format. Stay as faithful to the text in the image as possible.
33
 
34
  **Instructions:**
35
 
 
40
  5. If there are no errors in any of the extracted text, output "No corrections needed".
41
  6. Correct any errors in the extracted text, including:
42
  * Inline math: Ensure all mathematical expressions are correctly formatted and rendered. Surround them with <math>...</math> tags. The math expressions should be rendered in simple, concise, KaTeX-compatible LaTeX. Do not use $ or $$ as delimiters.
43
+ * If a math expression is not in LaTeX format, convert it to LaTeX format, and surround it with <math>...</math> tags.
44
  * Formatting: Maintain consistent formatting with the text block image, including spacing, indentation, subscripts/superscripts, and special characters. Use the <i>, <b>, <sup>, <sub>, and <span> tags to format the text as needed.
45
  * Other inaccuracies: If the image is handwritten then you may correct any spelling errors, or other discrepancies.
46
  * Ensure lines wrap properly, and that newlines are not in the middle of sentences.
 
126
  pbar.close()
127
 
128
  def get_block_text(self, block: Block, document: Document) -> str:
129
+ html = json_to_html(block.render(document))
130
  return html
131
 
132
  def get_block_lines(self, block: Block, document: Document) -> Tuple[list, list]:
marker/processors/llm/llm_table_merge.py CHANGED
@@ -5,6 +5,7 @@ from pydantic import BaseModel
5
  from tqdm import tqdm
6
  from PIL import Image
7
 
 
8
  from marker.processors.llm import BaseLLMComplexBlockProcessor
9
  from marker.schema import BlockTypes
10
  from marker.schema.blocks import Block, TableCell
@@ -235,8 +236,8 @@ Table 2
235
 
236
  start_image = start_block.get_image(document, highres=False)
237
  curr_image = curr_block.get_image(document, highres=False)
238
- start_html = start_block.render(document).html
239
- curr_html = curr_block.render(document).html
240
 
241
  prompt = self.table_merge_prompt.replace("{{table1}}", start_html).replace("{{table2}}", curr_html)
242
 
 
5
  from tqdm import tqdm
6
  from PIL import Image
7
 
8
+ from marker.output import json_to_html
9
  from marker.processors.llm import BaseLLMComplexBlockProcessor
10
  from marker.schema import BlockTypes
11
  from marker.schema.blocks import Block, TableCell
 
236
 
237
  start_image = start_block.get_image(document, highres=False)
238
  curr_image = curr_block.get_image(document, highres=False)
239
+ start_html = json_to_html(start_block.render(document))
240
+ curr_html = json_to_html(curr_block.render(document))
241
 
242
  prompt = self.table_merge_prompt.replace("{{table1}}", start_html).replace("{{table2}}", curr_html)
243
 
marker/services/vertex.py CHANGED
@@ -17,11 +17,18 @@ class GoogleVertexService(BaseGeminiService):
17
  str,
18
  "The name of the Google model to use for the service."
19
  ] = "gemini-2.0-flash-001"
 
 
 
 
20
 
21
  def get_google_client(self, timeout: int):
 
 
 
22
  return genai.Client(
23
  vertexai=True,
24
  project=self.vertex_project_id,
25
  location=self.vertex_location,
26
- http_options={"timeout": timeout * 1000} # Convert to milliseconds
27
  )
 
17
  str,
18
  "The name of the Google model to use for the service."
19
  ] = "gemini-2.0-flash-001"
20
+ vertex_dedicated: Annotated[
21
+ bool,
22
+ "Whether to use a dedicated Vertex AI instance."
23
+ ] = False
24
 
25
  def get_google_client(self, timeout: int):
26
+ http_options = {"timeout": timeout * 1000} # Convert to milliseconds
27
+ if self.vertex_dedicated:
28
+ http_options["headers"] = {"x-vertex-ai-llm-request-type": "dedicated"}
29
  return genai.Client(
30
  vertexai=True,
31
  project=self.vertex_project_id,
32
  location=self.vertex_location,
33
+ http_options=http_options,
34
  )