Merge remote-tracking branch 'origin/dev' into highquality-processors
Browse files- marker/builders/layout.py +40 -4
- marker/models.py +11 -0
- marker/renderers/markdown.py +9 -0
- poetry.lock +22 -22
- tests/builders/test_blank_page.py +2 -2
- tests/builders/test_garbled_pdf.py +29 -2
- tests/conftest.py +11 -5
marker/builders/layout.py
CHANGED
|
@@ -5,6 +5,10 @@ from surya.layout import batch_layout_detection
|
|
| 5 |
from surya.schema import LayoutResult
|
| 6 |
from surya.model.layout.encoderdecoder import SuryaLayoutModel
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from marker.settings import settings
|
| 9 |
from marker.builders import BaseBuilder
|
| 10 |
from marker.providers import ProviderOutput, ProviderPageLines
|
|
@@ -37,15 +41,21 @@ class LayoutBuilder(BaseBuilder):
|
|
| 37 |
document_ocr_threshold (float):
|
| 38 |
The minimum ratio of pages that must pass the layout coverage check
|
| 39 |
to avoid OCR. Default is 0.8.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
"""
|
| 41 |
batch_size = None
|
| 42 |
layout_coverage_min_lines = 1
|
| 43 |
layout_coverage_threshold = .1
|
| 44 |
document_ocr_threshold = .8
|
|
|
|
| 45 |
excluded_for_coverage = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
|
| 46 |
|
| 47 |
-
def __init__(self, layout_model: SuryaLayoutModel, config=None):
|
| 48 |
self.layout_model = layout_model
|
|
|
|
| 49 |
|
| 50 |
super().__init__(config)
|
| 51 |
|
|
@@ -71,6 +81,31 @@ class LayoutBuilder(BaseBuilder):
|
|
| 71 |
)
|
| 72 |
return layout_results
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[LayoutResult]):
|
| 75 |
for page, layout_result in zip(pages, layout_results):
|
| 76 |
layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
|
|
@@ -92,16 +127,17 @@ class LayoutBuilder(BaseBuilder):
|
|
| 92 |
page.children = []
|
| 93 |
|
| 94 |
def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: ProviderPageLines):
|
|
|
|
|
|
|
| 95 |
good_pages = []
|
| 96 |
-
for document_page in document_pages:
|
| 97 |
provider_lines = provider_page_lines.get(document_page.page_id, [])
|
| 98 |
-
good_pages.append(self.check_layout_coverage(document_page, provider_lines))
|
| 99 |
|
| 100 |
ocr_document = sum(good_pages) / len(good_pages) < self.document_ocr_threshold
|
| 101 |
for idx, document_page in enumerate(document_pages):
|
| 102 |
provider_lines = provider_page_lines.get(document_page.page_id, [])
|
| 103 |
needs_ocr = not good_pages[idx]
|
| 104 |
-
|
| 105 |
if needs_ocr and ocr_document:
|
| 106 |
document_page.text_extraction_method = "surya"
|
| 107 |
continue
|
|
|
|
| 5 |
from surya.schema import LayoutResult
|
| 6 |
from surya.model.layout.encoderdecoder import SuryaLayoutModel
|
| 7 |
|
| 8 |
+
from surya.ocr_error import batch_ocr_error_detection
|
| 9 |
+
from surya.schema import OCRErrorDetectionResult
|
| 10 |
+
from surya.model.ocr_error.model import DistilBertForSequenceClassification
|
| 11 |
+
|
| 12 |
from marker.settings import settings
|
| 13 |
from marker.builders import BaseBuilder
|
| 14 |
from marker.providers import ProviderOutput, ProviderPageLines
|
|
|
|
| 41 |
document_ocr_threshold (float):
|
| 42 |
The minimum ratio of pages that must pass the layout coverage check
|
| 43 |
to avoid OCR. Default is 0.8.
|
| 44 |
+
|
| 45 |
+
error_model_segment_length (int):
|
| 46 |
+
The maximum number of characters to send to the OCR error model.
|
| 47 |
+
Default is 1024.
|
| 48 |
"""
|
| 49 |
batch_size = None
|
| 50 |
layout_coverage_min_lines = 1
|
| 51 |
layout_coverage_threshold = .1
|
| 52 |
document_ocr_threshold = .8
|
| 53 |
+
error_model_segment_length = 512
|
| 54 |
excluded_for_coverage = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
|
| 55 |
|
| 56 |
+
def __init__(self, layout_model: SuryaLayoutModel, ocr_error_model: DistilBertForSequenceClassification, config=None):
|
| 57 |
self.layout_model = layout_model
|
| 58 |
+
self.ocr_error_model = ocr_error_model
|
| 59 |
|
| 60 |
super().__init__(config)
|
| 61 |
|
|
|
|
| 81 |
)
|
| 82 |
return layout_results
|
| 83 |
|
| 84 |
+
def surya_ocr_error_detection(self, pages:List[PageGroup], provider_page_lines: ProviderPageLines) -> OCRErrorDetectionResult:
|
| 85 |
+
page_texts = []
|
| 86 |
+
for document_page in pages:
|
| 87 |
+
page_text = ''
|
| 88 |
+
provider_lines = provider_page_lines.get(document_page.page_id, [])
|
| 89 |
+
for line in provider_lines:
|
| 90 |
+
page_text += ' '.join([s.text for s in line.spans])
|
| 91 |
+
|
| 92 |
+
# Sample text from the middle
|
| 93 |
+
if len(page_text) > 0:
|
| 94 |
+
page_text_middle = len(page_text) // 2
|
| 95 |
+
page_text_start = max(0, page_text_middle - self.error_model_segment_length // 2)
|
| 96 |
+
page_text_end = page_text_start + self.error_model_segment_length
|
| 97 |
+
page_text = page_text[page_text_start:page_text_end]
|
| 98 |
+
|
| 99 |
+
page_texts.append(page_text)
|
| 100 |
+
|
| 101 |
+
ocr_error_detection_results = batch_ocr_error_detection(
|
| 102 |
+
page_texts,
|
| 103 |
+
self.ocr_error_model,
|
| 104 |
+
self.ocr_error_model.tokenizer,
|
| 105 |
+
batch_size=int(self.get_batch_size()) #TODO Better Multiplier
|
| 106 |
+
)
|
| 107 |
+
return ocr_error_detection_results
|
| 108 |
+
|
| 109 |
def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[LayoutResult]):
|
| 110 |
for page, layout_result in zip(pages, layout_results):
|
| 111 |
layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
|
|
|
|
| 127 |
page.children = []
|
| 128 |
|
| 129 |
def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: ProviderPageLines):
|
| 130 |
+
ocr_error_detection_labels = self.surya_ocr_error_detection(document_pages, provider_page_lines).labels
|
| 131 |
+
|
| 132 |
good_pages = []
|
| 133 |
+
for (document_page, ocr_error_detection_label) in zip(document_pages, ocr_error_detection_labels):
|
| 134 |
provider_lines = provider_page_lines.get(document_page.page_id, [])
|
| 135 |
+
good_pages.append(self.check_layout_coverage(document_page, provider_lines) and (ocr_error_detection_label != "bad"))
|
| 136 |
|
| 137 |
ocr_document = sum(good_pages) / len(good_pages) < self.document_ocr_threshold
|
| 138 |
for idx, document_page in enumerate(document_pages):
|
| 139 |
provider_lines = provider_page_lines.get(document_page.page_id, [])
|
| 140 |
needs_ocr = not good_pages[idx]
|
|
|
|
| 141 |
if needs_ocr and ocr_document:
|
| 142 |
document_page.text_extraction_method = "surya"
|
| 143 |
continue
|
marker/models.py
CHANGED
|
@@ -12,12 +12,15 @@ from surya.model.recognition.model import load_model as load_recognition_model
|
|
| 12 |
from surya.model.recognition.processor import load_processor as load_recognition_processor
|
| 13 |
from surya.model.table_rec.model import load_model as load_table_model
|
| 14 |
from surya.model.table_rec.processor import load_processor as load_table_processor
|
|
|
|
|
|
|
| 15 |
|
| 16 |
from texify.model.model import GenerateVisionEncoderDecoderModel
|
| 17 |
from surya.model.layout.encoderdecoder import SuryaLayoutModel
|
| 18 |
from surya.model.detection.model import EfficientViTForSemanticSegmentation
|
| 19 |
from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
|
| 20 |
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def setup_table_rec_model(device=None, dtype=None) -> TableRecEncoderDecoderModel:
|
|
@@ -64,6 +67,13 @@ def setup_layout_model(device=None, dtype=None) -> SuryaLayoutModel:
|
|
| 64 |
model.processor = load_layout_processor()
|
| 65 |
return model
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
def create_model_dict(device=None, dtype=None) -> dict:
|
| 69 |
return {
|
|
@@ -72,4 +82,5 @@ def create_model_dict(device=None, dtype=None) -> dict:
|
|
| 72 |
"recognition_model": setup_recognition_model(device, dtype),
|
| 73 |
"table_rec_model": setup_table_rec_model(device, dtype),
|
| 74 |
"detection_model": setup_detection_model(device, dtype),
|
|
|
|
| 75 |
}
|
|
|
|
| 12 |
from surya.model.recognition.processor import load_processor as load_recognition_processor
|
| 13 |
from surya.model.table_rec.model import load_model as load_table_model
|
| 14 |
from surya.model.table_rec.processor import load_processor as load_table_processor
|
| 15 |
+
from surya.model.ocr_error.model import load_model as load_ocr_error_model
|
| 16 |
+
from surya.model.ocr_error.model import load_tokenizer as load_ocr_error_tokenizer
|
| 17 |
|
| 18 |
from texify.model.model import GenerateVisionEncoderDecoderModel
|
| 19 |
from surya.model.layout.encoderdecoder import SuryaLayoutModel
|
| 20 |
from surya.model.detection.model import EfficientViTForSemanticSegmentation
|
| 21 |
from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
|
| 22 |
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
|
| 23 |
+
from surya.model.ocr_error.model import DistilBertForSequenceClassification
|
| 24 |
|
| 25 |
|
| 26 |
def setup_table_rec_model(device=None, dtype=None) -> TableRecEncoderDecoderModel:
|
|
|
|
| 67 |
model.processor = load_layout_processor()
|
| 68 |
return model
|
| 69 |
|
| 70 |
+
def setup_ocr_error_model(device=None, dtype=None) -> DistilBertForSequenceClassification:
|
| 71 |
+
if device:
|
| 72 |
+
model = load_ocr_error_model(device=device, dtype=dtype)
|
| 73 |
+
else:
|
| 74 |
+
model = load_ocr_error_model()
|
| 75 |
+
model.tokenizer = load_ocr_error_tokenizer()
|
| 76 |
+
return model
|
| 77 |
|
| 78 |
def create_model_dict(device=None, dtype=None) -> dict:
|
| 79 |
return {
|
|
|
|
| 82 |
"recognition_model": setup_recognition_model(device, dtype),
|
| 83 |
"table_rec_model": setup_table_rec_model(device, dtype),
|
| 84 |
"detection_model": setup_detection_model(device, dtype),
|
| 85 |
+
"ocr_error_model": setup_ocr_error_model(device,dtype)
|
| 86 |
}
|
marker/renderers/markdown.py
CHANGED
|
@@ -53,6 +53,15 @@ class Markdownify(MarkdownConverter):
|
|
| 53 |
else:
|
| 54 |
return "\n" + self.block_math_delimiters[0] + text + self.block_math_delimiters[1] + "\n\n"
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
class MarkdownOutput(BaseModel):
|
| 58 |
markdown: str
|
|
|
|
| 53 |
else:
|
| 54 |
return "\n" + self.block_math_delimiters[0] + text + self.block_math_delimiters[1] + "\n\n"
|
| 55 |
|
| 56 |
+
def convert_td(self, el, text, convert_as_inline):
|
| 57 |
+
text = text.replace("|", " ").replace("\n", " ")
|
| 58 |
+
return super().convert_td(el, text, convert_as_inline)
|
| 59 |
+
|
| 60 |
+
def convert_th(self, el, text, convert_as_inline):
|
| 61 |
+
text = text.replace("|", " ").replace("\n", " ")
|
| 62 |
+
return super().convert_th(el, text, convert_as_inline)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
|
| 66 |
class MarkdownOutput(BaseModel):
|
| 67 |
markdown: str
|
poetry.lock
CHANGED
|
@@ -1447,13 +1447,13 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio
|
|
| 1447 |
|
| 1448 |
[[package]]
|
| 1449 |
name = "ipython"
|
| 1450 |
-
version = "8.
|
| 1451 |
description = "IPython: Productive Interactive Computing"
|
| 1452 |
optional = false
|
| 1453 |
python-versions = ">=3.10"
|
| 1454 |
files = [
|
| 1455 |
-
{file = "ipython-8.
|
| 1456 |
-
{file = "ipython-8.
|
| 1457 |
]
|
| 1458 |
|
| 1459 |
[package.dependencies]
|
|
@@ -1759,13 +1759,13 @@ jupyter-server = ">=1.1.2"
|
|
| 1759 |
|
| 1760 |
[[package]]
|
| 1761 |
name = "jupyter-server"
|
| 1762 |
-
version = "2.
|
| 1763 |
description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications."
|
| 1764 |
optional = false
|
| 1765 |
-
python-versions = ">=3.
|
| 1766 |
files = [
|
| 1767 |
-
{file = "jupyter_server-2.
|
| 1768 |
-
{file = "jupyter_server-2.
|
| 1769 |
]
|
| 1770 |
|
| 1771 |
[package.dependencies]
|
|
@@ -1774,7 +1774,7 @@ argon2-cffi = ">=21.1"
|
|
| 1774 |
jinja2 = ">=3.0.3"
|
| 1775 |
jupyter-client = ">=7.4.4"
|
| 1776 |
jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
|
| 1777 |
-
jupyter-events = ">=0.
|
| 1778 |
jupyter-server-terminals = ">=0.4.4"
|
| 1779 |
nbconvert = ">=6.4.4"
|
| 1780 |
nbformat = ">=5.3.0"
|
|
@@ -3392,24 +3392,24 @@ diagrams = ["jinja2", "railroad-diagrams"]
|
|
| 3392 |
|
| 3393 |
[[package]]
|
| 3394 |
name = "pypdfium2"
|
| 3395 |
-
version = "4.30.
|
| 3396 |
description = "Python bindings to PDFium"
|
| 3397 |
optional = false
|
| 3398 |
python-versions = ">=3.6"
|
| 3399 |
files = [
|
| 3400 |
-
{file = "pypdfium2-4.30.
|
| 3401 |
-
{file = "pypdfium2-4.30.
|
| 3402 |
-
{file = "pypdfium2-4.30.
|
| 3403 |
-
{file = "pypdfium2-4.30.
|
| 3404 |
-
{file = "pypdfium2-4.30.
|
| 3405 |
-
{file = "pypdfium2-4.30.
|
| 3406 |
-
{file = "pypdfium2-4.30.
|
| 3407 |
-
{file = "pypdfium2-4.30.
|
| 3408 |
-
{file = "pypdfium2-4.30.
|
| 3409 |
-
{file = "pypdfium2-4.30.
|
| 3410 |
-
{file = "pypdfium2-4.30.
|
| 3411 |
-
{file = "pypdfium2-4.30.
|
| 3412 |
-
{file = "pypdfium2-4.30.
|
| 3413 |
]
|
| 3414 |
|
| 3415 |
[[package]]
|
|
|
|
| 1447 |
|
| 1448 |
[[package]]
|
| 1449 |
name = "ipython"
|
| 1450 |
+
version = "8.31.0"
|
| 1451 |
description = "IPython: Productive Interactive Computing"
|
| 1452 |
optional = false
|
| 1453 |
python-versions = ">=3.10"
|
| 1454 |
files = [
|
| 1455 |
+
{file = "ipython-8.31.0-py3-none-any.whl", hash = "sha256:46ec58f8d3d076a61d128fe517a51eb730e3aaf0c184ea8c17d16e366660c6a6"},
|
| 1456 |
+
{file = "ipython-8.31.0.tar.gz", hash = "sha256:b6a2274606bec6166405ff05e54932ed6e5cfecaca1fc05f2cacde7bb074d70b"},
|
| 1457 |
]
|
| 1458 |
|
| 1459 |
[package.dependencies]
|
|
|
|
| 1759 |
|
| 1760 |
[[package]]
|
| 1761 |
name = "jupyter-server"
|
| 1762 |
+
version = "2.15.0"
|
| 1763 |
description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications."
|
| 1764 |
optional = false
|
| 1765 |
+
python-versions = ">=3.9"
|
| 1766 |
files = [
|
| 1767 |
+
{file = "jupyter_server-2.15.0-py3-none-any.whl", hash = "sha256:872d989becf83517012ee669f09604aa4a28097c0bd90b2f424310156c2cdae3"},
|
| 1768 |
+
{file = "jupyter_server-2.15.0.tar.gz", hash = "sha256:9d446b8697b4f7337a1b7cdcac40778babdd93ba614b6d68ab1c0c918f1c4084"},
|
| 1769 |
]
|
| 1770 |
|
| 1771 |
[package.dependencies]
|
|
|
|
| 1774 |
jinja2 = ">=3.0.3"
|
| 1775 |
jupyter-client = ">=7.4.4"
|
| 1776 |
jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
|
| 1777 |
+
jupyter-events = ">=0.11.0"
|
| 1778 |
jupyter-server-terminals = ">=0.4.4"
|
| 1779 |
nbconvert = ">=6.4.4"
|
| 1780 |
nbformat = ">=5.3.0"
|
|
|
|
| 3392 |
|
| 3393 |
[[package]]
|
| 3394 |
name = "pypdfium2"
|
| 3395 |
+
version = "4.30.1"
|
| 3396 |
description = "Python bindings to PDFium"
|
| 3397 |
optional = false
|
| 3398 |
python-versions = ">=3.6"
|
| 3399 |
files = [
|
| 3400 |
+
{file = "pypdfium2-4.30.1-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:e07c47633732cc18d890bb7e965ad28a9c5a932e548acb928596f86be2e5ae37"},
|
| 3401 |
+
{file = "pypdfium2-4.30.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5ea2d44e96d361123b67b00f527017aa9c847c871b5714e013c01c3eb36a79fe"},
|
| 3402 |
+
{file = "pypdfium2-4.30.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1de7a3a36803171b3f66911131046d65a732f9e7834438191cb58235e6163c4e"},
|
| 3403 |
+
{file = "pypdfium2-4.30.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8a4231efb13170354f568c722d6540b8d5b476b08825586d48ef70c40d16e03"},
|
| 3404 |
+
{file = "pypdfium2-4.30.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f434a4934e8244aa95343ffcf24e9ad9f120dbb4785f631bb40a88c39292493"},
|
| 3405 |
+
{file = "pypdfium2-4.30.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f454032a0bc7681900170f67d8711b3942824531e765f91c2f5ce7937f999794"},
|
| 3406 |
+
{file = "pypdfium2-4.30.1-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:bbf9130a72370ee9d602e39949b902db669a2a1c24746a91e5586eb829055d9f"},
|
| 3407 |
+
{file = "pypdfium2-4.30.1-py3-none-musllinux_1_1_i686.whl", hash = "sha256:5cb52884b1583b96e94fd78542c63bb42e06df5e8f9e52f8f31f5ad5a1e53367"},
|
| 3408 |
+
{file = "pypdfium2-4.30.1-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:1a9e372bd4867ff223cc8c338e33fe11055dad12f22885950fc27646cc8d9122"},
|
| 3409 |
+
{file = "pypdfium2-4.30.1-py3-none-win32.whl", hash = "sha256:421f1cf205e213e07c1f2934905779547f4f4a2ff2f59dde29da3d511d3fc806"},
|
| 3410 |
+
{file = "pypdfium2-4.30.1-py3-none-win_amd64.whl", hash = "sha256:598a7f20264ab5113853cba6d86c4566e4356cad037d7d1f849c8c9021007e05"},
|
| 3411 |
+
{file = "pypdfium2-4.30.1-py3-none-win_arm64.whl", hash = "sha256:c2b6d63f6d425d9416c08d2511822b54b8e3ac38e639fc41164b1d75584b3a8c"},
|
| 3412 |
+
{file = "pypdfium2-4.30.1.tar.gz", hash = "sha256:5f5c7c6d03598e107d974f66b220a49436aceb191da34cda5f692be098a814ce"},
|
| 3413 |
]
|
| 3414 |
|
| 3415 |
[[package]]
|
tests/builders/test_blank_page.py
CHANGED
|
@@ -5,8 +5,8 @@ from marker.builders.layout import LayoutBuilder
|
|
| 5 |
from marker.builders.ocr import OcrBuilder
|
| 6 |
|
| 7 |
|
| 8 |
-
def test_blank_page(config, pdf_provider, layout_model, recognition_model, detection_model):
|
| 9 |
-
layout_builder = LayoutBuilder(layout_model, config)
|
| 10 |
builder = DocumentBuilder(config)
|
| 11 |
document = builder.build_document(pdf_provider)
|
| 12 |
|
|
|
|
| 5 |
from marker.builders.ocr import OcrBuilder
|
| 6 |
|
| 7 |
|
| 8 |
+
def test_blank_page(config, pdf_provider, layout_model, ocr_error_model, recognition_model, detection_model):
|
| 9 |
+
layout_builder = LayoutBuilder(layout_model, ocr_error_model, config)
|
| 10 |
builder = DocumentBuilder(config)
|
| 11 |
document = builder.build_document(pdf_provider)
|
| 12 |
|
tests/builders/test_garbled_pdf.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
import pytest
|
| 2 |
-
from marker.schema import BlockTypes
|
| 3 |
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
@pytest.mark.skip(reason="This is failing because we need better garbled text detection")
|
| 6 |
@pytest.mark.filename("water_damage.pdf")
|
| 7 |
def test_garbled_pdf(pdf_document):
|
| 8 |
assert pdf_document.pages[0].structure[0] == '/page/0/Table/0'
|
|
@@ -18,3 +19,29 @@ def test_garbled_pdf(pdf_document):
|
|
| 18 |
span = pdf_document.pages[0].get_block(table_cell.structure[0])
|
| 19 |
assert span.block_type == BlockTypes.Span
|
| 20 |
assert "комплекс" in span.text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import pytest
|
|
|
|
| 2 |
|
| 3 |
+
from marker.builders.document import DocumentBuilder
|
| 4 |
+
from marker.builders.layout import LayoutBuilder
|
| 5 |
+
from marker.schema import BlockTypes
|
| 6 |
|
|
|
|
| 7 |
@pytest.mark.filename("water_damage.pdf")
|
| 8 |
def test_garbled_pdf(pdf_document):
|
| 9 |
assert pdf_document.pages[0].structure[0] == '/page/0/Table/0'
|
|
|
|
| 19 |
span = pdf_document.pages[0].get_block(table_cell.structure[0])
|
| 20 |
assert span.block_type == BlockTypes.Span
|
| 21 |
assert "комплекс" in span.text
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.mark.filename("hindi_judgement.pdf")
|
| 25 |
+
@pytest.mark.config({"page_range": [2, 3]})
|
| 26 |
+
def test_garbled_builder(config, pdf_provider, layout_model, ocr_error_model):
|
| 27 |
+
layout_builder = LayoutBuilder(layout_model, ocr_error_model, config)
|
| 28 |
+
builder = DocumentBuilder(config)
|
| 29 |
+
document = builder.build_document(pdf_provider)
|
| 30 |
+
|
| 31 |
+
bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines)
|
| 32 |
+
assert len(bad_ocr_results.labels) == 2
|
| 33 |
+
assert all([l == "bad" for l in bad_ocr_results.labels])
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@pytest.mark.filename("adversarial.pdf")
|
| 37 |
+
@pytest.mark.config({"page_range": [2, 3]})
|
| 38 |
+
def test_nongarbled_builder(config, pdf_provider, layout_model, ocr_error_model):
|
| 39 |
+
layout_builder = LayoutBuilder(layout_model, ocr_error_model, config)
|
| 40 |
+
builder = DocumentBuilder(config)
|
| 41 |
+
document = builder.build_document(pdf_provider)
|
| 42 |
+
|
| 43 |
+
bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines)
|
| 44 |
+
assert len(bad_ocr_results.labels) == 2
|
| 45 |
+
assert all([l == "good" for l in bad_ocr_results.labels])
|
| 46 |
+
|
| 47 |
+
|
tests/conftest.py
CHANGED
|
@@ -11,7 +11,7 @@ from marker.builders.ocr import OcrBuilder
|
|
| 11 |
from marker.converters.pdf import PdfConverter
|
| 12 |
from marker.models import setup_detection_model, setup_layout_model, \
|
| 13 |
setup_recognition_model, setup_table_rec_model, \
|
| 14 |
-
setup_texify_model
|
| 15 |
from marker.schema import BlockTypes
|
| 16 |
from marker.schema.blocks import Block
|
| 17 |
from marker.renderers.markdown import MarkdownRenderer
|
|
@@ -54,6 +54,11 @@ def table_rec_model():
|
|
| 54 |
yield table_rec_m
|
| 55 |
del table_rec_m
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
@pytest.fixture(scope="function")
|
| 59 |
def config(request):
|
|
@@ -87,8 +92,8 @@ def pdf_provider(request, config, temp_pdf):
|
|
| 87 |
|
| 88 |
|
| 89 |
@pytest.fixture(scope="function")
|
| 90 |
-
def pdf_document(request, config, pdf_provider, layout_model, recognition_model, detection_model):
|
| 91 |
-
layout_builder = LayoutBuilder(layout_model, config)
|
| 92 |
ocr_builder = OcrBuilder(detection_model, recognition_model, config)
|
| 93 |
builder = DocumentBuilder(config)
|
| 94 |
document = builder(pdf_provider, layout_builder, ocr_builder)
|
|
@@ -96,13 +101,14 @@ def pdf_document(request, config, pdf_provider, layout_model, recognition_model,
|
|
| 96 |
|
| 97 |
|
| 98 |
@pytest.fixture(scope="function")
|
| 99 |
-
def pdf_converter(request, config, layout_model, texify_model, recognition_model, table_rec_model, detection_model, renderer):
|
| 100 |
model_dict = {
|
| 101 |
"layout_model": layout_model,
|
| 102 |
"texify_model": texify_model,
|
| 103 |
"recognition_model": recognition_model,
|
| 104 |
"table_rec_model": table_rec_model,
|
| 105 |
-
"detection_model": detection_model
|
|
|
|
| 106 |
}
|
| 107 |
yield PdfConverter(
|
| 108 |
artifact_dict=model_dict,
|
|
|
|
| 11 |
from marker.converters.pdf import PdfConverter
|
| 12 |
from marker.models import setup_detection_model, setup_layout_model, \
|
| 13 |
setup_recognition_model, setup_table_rec_model, \
|
| 14 |
+
setup_texify_model, setup_ocr_error_model
|
| 15 |
from marker.schema import BlockTypes
|
| 16 |
from marker.schema.blocks import Block
|
| 17 |
from marker.renderers.markdown import MarkdownRenderer
|
|
|
|
| 54 |
yield table_rec_m
|
| 55 |
del table_rec_m
|
| 56 |
|
| 57 |
+
@pytest.fixture(scope="session")
|
| 58 |
+
def ocr_error_model():
|
| 59 |
+
ocr_error_m = setup_ocr_error_model()
|
| 60 |
+
yield ocr_error_m
|
| 61 |
+
del ocr_error_m
|
| 62 |
|
| 63 |
@pytest.fixture(scope="function")
|
| 64 |
def config(request):
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
@pytest.fixture(scope="function")
|
| 95 |
+
def pdf_document(request, config, pdf_provider, layout_model, ocr_error_model, recognition_model, detection_model):
|
| 96 |
+
layout_builder = LayoutBuilder(layout_model, ocr_error_model, config)
|
| 97 |
ocr_builder = OcrBuilder(detection_model, recognition_model, config)
|
| 98 |
builder = DocumentBuilder(config)
|
| 99 |
document = builder(pdf_provider, layout_builder, ocr_builder)
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
@pytest.fixture(scope="function")
|
| 104 |
+
def pdf_converter(request, config, layout_model, texify_model, recognition_model, table_rec_model, detection_model, ocr_error_model, renderer):
|
| 105 |
model_dict = {
|
| 106 |
"layout_model": layout_model,
|
| 107 |
"texify_model": texify_model,
|
| 108 |
"recognition_model": recognition_model,
|
| 109 |
"table_rec_model": table_rec_model,
|
| 110 |
+
"detection_model": detection_model,
|
| 111 |
+
"ocr_error_model": ocr_error_model
|
| 112 |
}
|
| 113 |
yield PdfConverter(
|
| 114 |
artifact_dict=model_dict,
|