Vik Paruchuri
commited on
Commit
·
1643ef3
1
Parent(s):
b334fad
Misc fixes, benchmark updates
Browse files- .gitignore +1 -0
- benchmarks/overall/download/base.py +5 -2
- benchmarks/overall/download/llamaparse.py +0 -1
- benchmarks/overall/download/main.py +7 -5
- benchmarks/overall/download/mistral.py +73 -0
- benchmarks/overall/elo.py +3 -2
- benchmarks/overall/methods/mistral.py +22 -0
- benchmarks/overall/overall.py +5 -2
- benchmarks/overall/registry.py +3 -1
- marker/output.py +4 -1
- marker/processors/llm/llm_form.py +3 -2
- marker/processors/llm/llm_mathblock.py +5 -4
- marker/processors/llm/llm_table_merge.py +3 -2
- marker/services/vertex.py +8 -1
.gitignore
CHANGED
|
@@ -13,6 +13,7 @@ temp.md
|
|
| 13 |
temp
|
| 14 |
conversion_results
|
| 15 |
uploads
|
|
|
|
| 16 |
|
| 17 |
# Byte-compiled / optimized / DLL files
|
| 18 |
__pycache__/
|
|
|
|
| 13 |
temp
|
| 14 |
conversion_results
|
| 15 |
uploads
|
| 16 |
+
/cache
|
| 17 |
|
| 18 |
# Byte-compiled / optimized / DLL files
|
| 19 |
__pycache__/
|
benchmarks/overall/download/base.py
CHANGED
|
@@ -32,10 +32,10 @@ class Downloader:
|
|
| 32 |
"uuid": datasets.Value("string"),
|
| 33 |
"time": datasets.Value("float"),
|
| 34 |
}))
|
| 35 |
-
out_ds.push_to_hub(f"datalab-to/marker_benchmark_{self.service}")
|
| 36 |
|
| 37 |
def generate_data(self):
|
| 38 |
-
max_rows =
|
| 39 |
for idx, sample in tqdm(enumerate(self.ds), desc=f"Saving {self.service} results"):
|
| 40 |
cache_file = self.cache_path / f"{idx}.json"
|
| 41 |
if cache_file.exists():
|
|
@@ -47,6 +47,9 @@ class Downloader:
|
|
| 47 |
except JSONDecodeError as e:
|
| 48 |
print(f"Error with sample {idx}: {e}")
|
| 49 |
continue
|
|
|
|
|
|
|
|
|
|
| 50 |
out_data["uuid"] = sample["uuid"]
|
| 51 |
|
| 52 |
with cache_file.open("w") as f:
|
|
|
|
| 32 |
"uuid": datasets.Value("string"),
|
| 33 |
"time": datasets.Value("float"),
|
| 34 |
}))
|
| 35 |
+
out_ds.push_to_hub(f"datalab-to/marker_benchmark_{self.service}", private=True)
|
| 36 |
|
| 37 |
def generate_data(self):
|
| 38 |
+
max_rows = self.max_rows
|
| 39 |
for idx, sample in tqdm(enumerate(self.ds), desc=f"Saving {self.service} results"):
|
| 40 |
cache_file = self.cache_path / f"{idx}.json"
|
| 41 |
if cache_file.exists():
|
|
|
|
| 47 |
except JSONDecodeError as e:
|
| 48 |
print(f"Error with sample {idx}: {e}")
|
| 49 |
continue
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Error with sample {idx}: {e}")
|
| 52 |
+
continue
|
| 53 |
out_data["uuid"] = sample["uuid"]
|
| 54 |
|
| 55 |
with cache_file.open("w") as f:
|
benchmarks/overall/download/llamaparse.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import io
|
| 2 |
-
import os
|
| 3 |
import time
|
| 4 |
|
| 5 |
import requests
|
|
|
|
| 1 |
import io
|
|
|
|
| 2 |
import time
|
| 3 |
|
| 4 |
import requests
|
benchmarks/overall/download/main.py
CHANGED
|
@@ -2,17 +2,19 @@ import click
|
|
| 2 |
|
| 3 |
from benchmarks.overall.download.llamaparse import LlamaParseDownloader
|
| 4 |
from benchmarks.overall.download.mathpix import MathpixDownloader
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
@click.command("Download data from inference services")
|
| 8 |
-
@click.argument("service", type=click.Choice(["mathpix", "llamaparse"]))
|
| 9 |
-
@click.
|
| 10 |
-
@click.
|
| 11 |
-
@click.
|
| 12 |
def main(service: str, max_rows: int, api_key: str, app_id: str):
|
| 13 |
registry = {
|
| 14 |
"mathpix": MathpixDownloader,
|
| 15 |
-
"llamaparse": LlamaParseDownloader
|
|
|
|
| 16 |
}
|
| 17 |
downloader = registry[service](api_key, app_id, max_rows=max_rows)
|
| 18 |
|
|
|
|
| 2 |
|
| 3 |
from benchmarks.overall.download.llamaparse import LlamaParseDownloader
|
| 4 |
from benchmarks.overall.download.mathpix import MathpixDownloader
|
| 5 |
+
from benchmarks.overall.download.mistral import MistralDownloader
|
| 6 |
|
| 7 |
|
| 8 |
@click.command("Download data from inference services")
|
| 9 |
+
@click.argument("service", type=click.Choice(["mathpix", "llamaparse", "mistral"]))
|
| 10 |
+
@click.option("--max_rows", type=int, default=2200)
|
| 11 |
+
@click.option("--api_key", type=str, default=None)
|
| 12 |
+
@click.option("--app_id", type=str, default=None)
|
| 13 |
def main(service: str, max_rows: int, api_key: str, app_id: str):
|
| 14 |
registry = {
|
| 15 |
"mathpix": MathpixDownloader,
|
| 16 |
+
"llamaparse": LlamaParseDownloader,
|
| 17 |
+
"mistral": MistralDownloader,
|
| 18 |
}
|
| 19 |
downloader = registry[service](api_key, app_id, max_rows=max_rows)
|
| 20 |
|
benchmarks/overall/download/mistral.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import time
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
from benchmarks.overall.download.base import Downloader
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MistralDownloader(Downloader):
|
| 9 |
+
service = "mistral"
|
| 10 |
+
|
| 11 |
+
def get_html(self, pdf_bytes):
|
| 12 |
+
rand_name = str(time.time()) + ".pdf"
|
| 13 |
+
start = time.time()
|
| 14 |
+
buff = io.BytesIO(pdf_bytes)
|
| 15 |
+
md = upload_and_process_file(self.api_key, rand_name, buff)
|
| 16 |
+
end = time.time()
|
| 17 |
+
if isinstance(md, bytes):
|
| 18 |
+
md = md.decode("utf-8")
|
| 19 |
+
|
| 20 |
+
return {
|
| 21 |
+
"md": md,
|
| 22 |
+
"time": end - start,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def upload_and_process_file(api_key: str, fname: str, buff):
|
| 27 |
+
headers = {
|
| 28 |
+
"Authorization": f"Bearer {api_key}"
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
upload_headers = headers.copy()
|
| 32 |
+
files = {
|
| 33 |
+
'file': (fname, buff, 'application/pdf'),
|
| 34 |
+
'purpose': (None, 'ocr')
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
upload_response = requests.post(
|
| 38 |
+
'https://api.mistral.ai/v1/files',
|
| 39 |
+
headers=upload_headers,
|
| 40 |
+
files=files
|
| 41 |
+
)
|
| 42 |
+
upload_response.raise_for_status()
|
| 43 |
+
file_id = upload_response.json()['id']
|
| 44 |
+
|
| 45 |
+
url_headers = headers.copy()
|
| 46 |
+
url_headers["Accept"] = "application/json"
|
| 47 |
+
|
| 48 |
+
url_response = requests.get(
|
| 49 |
+
f'https://api.mistral.ai/v1/files/{file_id}/url?expiry=24',
|
| 50 |
+
headers=url_headers
|
| 51 |
+
)
|
| 52 |
+
url_response.raise_for_status()
|
| 53 |
+
signed_url = url_response.json()['url']
|
| 54 |
+
|
| 55 |
+
ocr_headers = headers.copy()
|
| 56 |
+
ocr_headers["Content-Type"] = "application/json"
|
| 57 |
+
|
| 58 |
+
ocr_data = {
|
| 59 |
+
"model": "mistral-ocr-latest",
|
| 60 |
+
"document": {
|
| 61 |
+
"type": "document_url",
|
| 62 |
+
"document_url": signed_url
|
| 63 |
+
},
|
| 64 |
+
"include_image_base64": True
|
| 65 |
+
}
|
| 66 |
+
ocr_response = requests.post(
|
| 67 |
+
'https://api.mistral.ai/v1/ocr',
|
| 68 |
+
headers=ocr_headers,
|
| 69 |
+
json=ocr_data
|
| 70 |
+
)
|
| 71 |
+
ocr_response.raise_for_status()
|
| 72 |
+
result = ocr_response.json()
|
| 73 |
+
return result["pages"][0]["markdown"]
|
benchmarks/overall/elo.py
CHANGED
|
@@ -176,7 +176,7 @@ def display_win_rates_table(win_rates: dict):
|
|
| 176 |
@click.argument("dataset", type=str)
|
| 177 |
@click.option("--methods", type=str, help="List of methods to compare: comma separated like marker,mathpix")
|
| 178 |
@click.option("--row_samples", type=int, default=2, help="Number of samples per row")
|
| 179 |
-
@click.option("--max_rows", type=int, default=
|
| 180 |
def main(
|
| 181 |
dataset: str,
|
| 182 |
methods: str,
|
|
@@ -187,8 +187,9 @@ def main(
|
|
| 187 |
method_lst = methods.split(",")
|
| 188 |
win_rates = {m: defaultdict(lambda: defaultdict(int)) for m in method_lst}
|
| 189 |
comparer = Comparer()
|
|
|
|
| 190 |
|
| 191 |
-
for i in tqdm(range(
|
| 192 |
row = ds[i]
|
| 193 |
# Avoid any bias in ordering
|
| 194 |
random.shuffle(method_lst)
|
|
|
|
| 176 |
@click.argument("dataset", type=str)
|
| 177 |
@click.option("--methods", type=str, help="List of methods to compare: comma separated like marker,mathpix")
|
| 178 |
@click.option("--row_samples", type=int, default=2, help="Number of samples per row")
|
| 179 |
+
@click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process")
|
| 180 |
def main(
|
| 181 |
dataset: str,
|
| 182 |
methods: str,
|
|
|
|
| 187 |
method_lst = methods.split(",")
|
| 188 |
win_rates = {m: defaultdict(lambda: defaultdict(int)) for m in method_lst}
|
| 189 |
comparer = Comparer()
|
| 190 |
+
max_rows = max_rows or len(ds)
|
| 191 |
|
| 192 |
+
for i in tqdm(range(max_rows), desc="Calculating win rates..."):
|
| 193 |
row = ds[i]
|
| 194 |
# Avoid any bias in ordering
|
| 195 |
random.shuffle(method_lst)
|
benchmarks/overall/methods/mistral.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
|
| 3 |
+
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MistralMethod(BaseMethod):
|
| 7 |
+
mistral_ds: datasets.Dataset = None
|
| 8 |
+
|
| 9 |
+
def __call__(self, sample) -> BenchmarkResult:
|
| 10 |
+
uuid = sample["uuid"]
|
| 11 |
+
data = None
|
| 12 |
+
for row in self.mistral_ds:
|
| 13 |
+
if str(row["uuid"]) == str(uuid):
|
| 14 |
+
data = row
|
| 15 |
+
break
|
| 16 |
+
if not data:
|
| 17 |
+
raise ValueError(f"Could not find data for uuid {uuid}")
|
| 18 |
+
|
| 19 |
+
return {
|
| 20 |
+
"markdown": data["md"],
|
| 21 |
+
"time": data["time"]
|
| 22 |
+
}
|
benchmarks/overall/overall.py
CHANGED
|
@@ -89,7 +89,7 @@ def get_method_scores(benchmark_dataset: datasets.Dataset, methods: List[str], s
|
|
| 89 |
@click.command(help="Benchmark PDF to MD conversion.")
|
| 90 |
@click.option("--dataset", type=str, help="Path to the benchmark dataset", default="datalab-to/marker_benchmark")
|
| 91 |
@click.option("--out_dataset", type=str, help="Path to the output dataset", default=None)
|
| 92 |
-
@click.option("--methods", type=str, help="Comma separated list of other methods to compare against. Possible values: marker,mathpix,llamaparse,docling", default="marker")
|
| 93 |
@click.option("--scores", type=str, help="Comma separated list of scoring functions to use. Possible values: heuristic,llm", default="heuristic")
|
| 94 |
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "overall"), help="Output path for results.")
|
| 95 |
@click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process.")
|
|
@@ -145,6 +145,9 @@ def main(
|
|
| 145 |
if "llamaparse" in methods:
|
| 146 |
artifacts["llamaparse_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_llamaparse", split="train")
|
| 147 |
|
|
|
|
|
|
|
|
|
|
| 148 |
if "olmocr" in methods:
|
| 149 |
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
| 150 |
model = Qwen2VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-0225-preview",
|
|
@@ -167,7 +170,7 @@ def main(
|
|
| 167 |
if use_llm:
|
| 168 |
out_dataset += "_llm"
|
| 169 |
dataset = build_dataset(benchmark_dataset, result, score_types, max_rows=max_rows)
|
| 170 |
-
dataset.push_to_hub(out_dataset)
|
| 171 |
|
| 172 |
|
| 173 |
if __name__ == "__main__":
|
|
|
|
| 89 |
@click.command(help="Benchmark PDF to MD conversion.")
|
| 90 |
@click.option("--dataset", type=str, help="Path to the benchmark dataset", default="datalab-to/marker_benchmark")
|
| 91 |
@click.option("--out_dataset", type=str, help="Path to the output dataset", default=None)
|
| 92 |
+
@click.option("--methods", type=str, help="Comma separated list of other methods to compare against. Possible values: marker,mathpix,llamaparse,docling,mistral", default="marker")
|
| 93 |
@click.option("--scores", type=str, help="Comma separated list of scoring functions to use. Possible values: heuristic,llm", default="heuristic")
|
| 94 |
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "overall"), help="Output path for results.")
|
| 95 |
@click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process.")
|
|
|
|
| 145 |
if "llamaparse" in methods:
|
| 146 |
artifacts["llamaparse_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_llamaparse", split="train")
|
| 147 |
|
| 148 |
+
if "mistral" in methods:
|
| 149 |
+
artifacts["mistral_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_mistral", split="train")
|
| 150 |
+
|
| 151 |
if "olmocr" in methods:
|
| 152 |
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
| 153 |
model = Qwen2VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-0225-preview",
|
|
|
|
| 170 |
if use_llm:
|
| 171 |
out_dataset += "_llm"
|
| 172 |
dataset = build_dataset(benchmark_dataset, result, score_types, max_rows=max_rows)
|
| 173 |
+
dataset.push_to_hub(out_dataset, private=True)
|
| 174 |
|
| 175 |
|
| 176 |
if __name__ == "__main__":
|
benchmarks/overall/registry.py
CHANGED
|
@@ -3,6 +3,7 @@ from benchmarks.overall.methods.gt import GTMethod
|
|
| 3 |
from benchmarks.overall.methods.llamaparse import LlamaParseMethod
|
| 4 |
from benchmarks.overall.methods.marker import MarkerMethod
|
| 5 |
from benchmarks.overall.methods.mathpix import MathpixMethod
|
|
|
|
| 6 |
from benchmarks.overall.methods.olmocr import OlmOCRMethod
|
| 7 |
from benchmarks.overall.scorers.heuristic import HeuristicScorer
|
| 8 |
from benchmarks.overall.scorers.llm import LLMScorer
|
|
@@ -18,5 +19,6 @@ METHOD_REGISTRY = {
|
|
| 18 |
"mathpix": MathpixMethod,
|
| 19 |
"llamaparse": LlamaParseMethod,
|
| 20 |
"docling": DoclingMethod,
|
| 21 |
-
"olmocr": OlmOCRMethod
|
|
|
|
| 22 |
}
|
|
|
|
| 3 |
from benchmarks.overall.methods.llamaparse import LlamaParseMethod
|
| 4 |
from benchmarks.overall.methods.marker import MarkerMethod
|
| 5 |
from benchmarks.overall.methods.mathpix import MathpixMethod
|
| 6 |
+
from benchmarks.overall.methods.mistral import MistralMethod
|
| 7 |
from benchmarks.overall.methods.olmocr import OlmOCRMethod
|
| 8 |
from benchmarks.overall.scorers.heuristic import HeuristicScorer
|
| 9 |
from benchmarks.overall.scorers.llm import LLMScorer
|
|
|
|
| 19 |
"mathpix": MathpixMethod,
|
| 20 |
"llamaparse": LlamaParseMethod,
|
| 21 |
"docling": DoclingMethod,
|
| 22 |
+
"olmocr": OlmOCRMethod,
|
| 23 |
+
"mistral": MistralMethod
|
| 24 |
}
|
marker/output.py
CHANGED
|
@@ -7,9 +7,12 @@ from pydantic import BaseModel
|
|
| 7 |
from marker.renderers.html import HTMLOutput
|
| 8 |
from marker.renderers.json import JSONOutput, JSONBlockOutput
|
| 9 |
from marker.renderers.markdown import MarkdownOutput
|
|
|
|
|
|
|
| 10 |
from marker.settings import settings
|
| 11 |
|
| 12 |
-
|
|
|
|
| 13 |
# Utility function to take in json block output and give html for the block.
|
| 14 |
if not getattr(block, "children", None):
|
| 15 |
return block.html
|
|
|
|
| 7 |
from marker.renderers.html import HTMLOutput
|
| 8 |
from marker.renderers.json import JSONOutput, JSONBlockOutput
|
| 9 |
from marker.renderers.markdown import MarkdownOutput
|
| 10 |
+
from marker.schema.blocks import Block, BlockOutput
|
| 11 |
+
from marker.schema.document import Document
|
| 12 |
from marker.settings import settings
|
| 13 |
|
| 14 |
+
|
| 15 |
+
def json_to_html(block: JSONBlockOutput | BlockOutput):
|
| 16 |
# Utility function to take in json block output and give html for the block.
|
| 17 |
if not getattr(block, "children", None):
|
| 18 |
return block.html
|
marker/processors/llm/llm_form.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import List
|
|
| 2 |
|
| 3 |
from pydantic import BaseModel
|
| 4 |
|
|
|
|
| 5 |
from marker.processors.llm import PromptData, BaseLLMSimpleBlockProcessor, BlockData
|
| 6 |
|
| 7 |
from marker.schema import BlockTypes
|
|
@@ -77,7 +78,7 @@ Comparison: The html representation has the labels in the first row and the valu
|
|
| 77 |
prompt_data = []
|
| 78 |
for block_data in self.inference_blocks(document):
|
| 79 |
block = block_data["block"]
|
| 80 |
-
block_html = block.render(document)
|
| 81 |
prompt = self.form_rewriting_prompt.replace("{block_html}", block_html)
|
| 82 |
image = self.extract_image(document, block)
|
| 83 |
prompt_data.append({
|
|
@@ -92,7 +93,7 @@ Comparison: The html representation has the labels in the first row and the valu
|
|
| 92 |
|
| 93 |
def rewrite_block(self, response: dict, prompt_data: PromptData, document: Document):
|
| 94 |
block = prompt_data["block"]
|
| 95 |
-
block_html = block.render(document)
|
| 96 |
|
| 97 |
if not response or "corrected_html" not in response:
|
| 98 |
block.update_metadata(llm_error_count=1)
|
|
|
|
| 2 |
|
| 3 |
from pydantic import BaseModel
|
| 4 |
|
| 5 |
+
from marker.output import json_to_html
|
| 6 |
from marker.processors.llm import PromptData, BaseLLMSimpleBlockProcessor, BlockData
|
| 7 |
|
| 8 |
from marker.schema import BlockTypes
|
|
|
|
| 78 |
prompt_data = []
|
| 79 |
for block_data in self.inference_blocks(document):
|
| 80 |
block = block_data["block"]
|
| 81 |
+
block_html = json_to_html(block.render(document))
|
| 82 |
prompt = self.form_rewriting_prompt.replace("{block_html}", block_html)
|
| 83 |
image = self.extract_image(document, block)
|
| 84 |
prompt_data.append({
|
|
|
|
| 93 |
|
| 94 |
def rewrite_block(self, response: dict, prompt_data: PromptData, document: Document):
|
| 95 |
block = prompt_data["block"]
|
| 96 |
+
block_html = json_to_html(block.render(document))
|
| 97 |
|
| 98 |
if not response or "corrected_html" not in response:
|
| 99 |
block.update_metadata(llm_error_count=1)
|
marker/processors/llm/llm_mathblock.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import List, Tuple, Annotated
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from tqdm import tqdm
|
| 7 |
|
|
|
|
| 8 |
from marker.processors.llm import BaseLLMComplexBlockProcessor
|
| 9 |
|
| 10 |
from marker.schema import BlockTypes
|
|
@@ -27,8 +28,8 @@ class LLMMathBlockProcessor(BaseLLMComplexBlockProcessor):
|
|
| 27 |
additional_block_types = (BlockTypes.Text, BlockTypes.Caption, BlockTypes.SectionHeader, BlockTypes.Footnote) # Seconday, can also contain math
|
| 28 |
|
| 29 |
text_math_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
|
| 30 |
-
You will receive an image of a text block and
|
| 31 |
-
Your task is to correct any errors in the extracted
|
| 32 |
|
| 33 |
**Instructions:**
|
| 34 |
|
|
@@ -39,7 +40,7 @@ Your task is to correct any errors in the extracted block, including math, forma
|
|
| 39 |
5. If there are no errors in any of the extracted text, output "No corrections needed".
|
| 40 |
6. Correct any errors in the extracted text, including:
|
| 41 |
* Inline math: Ensure all mathematical expressions are correctly formatted and rendered. Surround them with <math>...</math> tags. The math expressions should be rendered in simple, concise, KaTeX-compatible LaTeX. Do not use $ or $$ as delimiters.
|
| 42 |
-
|
| 43 |
* Formatting: Maintain consistent formatting with the text block image, including spacing, indentation, subscripts/superscripts, and special characters. Use the <i>, <b>, <sup>, <sub>, and <span> tags to format the text as needed.
|
| 44 |
* Other inaccuracies: If the image is handwritten then you may correct any spelling errors, or other discrepancies.
|
| 45 |
* Ensure lines wrap properly, and that newlines are not in the middle of sentences.
|
|
@@ -125,7 +126,7 @@ Adversarial training <i>(AT)</i> <a href='#page-9-1'>[23]</a>, which aims to min
|
|
| 125 |
pbar.close()
|
| 126 |
|
| 127 |
def get_block_text(self, block: Block, document: Document) -> str:
|
| 128 |
-
html = block.render(document)
|
| 129 |
return html
|
| 130 |
|
| 131 |
def get_block_lines(self, block: Block, document: Document) -> Tuple[list, list]:
|
|
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from tqdm import tqdm
|
| 7 |
|
| 8 |
+
from marker.output import json_to_html
|
| 9 |
from marker.processors.llm import BaseLLMComplexBlockProcessor
|
| 10 |
|
| 11 |
from marker.schema import BlockTypes
|
|
|
|
| 28 |
additional_block_types = (BlockTypes.Text, BlockTypes.Caption, BlockTypes.SectionHeader, BlockTypes.Footnote) # Seconday, can also contain math
|
| 29 |
|
| 30 |
text_math_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
|
| 31 |
+
You will receive an image of a text block and extracted text corresponding to the text in the image.
|
| 32 |
+
Your task is to correct any errors in the extracted text, including math, formatting, and other inaccuracies, and output the corrected block in html format. Stay as faithful to the text in the image as possible.
|
| 33 |
|
| 34 |
**Instructions:**
|
| 35 |
|
|
|
|
| 40 |
5. If there are no errors in any of the extracted text, output "No corrections needed".
|
| 41 |
6. Correct any errors in the extracted text, including:
|
| 42 |
* Inline math: Ensure all mathematical expressions are correctly formatted and rendered. Surround them with <math>...</math> tags. The math expressions should be rendered in simple, concise, KaTeX-compatible LaTeX. Do not use $ or $$ as delimiters.
|
| 43 |
+
* If a math expression is not in LaTeX format, convert it to LaTeX format, and surround it with <math>...</math> tags.
|
| 44 |
* Formatting: Maintain consistent formatting with the text block image, including spacing, indentation, subscripts/superscripts, and special characters. Use the <i>, <b>, <sup>, <sub>, and <span> tags to format the text as needed.
|
| 45 |
* Other inaccuracies: If the image is handwritten then you may correct any spelling errors, or other discrepancies.
|
| 46 |
* Ensure lines wrap properly, and that newlines are not in the middle of sentences.
|
|
|
|
| 126 |
pbar.close()
|
| 127 |
|
| 128 |
def get_block_text(self, block: Block, document: Document) -> str:
|
| 129 |
+
html = json_to_html(block.render(document))
|
| 130 |
return html
|
| 131 |
|
| 132 |
def get_block_lines(self, block: Block, document: Document) -> Tuple[list, list]:
|
marker/processors/llm/llm_table_merge.py
CHANGED
|
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
|
| 5 |
from tqdm import tqdm
|
| 6 |
from PIL import Image
|
| 7 |
|
|
|
|
| 8 |
from marker.processors.llm import BaseLLMComplexBlockProcessor
|
| 9 |
from marker.schema import BlockTypes
|
| 10 |
from marker.schema.blocks import Block, TableCell
|
|
@@ -235,8 +236,8 @@ Table 2
|
|
| 235 |
|
| 236 |
start_image = start_block.get_image(document, highres=False)
|
| 237 |
curr_image = curr_block.get_image(document, highres=False)
|
| 238 |
-
start_html = start_block.render(document)
|
| 239 |
-
curr_html = curr_block.render(document)
|
| 240 |
|
| 241 |
prompt = self.table_merge_prompt.replace("{{table1}}", start_html).replace("{{table2}}", curr_html)
|
| 242 |
|
|
|
|
| 5 |
from tqdm import tqdm
|
| 6 |
from PIL import Image
|
| 7 |
|
| 8 |
+
from marker.output import json_to_html
|
| 9 |
from marker.processors.llm import BaseLLMComplexBlockProcessor
|
| 10 |
from marker.schema import BlockTypes
|
| 11 |
from marker.schema.blocks import Block, TableCell
|
|
|
|
| 236 |
|
| 237 |
start_image = start_block.get_image(document, highres=False)
|
| 238 |
curr_image = curr_block.get_image(document, highres=False)
|
| 239 |
+
start_html = json_to_html(start_block.render(document))
|
| 240 |
+
curr_html = json_to_html(curr_block.render(document))
|
| 241 |
|
| 242 |
prompt = self.table_merge_prompt.replace("{{table1}}", start_html).replace("{{table2}}", curr_html)
|
| 243 |
|
marker/services/vertex.py
CHANGED
|
@@ -17,11 +17,18 @@ class GoogleVertexService(BaseGeminiService):
|
|
| 17 |
str,
|
| 18 |
"The name of the Google model to use for the service."
|
| 19 |
] = "gemini-2.0-flash-001"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def get_google_client(self, timeout: int):
|
|
|
|
|
|
|
|
|
|
| 22 |
return genai.Client(
|
| 23 |
vertexai=True,
|
| 24 |
project=self.vertex_project_id,
|
| 25 |
location=self.vertex_location,
|
| 26 |
-
http_options=
|
| 27 |
)
|
|
|
|
| 17 |
str,
|
| 18 |
"The name of the Google model to use for the service."
|
| 19 |
] = "gemini-2.0-flash-001"
|
| 20 |
+
vertex_dedicated: Annotated[
|
| 21 |
+
bool,
|
| 22 |
+
"Whether to use a dedicated Vertex AI instance."
|
| 23 |
+
] = False
|
| 24 |
|
| 25 |
def get_google_client(self, timeout: int):
|
| 26 |
+
http_options = {"timeout": timeout * 1000} # Convert to milliseconds
|
| 27 |
+
if self.vertex_dedicated:
|
| 28 |
+
http_options["headers"] = {"x-vertex-ai-llm-request-type": "dedicated"}
|
| 29 |
return genai.Client(
|
| 30 |
vertexai=True,
|
| 31 |
project=self.vertex_project_id,
|
| 32 |
location=self.vertex_location,
|
| 33 |
+
http_options=http_options,
|
| 34 |
)
|