Vik Paruchuri
commited on
Commit
·
e332feb
1
Parent(s):
0f59b51
Merge plus form processor
Browse files- marker/config/parser.py +17 -23
- marker/converters/pdf.py +5 -1
- marker/llm.py +55 -0
- marker/processors/llm/__init__.py +0 -0
- marker/processors/llm/highqualityformprocessor.py +151 -0
- marker/processors/llm/highqualitytableprocessor.py +188 -0
- marker/schema/blocks/form.py +5 -0
- marker/settings.py +3 -0
- pyproject.toml +1 -0
marker/config/parser.py
CHANGED
|
@@ -34,45 +34,39 @@ class ConfigParser:
|
|
| 34 |
fn = click.option("--disable_multiprocessing", is_flag=True, default=False, help="Disable multiprocessing.")(fn)
|
| 35 |
fn = click.option("--paginate_output", is_flag=True, default=False, help="Paginate output.")(fn)
|
| 36 |
fn = click.option("--disable_image_extraction", is_flag=True, default=False, help="Disable image extraction.")(fn)
|
| 37 |
-
fn = click.option("--high_quality", is_flag=True, default=False, help="Enable high quality processing with
|
| 38 |
return fn
|
| 39 |
|
| 40 |
def generate_config_dict(self) -> Dict[str, any]:
|
| 41 |
config = {}
|
| 42 |
output_dir = self.cli_options.get("output_dir", settings.OUTPUT_DIR)
|
| 43 |
for k, v in self.cli_options.items():
|
|
|
|
|
|
|
|
|
|
| 44 |
match k:
|
| 45 |
case "debug":
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
config["debug_data_folder"] = output_dir
|
| 51 |
case "page_range":
|
| 52 |
-
|
| 53 |
-
config["page_range"] = parse_range_str(v)
|
| 54 |
case "force_ocr":
|
| 55 |
-
|
| 56 |
-
config["force_ocr"] = True
|
| 57 |
case "languages":
|
| 58 |
-
|
| 59 |
-
config["languages"] = v.split(",")
|
| 60 |
case "config_json":
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
config.update(json.load(f))
|
| 64 |
case "disable_multiprocessing":
|
| 65 |
-
|
| 66 |
-
config["pdftext_workers"] = 1
|
| 67 |
case "paginate_output":
|
| 68 |
-
|
| 69 |
-
config["paginate_output"] = True
|
| 70 |
case "disable_image_extraction":
|
| 71 |
-
|
| 72 |
-
config["extract_images"] = False
|
| 73 |
case "high_quality":
|
| 74 |
-
|
| 75 |
-
config["high_quality"] = True
|
| 76 |
return config
|
| 77 |
|
| 78 |
def get_renderer(self):
|
|
|
|
| 34 |
fn = click.option("--disable_multiprocessing", is_flag=True, default=False, help="Disable multiprocessing.")(fn)
|
| 35 |
fn = click.option("--paginate_output", is_flag=True, default=False, help="Paginate output.")(fn)
|
| 36 |
fn = click.option("--disable_image_extraction", is_flag=True, default=False, help="Disable image extraction.")(fn)
|
| 37 |
+
fn = click.option("--high_quality", is_flag=True, default=False, help="Enable high quality processing with LLMs.")(fn)
|
| 38 |
return fn
|
| 39 |
|
| 40 |
def generate_config_dict(self) -> Dict[str, any]:
|
| 41 |
config = {}
|
| 42 |
output_dir = self.cli_options.get("output_dir", settings.OUTPUT_DIR)
|
| 43 |
for k, v in self.cli_options.items():
|
| 44 |
+
if not v:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
match k:
|
| 48 |
case "debug":
|
| 49 |
+
config["debug_pdf_images"] = True
|
| 50 |
+
config["debug_layout_images"] = True
|
| 51 |
+
config["debug_json"] = True
|
| 52 |
+
config["debug_data_folder"] = output_dir
|
|
|
|
| 53 |
case "page_range":
|
| 54 |
+
config["page_range"] = parse_range_str(v)
|
|
|
|
| 55 |
case "force_ocr":
|
| 56 |
+
config["force_ocr"] = True
|
|
|
|
| 57 |
case "languages":
|
| 58 |
+
config["languages"] = v.split(",")
|
|
|
|
| 59 |
case "config_json":
|
| 60 |
+
with open(v, "r") as f:
|
| 61 |
+
config.update(json.load(f))
|
|
|
|
| 62 |
case "disable_multiprocessing":
|
| 63 |
+
config["pdftext_workers"] = 1
|
|
|
|
| 64 |
case "paginate_output":
|
| 65 |
+
config["paginate_output"] = True
|
|
|
|
| 66 |
case "disable_image_extraction":
|
| 67 |
+
config["extract_images"] = False
|
|
|
|
| 68 |
case "high_quality":
|
| 69 |
+
config["high_quality"] = True
|
|
|
|
| 70 |
return config
|
| 71 |
|
| 72 |
def get_renderer(self):
|
marker/converters/pdf.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 3 |
|
| 4 |
import inspect
|
| 5 |
from collections import defaultdict
|
|
@@ -17,6 +17,8 @@ from marker.processors.debug import DebugProcessor
|
|
| 17 |
from marker.processors.document_toc import DocumentTOCProcessor
|
| 18 |
from marker.processors.equation import EquationProcessor
|
| 19 |
from marker.processors.footnote import FootnoteProcessor
|
|
|
|
|
|
|
| 20 |
from marker.processors.high_quality_text import HighQualityTextProcessor
|
| 21 |
from marker.processors.ignoretext import IgnoreTextProcessor
|
| 22 |
from marker.processors.line_numbers import LineNumbersProcessor
|
|
@@ -68,6 +70,8 @@ class PdfConverter(BaseConverter):
|
|
| 68 |
PageHeaderProcessor,
|
| 69 |
SectionHeaderProcessor,
|
| 70 |
TableProcessor,
|
|
|
|
|
|
|
| 71 |
TextProcessor,
|
| 72 |
HighQualityTextProcessor,
|
| 73 |
DebugProcessor,
|
|
|
|
| 1 |
import os
|
| 2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
|
| 3 |
|
| 4 |
import inspect
|
| 5 |
from collections import defaultdict
|
|
|
|
| 17 |
from marker.processors.document_toc import DocumentTOCProcessor
|
| 18 |
from marker.processors.equation import EquationProcessor
|
| 19 |
from marker.processors.footnote import FootnoteProcessor
|
| 20 |
+
from marker.processors.llm.highqualityformprocessor import HighQualityFormProcessor
|
| 21 |
+
from marker.processors.llm.highqualitytableprocessor import HighQualityTableProcessor
|
| 22 |
from marker.processors.high_quality_text import HighQualityTextProcessor
|
| 23 |
from marker.processors.ignoretext import IgnoreTextProcessor
|
| 24 |
from marker.processors.line_numbers import LineNumbersProcessor
|
|
|
|
| 70 |
PageHeaderProcessor,
|
| 71 |
SectionHeaderProcessor,
|
| 72 |
TableProcessor,
|
| 73 |
+
HighQualityTableProcessor,
|
| 74 |
+
HighQualityFormProcessor,
|
| 75 |
TextProcessor,
|
| 76 |
HighQualityTextProcessor,
|
| 77 |
DebugProcessor,
|
marker/llm.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import PIL
|
| 5 |
+
import google.generativeai as genai
|
| 6 |
+
from google.ai.generativelanguage_v1beta.types import content
|
| 7 |
+
from google.api_core.exceptions import ResourceExhausted
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GoogleModel:
|
| 11 |
+
def __init__(self, api_key: str, model_name: str):
|
| 12 |
+
if api_key is None:
|
| 13 |
+
raise ValueError("Google API key is not set")
|
| 14 |
+
|
| 15 |
+
self.api_key = api_key
|
| 16 |
+
self.model_name = model_name
|
| 17 |
+
self.model = self.configure_google_model()
|
| 18 |
+
|
| 19 |
+
def configure_google_model(self):
|
| 20 |
+
genai.configure(api_key=self.api_key)
|
| 21 |
+
return genai.GenerativeModel(self.model_name)
|
| 22 |
+
|
| 23 |
+
def generate_response(
|
| 24 |
+
self,
|
| 25 |
+
prompt: str,
|
| 26 |
+
image: PIL.Image.Image,
|
| 27 |
+
response_schema: content.Schema,
|
| 28 |
+
max_retries: int = 3,
|
| 29 |
+
timeout: int = 60
|
| 30 |
+
):
|
| 31 |
+
tries = 0
|
| 32 |
+
while tries < max_retries:
|
| 33 |
+
try:
|
| 34 |
+
responses = self.model.generate_content(
|
| 35 |
+
[prompt, image],
|
| 36 |
+
stream=False,
|
| 37 |
+
generation_config={
|
| 38 |
+
"temperature": 0,
|
| 39 |
+
"response_schema": response_schema,
|
| 40 |
+
"response_mime_type": "application/json",
|
| 41 |
+
},
|
| 42 |
+
request_options={'timeout': timeout}
|
| 43 |
+
)
|
| 44 |
+
output = responses.candidates[0].content.parts[0].text
|
| 45 |
+
return json.loads(output)
|
| 46 |
+
except ResourceExhausted as e:
|
| 47 |
+
tries += 1
|
| 48 |
+
wait_time = tries * 3
|
| 49 |
+
print(f"ResourceExhausted: {e}. Retrying in {wait_time} seconds... (Attempt {tries}/{max_retries})")
|
| 50 |
+
time.sleep(wait_time)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(e)
|
| 53 |
+
break
|
| 54 |
+
|
| 55 |
+
return {}
|
marker/processors/llm/__init__.py
ADDED
|
File without changes
|
marker/processors/llm/highqualityformprocessor.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import markdown2
|
| 2 |
+
|
| 3 |
+
from marker.llm import GoogleModel
|
| 4 |
+
from marker.processors import BaseProcessor
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from google.ai.generativelanguage_v1beta.types import content
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from tabled.formats import markdown_format
|
| 11 |
+
|
| 12 |
+
from marker.schema import BlockTypes
|
| 13 |
+
from marker.schema.blocks import Block
|
| 14 |
+
from marker.schema.document import Document
|
| 15 |
+
from marker.schema.groups.page import PageGroup
|
| 16 |
+
from marker.settings import settings
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class HighQualityFormProcessor(BaseProcessor):
|
| 20 |
+
"""
|
| 21 |
+
A processor for converting form blocks in a document to markdown.
|
| 22 |
+
Attributes:
|
| 23 |
+
google_api_key (str):
|
| 24 |
+
The Google API key to use for the Gemini model.
|
| 25 |
+
Default is None.
|
| 26 |
+
model_name (str):
|
| 27 |
+
The name of the Gemini model to use.
|
| 28 |
+
Default is "gemini-1.5-flash".
|
| 29 |
+
max_retries (int):
|
| 30 |
+
The maximum number of retries to use for the Gemini model.
|
| 31 |
+
Default is 3.
|
| 32 |
+
max_concurrency (int):
|
| 33 |
+
The maximum number of concurrent requests to make to the Gemini model.
|
| 34 |
+
Default is 3.
|
| 35 |
+
timeout (int):
|
| 36 |
+
The timeout for requests to the Gemini model.
|
| 37 |
+
gemini_rewriting_prompt (str):
|
| 38 |
+
The prompt to use for rewriting text.
|
| 39 |
+
Default is a string containing the Gemini rewriting prompt.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
block_types = (BlockTypes.Form,)
|
| 43 |
+
google_api_key: Optional[str] = settings.GOOGLE_API_KEY
|
| 44 |
+
model_name: str = "gemini-1.5-flash"
|
| 45 |
+
high_quality: bool = False
|
| 46 |
+
max_retries: int = 3
|
| 47 |
+
max_concurrency: int = 3
|
| 48 |
+
timeout: int = 60
|
| 49 |
+
|
| 50 |
+
gemini_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
|
| 51 |
+
You will receive an image of a text block and a markdown representation of the form in the image.
|
| 52 |
+
Your task is to correct any errors in the markdown representation, and format it properly.
|
| 53 |
+
Values and labels should appear in markdown tables, with the labels on the left side, and values on the right. The headers should be "Labels" and "Values". Other text in the form can appear between the tables.
|
| 54 |
+
**Instructions:**
|
| 55 |
+
1. Carefully examine the provided form block image.
|
| 56 |
+
2. Analyze the markdown representation of the form.
|
| 57 |
+
3. If the markdown representation is largely correct, then write "No corrections needed."
|
| 58 |
+
4. If the markdown representation contains errors, generate the corrected markdown representation.
|
| 59 |
+
5. Output only either the corrected markdown representation or "No corrections needed."
|
| 60 |
+
**Example:**
|
| 61 |
+
Input:
|
| 62 |
+
```markdown
|
| 63 |
+
| Label 1 | Label 2 | Label 3 |
|
| 64 |
+
|----------|----------|----------|
|
| 65 |
+
| Value 1 | Value 2 | Value 3 |
|
| 66 |
+
```
|
| 67 |
+
Output:
|
| 68 |
+
```markdown
|
| 69 |
+
| Labels | Values |
|
| 70 |
+
|--------|--------|
|
| 71 |
+
| Label 1 | Value 1 |
|
| 72 |
+
| Label 2 | Value 2 |
|
| 73 |
+
| Label 3 | Value 3 |
|
| 74 |
+
```
|
| 75 |
+
**Input:**
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, config=None):
|
| 79 |
+
super().__init__(config)
|
| 80 |
+
|
| 81 |
+
self.model = None
|
| 82 |
+
if not self.high_quality:
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
self.model = GoogleModel(self.google_api_key, self.model_name)
|
| 86 |
+
|
| 87 |
+
def __call__(self, document: Document):
|
| 88 |
+
if not self.high_quality or self.model is None:
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
self.rewrite_blocks(document)
|
| 92 |
+
|
| 93 |
+
def rewrite_blocks(self, document: Document):
|
| 94 |
+
pbar = tqdm(desc="High quality form processor")
|
| 95 |
+
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
|
| 96 |
+
for future in as_completed([
|
| 97 |
+
executor.submit(self.process_rewriting, page, block)
|
| 98 |
+
for page in document.pages
|
| 99 |
+
for block in page.contained_blocks(document, self.block_types)
|
| 100 |
+
]):
|
| 101 |
+
future.result() # Raise exceptions if any occurred
|
| 102 |
+
pbar.update(1)
|
| 103 |
+
|
| 104 |
+
pbar.close()
|
| 105 |
+
|
| 106 |
+
def process_rewriting(self, page: PageGroup, block: Block):
|
| 107 |
+
cells = block.cells
|
| 108 |
+
if cells is None:
|
| 109 |
+
# Happens if table/form processors didn't run
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
prompt = self.gemini_rewriting_prompt + '```markdown\n`' + markdown_format(cells) + '`\n```\n'
|
| 113 |
+
image = self.extract_image(page, block)
|
| 114 |
+
response_schema = content.Schema(
|
| 115 |
+
type=content.Type.OBJECT,
|
| 116 |
+
enum=[],
|
| 117 |
+
required=["corrected_markdown"],
|
| 118 |
+
properties={
|
| 119 |
+
"corrected_markdown": content.Schema(
|
| 120 |
+
type=content.Type.STRING
|
| 121 |
+
)
|
| 122 |
+
},
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
response = self.model.generate_response(prompt, image, response_schema)
|
| 126 |
+
|
| 127 |
+
if not response or "corrected_markdown" not in response:
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
corrected_markdown = response["corrected_markdown"]
|
| 131 |
+
|
| 132 |
+
# The original table is okay
|
| 133 |
+
if "no corrections" in corrected_markdown.lower():
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
orig_cell_text = "".join([cell.text for cell in cells])
|
| 137 |
+
|
| 138 |
+
# Potentially a partial response
|
| 139 |
+
if len(corrected_markdown) < len(orig_cell_text) * .5:
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
# Convert LLM markdown to html
|
| 143 |
+
block.html = markdown2.markdown(corrected_markdown)
|
| 144 |
+
|
| 145 |
+
def extract_image(self, page: PageGroup, image_block: Block, expand: float = 0.01):
|
| 146 |
+
page_img = page.lowres_image
|
| 147 |
+
image_box = image_block.polygon\
|
| 148 |
+
.rescale(page.polygon.size, page_img.size)\
|
| 149 |
+
.expand(expand, expand)
|
| 150 |
+
cropped = page_img.crop(image_box.bbox)
|
| 151 |
+
return cropped
|
marker/processors/llm/highqualitytableprocessor.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tabled.schema import SpanTableCell
|
| 2 |
+
|
| 3 |
+
from marker.llm import GoogleModel
|
| 4 |
+
from marker.processors import BaseProcessor
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
|
| 8 |
+
from google.ai.generativelanguage_v1beta.types import content
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from tabled.formats import markdown_format
|
| 11 |
+
|
| 12 |
+
from marker.schema import BlockTypes
|
| 13 |
+
from marker.schema.blocks import Block
|
| 14 |
+
from marker.schema.document import Document
|
| 15 |
+
from marker.schema.groups.page import PageGroup
|
| 16 |
+
from marker.schema.polygon import PolygonBox
|
| 17 |
+
from marker.settings import settings
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class HighQualityTableProcessor(BaseProcessor):
|
| 21 |
+
"""
|
| 22 |
+
A processor for converting table blocks in a document to markdown.
|
| 23 |
+
Attributes:
|
| 24 |
+
google_api_key (str):
|
| 25 |
+
The Google API key to use for the Gemini model.
|
| 26 |
+
Default is None.
|
| 27 |
+
model_name (str):
|
| 28 |
+
The name of the Gemini model to use.
|
| 29 |
+
Default is "gemini-1.5-flash".
|
| 30 |
+
max_retries (int):
|
| 31 |
+
The maximum number of retries to use for the Gemini model.
|
| 32 |
+
Default is 3.
|
| 33 |
+
max_concurrency (int):
|
| 34 |
+
The maximum number of concurrent requests to make to the Gemini model.
|
| 35 |
+
Default is 3.
|
| 36 |
+
timeout (int):
|
| 37 |
+
The timeout for requests to the Gemini model.
|
| 38 |
+
gemini_rewriting_prompt (str):
|
| 39 |
+
The prompt to use for rewriting text.
|
| 40 |
+
Default is a string containing the Gemini rewriting prompt.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
block_types = (BlockTypes.Table,)
|
| 44 |
+
google_api_key: Optional[str] = settings.GOOGLE_API_KEY
|
| 45 |
+
model_name: str = "gemini-1.5-flash"
|
| 46 |
+
high_quality: bool = False
|
| 47 |
+
max_retries: int = 3
|
| 48 |
+
max_concurrency: int = 3
|
| 49 |
+
timeout: int = 60
|
| 50 |
+
|
| 51 |
+
gemini_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
|
| 52 |
+
You will receive an image of a text block and a markdown representation of the table in the image.
|
| 53 |
+
Your task is to correct any errors in the markdown representation. The markdown representation should be as faithful to the original table as possible.
|
| 54 |
+
**Instructions:**
|
| 55 |
+
1. Carefully examine the provided text block image.
|
| 56 |
+
2. Analyze the markdown representation of the table.
|
| 57 |
+
3. If the markdown representation is largely correct, then write "No corrections needed."
|
| 58 |
+
4. If the markdown representation contains errors, generate the corrected markdown representation.
|
| 59 |
+
5. Output only either the corrected markdown representation or "No corrections needed."
|
| 60 |
+
**Example:**
|
| 61 |
+
Input:
|
| 62 |
+
```markdown
|
| 63 |
+
| Column 1 | Column 2 | Column 3 |
|
| 64 |
+
|----------|----------|----------|
|
| 65 |
+
| Value 1 | Value 2 | Value 3 |
|
| 66 |
+
```
|
| 67 |
+
Output:
|
| 68 |
+
```markdown
|
| 69 |
+
No corrections needed.
|
| 70 |
+
```
|
| 71 |
+
**Input:**
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, config=None):
|
| 75 |
+
super().__init__(config)
|
| 76 |
+
|
| 77 |
+
self.model = None
|
| 78 |
+
if not self.high_quality:
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
self.model = GoogleModel(self.google_api_key, self.model_name)
|
| 82 |
+
|
| 83 |
+
def __call__(self, document: Document):
|
| 84 |
+
if not self.high_quality or self.model is None:
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
self.rewrite_blocks(document)
|
| 88 |
+
|
| 89 |
+
def rewrite_blocks(self, document: Document):
|
| 90 |
+
pbar = tqdm(desc="High quality table processor")
|
| 91 |
+
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
|
| 92 |
+
for future in as_completed([
|
| 93 |
+
executor.submit(self.process_rewriting, page, block)
|
| 94 |
+
for page in document.pages
|
| 95 |
+
for block in page.contained_blocks(document, self.block_types)
|
| 96 |
+
]):
|
| 97 |
+
future.result() # Raise exceptions if any occurred
|
| 98 |
+
pbar.update(1)
|
| 99 |
+
|
| 100 |
+
pbar.close()
|
| 101 |
+
|
| 102 |
+
def process_rewriting(self, page: PageGroup, block: Block):
|
| 103 |
+
cells = block.cells
|
| 104 |
+
if cells is None:
|
| 105 |
+
# Happens if table/form processors didn't run
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
prompt = self.gemini_rewriting_prompt + '```markdown\n`' + markdown_format(cells) + '`\n```\n'
|
| 109 |
+
image = self.extract_image(page, block)
|
| 110 |
+
response_schema = content.Schema(
|
| 111 |
+
type=content.Type.OBJECT,
|
| 112 |
+
enum=[],
|
| 113 |
+
required=["corrected_markdown"],
|
| 114 |
+
properties={
|
| 115 |
+
"corrected_markdown": content.Schema(
|
| 116 |
+
type=content.Type.STRING
|
| 117 |
+
)
|
| 118 |
+
},
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
response = self.model.generate_response(prompt, image, response_schema)
|
| 122 |
+
|
| 123 |
+
if not response or "corrected_markdown" not in response:
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
corrected_markdown = response["corrected_markdown"]
|
| 127 |
+
|
| 128 |
+
# The original table is okay
|
| 129 |
+
if "no corrections" in corrected_markdown.lower():
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
parsed_cells = self.parse_markdown_table(corrected_markdown, block)
|
| 133 |
+
if len(parsed_cells) <= 1:
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
parsed_cell_text = "".join([cell.text for cell in parsed_cells])
|
| 137 |
+
orig_cell_text = "".join([cell.text for cell in cells])
|
| 138 |
+
|
| 139 |
+
# Potentially a partial response
|
| 140 |
+
if len(parsed_cell_text) < len(orig_cell_text) * .5:
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
block.cells = parsed_cells
|
| 145 |
+
|
| 146 |
+
def extract_image(self, page: PageGroup, image_block: Block, expand: float = 0.01):
|
| 147 |
+
page_img = page.lowres_image
|
| 148 |
+
image_box = image_block.polygon\
|
| 149 |
+
.rescale(page.polygon.size, page_img.size)\
|
| 150 |
+
.expand(expand, expand)
|
| 151 |
+
cropped = page_img.crop(image_box.bbox)
|
| 152 |
+
return cropped
|
| 153 |
+
|
| 154 |
+
def parse_markdown_table(self, markdown_text: str, block: Block) -> List[SpanTableCell]:
|
| 155 |
+
lines = [line.strip() for line in markdown_text.splitlines() if line.strip()]
|
| 156 |
+
|
| 157 |
+
# Remove separator row for headers
|
| 158 |
+
lines = [line for line in lines if not line.replace('|', ' ').replace('-', ' ').isspace()]
|
| 159 |
+
|
| 160 |
+
rows = []
|
| 161 |
+
for line in lines:
|
| 162 |
+
# Remove leading/trailing pipes and split by remaining pipes
|
| 163 |
+
cells = line.strip('|').split('|')
|
| 164 |
+
# Clean whitespace from each cell
|
| 165 |
+
cells = [cell.strip() for cell in cells]
|
| 166 |
+
rows.append(cells)
|
| 167 |
+
|
| 168 |
+
cells = []
|
| 169 |
+
for i, row in enumerate(rows):
|
| 170 |
+
for j, cell in enumerate(row):
|
| 171 |
+
cell_bbox = [
|
| 172 |
+
block.polygon.bbox[0] + j,
|
| 173 |
+
block.polygon.bbox[1] + i,
|
| 174 |
+
block.polygon.bbox[0] + j + 1,
|
| 175 |
+
block.polygon.bbox[1] + i + 1
|
| 176 |
+
]
|
| 177 |
+
cell_polygon = PolygonBox.from_bbox(cell_bbox)
|
| 178 |
+
cells.append(
|
| 179 |
+
SpanTableCell(
|
| 180 |
+
text=cell,
|
| 181 |
+
row_ids=[i],
|
| 182 |
+
col_ids=[j],
|
| 183 |
+
bbox=cell_polygon.bbox
|
| 184 |
+
)
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
return cells
|
marker/schema/blocks/form.py
CHANGED
|
@@ -10,6 +10,11 @@ from marker.schema.blocks import Block
|
|
| 10 |
class Form(Block):
|
| 11 |
block_type: str = BlockTypes.Form
|
| 12 |
cells: List[SpanTableCell] | None = None
|
|
|
|
| 13 |
|
| 14 |
def assemble_html(self, child_blocks, parent_structure=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
return str(html_format(self.cells))
|
|
|
|
| 10 |
class Form(Block):
|
| 11 |
block_type: str = BlockTypes.Form
|
| 12 |
cells: List[SpanTableCell] | None = None
|
| 13 |
+
html: str | None = None
|
| 14 |
|
| 15 |
def assemble_html(self, child_blocks, parent_structure=None):
|
| 16 |
+
# Some processors convert the form to html
|
| 17 |
+
if self.html is not None:
|
| 18 |
+
return self.html
|
| 19 |
+
|
| 20 |
return str(html_format(self.cells))
|
marker/settings.py
CHANGED
|
@@ -18,6 +18,9 @@ class Settings(BaseSettings):
|
|
| 18 |
OUTPUT_ENCODING: str = "utf-8"
|
| 19 |
OUTPUT_IMAGE_FORMAT: str = "JPEG"
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
# General models
|
| 22 |
TORCH_DEVICE: Optional[str] = None # Note: MPS device does not work for text detection, and will default to CPU
|
| 23 |
GOOGLE_API_KEY: Optional[str] = None
|
|
|
|
| 18 |
OUTPUT_ENCODING: str = "utf-8"
|
| 19 |
OUTPUT_IMAGE_FORMAT: str = "JPEG"
|
| 20 |
|
| 21 |
+
# LLM
|
| 22 |
+
GOOGLE_API_KEY: Optional[str] = None
|
| 23 |
+
|
| 24 |
# General models
|
| 25 |
TORCH_DEVICE: Optional[str] = None # Note: MPS device does not work for text detection, and will default to CPU
|
| 26 |
GOOGLE_API_KEY: Optional[str] = None
|
pyproject.toml
CHANGED
|
@@ -40,6 +40,7 @@ tabled-pdf = "~0.2.0"
|
|
| 40 |
markdownify = "^0.13.1"
|
| 41 |
click = "^8.1.7"
|
| 42 |
google-generativeai = "^0.8.3"
|
|
|
|
| 43 |
|
| 44 |
[tool.poetry.group.dev.dependencies]
|
| 45 |
jupyter = "^1.0.0"
|
|
|
|
| 40 |
markdownify = "^0.13.1"
|
| 41 |
click = "^8.1.7"
|
| 42 |
google-generativeai = "^0.8.3"
|
| 43 |
+
markdown2 = "^2.5.2"
|
| 44 |
|
| 45 |
[tool.poetry.group.dev.dependencies]
|
| 46 |
jupyter = "^1.0.0"
|