Vik Paruchuri
commited on
Commit
·
ed65502
1
Parent(s):
e6cc383
Fix ocr converter
Browse files- README.md +21 -0
- marker/builders/ocr.py +4 -2
- marker/converters/ocr.py +44 -0
- marker/output.py +3 -0
- marker/providers/pdf.py +1 -1
- marker/renderers/ocr_json.py +128 -0
- marker/schema/groups/page.py +8 -2
- marker/schema/text/char.py +1 -1
- tests/conftest.py +16 -6
- tests/converters/test_ocr_converter.py +19 -0
README.md
CHANGED
|
@@ -227,6 +227,27 @@ You can also run this via the CLI with
|
|
| 227 |
marker_single FILENAME --use_llm --force_layout_block Table --converter_cls marker.converters.table.TableConverter --output_format json
|
| 228 |
```
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
# Output Formats
|
| 231 |
|
| 232 |
## Markdown
|
|
|
|
| 227 |
marker_single FILENAME --use_llm --force_layout_block Table --converter_cls marker.converters.table.TableConverter --output_format json
|
| 228 |
```
|
| 229 |
|
| 230 |
+
### OCR Only
|
| 231 |
+
|
| 232 |
+
If you only want to run OCR, you can also do that through the `OCRConverter`.
|
| 233 |
+
|
| 234 |
+
```python
|
| 235 |
+
from marker.converters.ocr import OCRConverter
|
| 236 |
+
from marker.models import create_model_dict
|
| 237 |
+
|
| 238 |
+
converter = OCRConverter(
|
| 239 |
+
artifact_dict=create_model_dict(),
|
| 240 |
+
)
|
| 241 |
+
rendered = converter("FILEPATH")
|
| 242 |
+
```
|
| 243 |
+
|
| 244 |
+
This takes all the same configuration as the PdfConverter.
|
| 245 |
+
|
| 246 |
+
You can also run this via the CLI with
|
| 247 |
+
```shell
|
| 248 |
+
marker_single FILENAME --converter_cls marker.converters.ocr.OCRConverter
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
# Output Formats
|
| 252 |
|
| 253 |
## Markdown
|
marker/builders/ocr.py
CHANGED
|
@@ -171,10 +171,12 @@ class OcrBuilder(BaseBuilder):
|
|
| 171 |
before_span, after_span = None, None
|
| 172 |
if before_text:
|
| 173 |
before_span = copy.deepcopy(span)
|
|
|
|
| 174 |
before_span.text = before_text
|
| 175 |
if after_text:
|
| 176 |
after_span = copy.deepcopy(span)
|
| 177 |
after_span.text = after_text
|
|
|
|
| 178 |
|
| 179 |
match_span = copy.deepcopy(span)
|
| 180 |
match_span.text = match_text
|
|
@@ -214,7 +216,6 @@ class OcrBuilder(BaseBuilder):
|
|
| 214 |
if not matched:
|
| 215 |
remaining_span = copy.deepcopy(original_span)
|
| 216 |
remaining_span.text = remaining_text
|
| 217 |
-
remaining_span.structure = []
|
| 218 |
final_new_spans.append(remaining_span)
|
| 219 |
break
|
| 220 |
|
|
@@ -287,10 +288,11 @@ class OcrBuilder(BaseBuilder):
|
|
| 287 |
current_span.html = (
|
| 288 |
f'<math display="inline">{current_span.text}</math>'
|
| 289 |
)
|
|
|
|
|
|
|
| 290 |
spans.append(current_span)
|
| 291 |
current_span = None
|
| 292 |
|
| 293 |
-
current_chars = self.assign_chars(current_span, current_chars)
|
| 294 |
continue
|
| 295 |
|
| 296 |
if not current_span:
|
|
|
|
| 171 |
before_span, after_span = None, None
|
| 172 |
if before_text:
|
| 173 |
before_span = copy.deepcopy(span)
|
| 174 |
+
before_span.structure = [] # Avoid duplicate characters
|
| 175 |
before_span.text = before_text
|
| 176 |
if after_text:
|
| 177 |
after_span = copy.deepcopy(span)
|
| 178 |
after_span.text = after_text
|
| 179 |
+
after_span.structure = [] # Avoid duplicate characters
|
| 180 |
|
| 181 |
match_span = copy.deepcopy(span)
|
| 182 |
match_span.text = match_text
|
|
|
|
| 216 |
if not matched:
|
| 217 |
remaining_span = copy.deepcopy(original_span)
|
| 218 |
remaining_span.text = remaining_text
|
|
|
|
| 219 |
final_new_spans.append(remaining_span)
|
| 220 |
break
|
| 221 |
|
|
|
|
| 288 |
current_span.html = (
|
| 289 |
f'<math display="inline">{current_span.text}</math>'
|
| 290 |
)
|
| 291 |
+
|
| 292 |
+
current_chars = self.assign_chars(current_span, current_chars)
|
| 293 |
spans.append(current_span)
|
| 294 |
current_span = None
|
| 295 |
|
|
|
|
| 296 |
continue
|
| 297 |
|
| 298 |
if not current_span:
|
marker/converters/ocr.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
from marker.builders.document import DocumentBuilder
|
| 4 |
+
from marker.builders.line import LineBuilder
|
| 5 |
+
from marker.builders.ocr import OcrBuilder
|
| 6 |
+
from marker.converters.pdf import PdfConverter
|
| 7 |
+
from marker.processors import BaseProcessor
|
| 8 |
+
from marker.processors.equation import EquationProcessor
|
| 9 |
+
from marker.providers.registry import provider_from_filepath
|
| 10 |
+
from marker.renderers.ocr_json import OCRJSONRenderer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class OCRConverter(PdfConverter):
|
| 14 |
+
default_processors: Tuple[BaseProcessor, ...] = (EquationProcessor,)
|
| 15 |
+
|
| 16 |
+
def __init__(self, *args, **kwargs):
|
| 17 |
+
super().__init__(*args, **kwargs)
|
| 18 |
+
|
| 19 |
+
if not self.config:
|
| 20 |
+
self.config = {}
|
| 21 |
+
|
| 22 |
+
self.config["format_lines"] = True
|
| 23 |
+
self.config["keep_chars"] = True
|
| 24 |
+
self.renderer = OCRJSONRenderer
|
| 25 |
+
|
| 26 |
+
def build_document(self, filepath: str):
|
| 27 |
+
provider_cls = provider_from_filepath(filepath)
|
| 28 |
+
layout_builder = self.resolve_dependencies(self.layout_builder_class)
|
| 29 |
+
line_builder = self.resolve_dependencies(LineBuilder)
|
| 30 |
+
ocr_builder = self.resolve_dependencies(OcrBuilder)
|
| 31 |
+
document_builder = DocumentBuilder(self.config)
|
| 32 |
+
|
| 33 |
+
provider = provider_cls(filepath, self.config)
|
| 34 |
+
document = document_builder(provider, layout_builder, line_builder, ocr_builder)
|
| 35 |
+
|
| 36 |
+
for processor in self.processor_list:
|
| 37 |
+
processor(document)
|
| 38 |
+
|
| 39 |
+
return document
|
| 40 |
+
|
| 41 |
+
def __call__(self, filepath: str):
|
| 42 |
+
document = self.build_document(filepath)
|
| 43 |
+
renderer = self.resolve_dependencies(self.renderer)
|
| 44 |
+
return renderer(document)
|
marker/output.py
CHANGED
|
@@ -8,6 +8,7 @@ from PIL import Image
|
|
| 8 |
from marker.renderers.html import HTMLOutput
|
| 9 |
from marker.renderers.json import JSONOutput, JSONBlockOutput
|
| 10 |
from marker.renderers.markdown import MarkdownOutput
|
|
|
|
| 11 |
from marker.schema.blocks import BlockOutput
|
| 12 |
from marker.settings import settings
|
| 13 |
|
|
@@ -57,6 +58,8 @@ def text_from_rendered(rendered: BaseModel):
|
|
| 57 |
return rendered.html, "html", rendered.images
|
| 58 |
elif isinstance(rendered, JSONOutput):
|
| 59 |
return rendered.model_dump_json(exclude=["metadata"], indent=2), "json", {}
|
|
|
|
|
|
|
| 60 |
else:
|
| 61 |
raise ValueError("Invalid output type")
|
| 62 |
|
|
|
|
| 8 |
from marker.renderers.html import HTMLOutput
|
| 9 |
from marker.renderers.json import JSONOutput, JSONBlockOutput
|
| 10 |
from marker.renderers.markdown import MarkdownOutput
|
| 11 |
+
from marker.renderers.ocr_json import OCRJSONOutput
|
| 12 |
from marker.schema.blocks import BlockOutput
|
| 13 |
from marker.settings import settings
|
| 14 |
|
|
|
|
| 58 |
return rendered.html, "html", rendered.images
|
| 59 |
elif isinstance(rendered, JSONOutput):
|
| 60 |
return rendered.model_dump_json(exclude=["metadata"], indent=2), "json", {}
|
| 61 |
+
elif isinstance(rendered, OCRJSONOutput):
|
| 62 |
+
return rendered.model_dump_json(exclude=["metadata"], indent=2), "json", {}
|
| 63 |
else:
|
| 64 |
raise ValueError("Invalid output type")
|
| 65 |
|
marker/providers/pdf.py
CHANGED
|
@@ -239,7 +239,7 @@ class PdfProvider(BaseProvider):
|
|
| 239 |
)
|
| 240 |
span_chars = [
|
| 241 |
CharClass(
|
| 242 |
-
|
| 243 |
polygon=PolygonBox.from_bbox(
|
| 244 |
c["bbox"], ensure_nonzero_area=True
|
| 245 |
),
|
|
|
|
| 239 |
)
|
| 240 |
span_chars = [
|
| 241 |
CharClass(
|
| 242 |
+
text=c["char"],
|
| 243 |
polygon=PolygonBox.from_bbox(
|
| 244 |
c["bbox"], ensure_nonzero_area=True
|
| 245 |
),
|
marker/renderers/ocr_json.py
CHANGED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Annotated, List, Tuple
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
from marker.renderers import BaseRenderer
|
| 6 |
+
from marker.schema import BlockTypes
|
| 7 |
+
from marker.schema.document import Document
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class OCRJSONCharOutput(BaseModel):
|
| 11 |
+
id: str
|
| 12 |
+
block_type: str
|
| 13 |
+
text: str
|
| 14 |
+
polygon: List[List[float]]
|
| 15 |
+
bbox: List[float]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OCRJSONLineOutput(BaseModel):
|
| 19 |
+
id: str
|
| 20 |
+
block_type: str
|
| 21 |
+
html: str
|
| 22 |
+
polygon: List[List[float]]
|
| 23 |
+
bbox: List[float]
|
| 24 |
+
children: List["OCRJSONCharOutput"] | None = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class OCRJSONPageOutput(BaseModel):
|
| 28 |
+
id: str
|
| 29 |
+
block_type: str
|
| 30 |
+
polygon: List[List[float]]
|
| 31 |
+
bbox: List[float]
|
| 32 |
+
children: List[OCRJSONLineOutput] | None = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class OCRJSONOutput(BaseModel):
|
| 36 |
+
children: List[OCRJSONPageOutput]
|
| 37 |
+
block_type: str = str(BlockTypes.Document)
|
| 38 |
+
metadata: dict | None = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class OCRJSONRenderer(BaseRenderer):
|
| 42 |
+
"""
|
| 43 |
+
A renderer for OCR JSON output.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
image_blocks: Annotated[
|
| 47 |
+
Tuple[BlockTypes],
|
| 48 |
+
"The list of block types to consider as images.",
|
| 49 |
+
] = (BlockTypes.Picture, BlockTypes.Figure)
|
| 50 |
+
page_blocks: Annotated[
|
| 51 |
+
Tuple[BlockTypes],
|
| 52 |
+
"The list of block types to consider as pages.",
|
| 53 |
+
] = (BlockTypes.Page,)
|
| 54 |
+
|
| 55 |
+
def extract_json(self, document: Document) -> List[OCRJSONPageOutput]:
|
| 56 |
+
pages = []
|
| 57 |
+
for page in document.pages:
|
| 58 |
+
page_equations = [
|
| 59 |
+
b for b in page.children if b.block_type == BlockTypes.Equation
|
| 60 |
+
]
|
| 61 |
+
equation_lines = []
|
| 62 |
+
for equation in page_equations:
|
| 63 |
+
if not equation.structure:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
equation_lines += [
|
| 67 |
+
line
|
| 68 |
+
for line in equation.structure
|
| 69 |
+
if line.block_type == BlockTypes.Line
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
page_lines = [
|
| 73 |
+
block
|
| 74 |
+
for block in page.children
|
| 75 |
+
if block.block_type == BlockTypes.Line
|
| 76 |
+
and block.id not in equation_lines
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
lines = []
|
| 80 |
+
for line in page_lines + page_equations:
|
| 81 |
+
line_obj = OCRJSONLineOutput(
|
| 82 |
+
id=str(line.id),
|
| 83 |
+
block_type=str(line.block_type),
|
| 84 |
+
html="",
|
| 85 |
+
polygon=line.polygon.polygon,
|
| 86 |
+
bbox=line.polygon.bbox,
|
| 87 |
+
)
|
| 88 |
+
if line in page_equations:
|
| 89 |
+
line_obj.html = line.html
|
| 90 |
+
else:
|
| 91 |
+
line_obj.html = line.formatted_text(document)
|
| 92 |
+
spans = [document.get_block(span_id) for span_id in line.structure]
|
| 93 |
+
children = []
|
| 94 |
+
for span in spans:
|
| 95 |
+
if not span.structure:
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
span_chars = [
|
| 99 |
+
document.get_block(char_id) for char_id in span.structure
|
| 100 |
+
]
|
| 101 |
+
children.extend(
|
| 102 |
+
[
|
| 103 |
+
OCRJSONCharOutput(
|
| 104 |
+
id=str(char.id),
|
| 105 |
+
block_type=str(char.block_type),
|
| 106 |
+
text=char.text,
|
| 107 |
+
polygon=char.polygon.polygon,
|
| 108 |
+
bbox=char.polygon.bbox,
|
| 109 |
+
)
|
| 110 |
+
for char in span_chars
|
| 111 |
+
]
|
| 112 |
+
)
|
| 113 |
+
line_obj.children = children
|
| 114 |
+
lines.append(line_obj)
|
| 115 |
+
|
| 116 |
+
page = OCRJSONPageOutput(
|
| 117 |
+
id=str(page.id),
|
| 118 |
+
block_type=str(page.block_type),
|
| 119 |
+
polygon=page.polygon.polygon,
|
| 120 |
+
bbox=page.polygon.bbox,
|
| 121 |
+
children=lines,
|
| 122 |
+
)
|
| 123 |
+
pages.append(page)
|
| 124 |
+
|
| 125 |
+
return pages
|
| 126 |
+
|
| 127 |
+
def __call__(self, document: Document) -> OCRJSONOutput:
|
| 128 |
+
return OCRJSONOutput(children=self.extract_json(document), metadata=None)
|
marker/schema/groups/page.py
CHANGED
|
@@ -253,14 +253,20 @@ class PageGroup(Group):
|
|
| 253 |
block.add_structure(line)
|
| 254 |
block.polygon = block.polygon.merge([line.polygon])
|
| 255 |
block.text_extraction_method = text_extraction_method
|
| 256 |
-
for span in spans:
|
| 257 |
self.add_full_block(span)
|
| 258 |
line.add_structure(span)
|
| 259 |
|
| 260 |
if not keep_chars:
|
| 261 |
continue
|
| 262 |
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
self.add_full_block(char)
|
| 265 |
span.add_structure(char)
|
| 266 |
|
|
|
|
| 253 |
block.add_structure(line)
|
| 254 |
block.polygon = block.polygon.merge([line.polygon])
|
| 255 |
block.text_extraction_method = text_extraction_method
|
| 256 |
+
for span_idx, span in enumerate(spans):
|
| 257 |
self.add_full_block(span)
|
| 258 |
line.add_structure(span)
|
| 259 |
|
| 260 |
if not keep_chars:
|
| 261 |
continue
|
| 262 |
|
| 263 |
+
# Provider doesn't have chars
|
| 264 |
+
if len(provider_output.chars) == 0:
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
# Loop through characters associated with the span
|
| 268 |
+
for char in provider_output.chars[span_idx]:
|
| 269 |
+
char.page_id = self.page_id
|
| 270 |
self.add_full_block(char)
|
| 271 |
span.add_structure(char)
|
| 272 |
|
marker/schema/text/char.py
CHANGED
|
@@ -6,5 +6,5 @@ class Char(Block):
|
|
| 6 |
block_type: BlockTypes = BlockTypes.Char
|
| 7 |
block_description: str = "A single character inside a span."
|
| 8 |
|
| 9 |
-
|
| 10 |
idx: int
|
|
|
|
| 6 |
block_type: BlockTypes = BlockTypes.Char
|
| 7 |
block_description: str = "A single character inside a span."
|
| 8 |
|
| 9 |
+
text: str
|
| 10 |
idx: int
|
tests/conftest.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from marker.providers.pdf import PdfProvider
|
| 2 |
import tempfile
|
| 3 |
from typing import Dict, Type
|
| 4 |
|
|
@@ -19,7 +18,6 @@ from marker.schema.blocks import Block
|
|
| 19 |
from marker.renderers.markdown import MarkdownRenderer
|
| 20 |
from marker.renderers.json import JSONRenderer
|
| 21 |
from marker.schema.registry import register_block_class
|
| 22 |
-
from marker.services.gemini import GoogleGeminiService
|
| 23 |
from marker.util import classes_to_strings, strings_to_classes
|
| 24 |
|
| 25 |
|
|
@@ -54,6 +52,7 @@ def table_rec_model(model_dict):
|
|
| 54 |
def ocr_error_model(model_dict):
|
| 55 |
yield model_dict["ocr_error_model"]
|
| 56 |
|
|
|
|
| 57 |
@pytest.fixture(scope="function")
|
| 58 |
def config(request):
|
| 59 |
config_mark = request.node.get_closest_marker("config")
|
|
@@ -65,20 +64,22 @@ def config(request):
|
|
| 65 |
|
| 66 |
return config
|
| 67 |
|
|
|
|
| 68 |
@pytest.fixture(scope="session")
|
| 69 |
def pdf_dataset():
|
| 70 |
return datasets.load_dataset("datalab-to/pdfs", split="train")
|
| 71 |
|
|
|
|
| 72 |
@pytest.fixture(scope="function")
|
| 73 |
def temp_doc(request, pdf_dataset):
|
| 74 |
filename_mark = request.node.get_closest_marker("filename")
|
| 75 |
filename = filename_mark.args[0] if filename_mark else "adversarial.pdf"
|
| 76 |
|
| 77 |
-
idx = pdf_dataset[
|
| 78 |
suffix = filename.split(".")[-1]
|
| 79 |
|
| 80 |
temp_pdf = tempfile.NamedTemporaryFile(suffix=f".{suffix}")
|
| 81 |
-
temp_pdf.write(pdf_dataset[
|
| 82 |
temp_pdf.flush()
|
| 83 |
yield temp_pdf
|
| 84 |
|
|
@@ -88,8 +89,17 @@ def doc_provider(request, config, temp_doc):
|
|
| 88 |
provider_cls = provider_from_filepath(temp_doc.name)
|
| 89 |
yield provider_cls(temp_doc.name, config)
|
| 90 |
|
|
|
|
| 91 |
@pytest.fixture(scope="function")
|
| 92 |
-
def pdf_document(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
layout_builder = LayoutBuilder(layout_model, config)
|
| 94 |
line_builder = LineBuilder(detection_model, ocr_error_model, config)
|
| 95 |
ocr_builder = OcrBuilder(recognition_model, config)
|
|
@@ -107,7 +117,7 @@ def pdf_converter(request, config, model_dict, renderer, llm_service):
|
|
| 107 |
processor_list=None,
|
| 108 |
renderer=classes_to_strings([renderer])[0],
|
| 109 |
config=config,
|
| 110 |
-
llm_service=llm_service
|
| 111 |
)
|
| 112 |
|
| 113 |
|
|
|
|
|
|
|
| 1 |
import tempfile
|
| 2 |
from typing import Dict, Type
|
| 3 |
|
|
|
|
| 18 |
from marker.renderers.markdown import MarkdownRenderer
|
| 19 |
from marker.renderers.json import JSONRenderer
|
| 20 |
from marker.schema.registry import register_block_class
|
|
|
|
| 21 |
from marker.util import classes_to_strings, strings_to_classes
|
| 22 |
|
| 23 |
|
|
|
|
| 52 |
def ocr_error_model(model_dict):
|
| 53 |
yield model_dict["ocr_error_model"]
|
| 54 |
|
| 55 |
+
|
| 56 |
@pytest.fixture(scope="function")
|
| 57 |
def config(request):
|
| 58 |
config_mark = request.node.get_closest_marker("config")
|
|
|
|
| 64 |
|
| 65 |
return config
|
| 66 |
|
| 67 |
+
|
| 68 |
@pytest.fixture(scope="session")
|
| 69 |
def pdf_dataset():
|
| 70 |
return datasets.load_dataset("datalab-to/pdfs", split="train")
|
| 71 |
|
| 72 |
+
|
| 73 |
@pytest.fixture(scope="function")
|
| 74 |
def temp_doc(request, pdf_dataset):
|
| 75 |
filename_mark = request.node.get_closest_marker("filename")
|
| 76 |
filename = filename_mark.args[0] if filename_mark else "adversarial.pdf"
|
| 77 |
|
| 78 |
+
idx = pdf_dataset["filename"].index(filename)
|
| 79 |
suffix = filename.split(".")[-1]
|
| 80 |
|
| 81 |
temp_pdf = tempfile.NamedTemporaryFile(suffix=f".{suffix}")
|
| 82 |
+
temp_pdf.write(pdf_dataset["pdf"][idx])
|
| 83 |
temp_pdf.flush()
|
| 84 |
yield temp_pdf
|
| 85 |
|
|
|
|
| 89 |
provider_cls = provider_from_filepath(temp_doc.name)
|
| 90 |
yield provider_cls(temp_doc.name, config)
|
| 91 |
|
| 92 |
+
|
| 93 |
@pytest.fixture(scope="function")
|
| 94 |
+
def pdf_document(
|
| 95 |
+
request,
|
| 96 |
+
config,
|
| 97 |
+
doc_provider,
|
| 98 |
+
layout_model,
|
| 99 |
+
ocr_error_model,
|
| 100 |
+
recognition_model,
|
| 101 |
+
detection_model,
|
| 102 |
+
):
|
| 103 |
layout_builder = LayoutBuilder(layout_model, config)
|
| 104 |
line_builder = LineBuilder(detection_model, ocr_error_model, config)
|
| 105 |
ocr_builder = OcrBuilder(recognition_model, config)
|
|
|
|
| 117 |
processor_list=None,
|
| 118 |
renderer=classes_to_strings([renderer])[0],
|
| 119 |
config=config,
|
| 120 |
+
llm_service=llm_service,
|
| 121 |
)
|
| 122 |
|
| 123 |
|
tests/converters/test_ocr_converter.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from marker.converters.ocr import OCRConverter
|
| 4 |
+
from marker.renderers.ocr_json import OCRJSONOutput
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _ocr_converter(config, model_dict, temp_pdf):
|
| 8 |
+
converter = OCRConverter(artifact_dict=model_dict, config=config)
|
| 9 |
+
|
| 10 |
+
ocr_json: OCRJSONOutput = converter(temp_pdf.name)
|
| 11 |
+
pages = ocr_json.pages
|
| 12 |
+
|
| 13 |
+
assert len(pages) == 1
|
| 14 |
+
breakpoint()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@pytest.mark.config({"page_range": [0]})
|
| 18 |
+
def test_ocr_converter(config, model_dict, temp_doc):
|
| 19 |
+
_ocr_converter(config, model_dict, temp_doc)
|