File size: 4,867 Bytes
383b4a6
4d0b118
383b4a6
78f3a66
 
383b4a6
 
 
2c69783
 
c39ff12
2c69783
3e38f0f
2c69783
969ff96
18eefbb
f960d11
 
2c69783
 
 
 
 
f2f1a27
 
383b4a6
969ff96
 
 
 
 
 
383b4a6
 
969ff96
 
6ebcce8
 
 
969ff96
 
6ebcce8
 
 
969ff96
 
6ebcce8
 
 
969ff96
 
 
383b4a6
1186a50
969ff96
 
 
ed65502
e40eddb
00bb02a
 
 
 
 
 
1adf0e5
00bb02a
 
 
ed65502
44826c8
 
 
00bb02a
ed65502
00bb02a
18eefbb
aee20f6
 
 
ed65502
18eefbb
383b4a6
18eefbb
ed65502
383b4a6
4d0b118
 
 
 
18eefbb
 
 
5380a39
ed65502
5380a39
ed65502
 
 
 
 
 
 
 
 
44826c8
870d666
c39ff12
72d2813
3e38f0f
18eefbb
3e38f0f
4d0b118
 
 
 
f2f1a27
 
 
4d0b118
433156c
5038c82
433156c
f2f1a27
ed65502
4d0b118
 
 
 
c0d162d
 
 
 
433156c
c0d162d
433156c
f960d11
 
 
 
c0d162d
 
 
433156c
78f3a66
5471d0c
 
f2f1a27
 
 
 
 
 
5471d0c
 
78f3a66
 
 
 
8627bc6
78f3a66
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import tempfile
from typing import Dict, Type

from PIL import Image, ImageDraw

import datasets
import pytest

from marker.builders.document import DocumentBuilder
from marker.builders.layout import LayoutBuilder
from marker.builders.line import LineBuilder
from marker.builders.ocr import OcrBuilder
from marker.builders.structure import StructureBuilder
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.providers.registry import provider_from_filepath
from marker.renderers.chunk import ChunkRenderer
from marker.renderers.html import HTMLRenderer
from marker.schema import BlockTypes
from marker.schema.blocks import Block
from marker.renderers.markdown import MarkdownRenderer
from marker.renderers.json import JSONRenderer
from marker.schema.registry import register_block_class
from marker.util import classes_to_strings, strings_to_classes


@pytest.fixture(scope="session")
def model_dict():
    model_dict = create_model_dict()
    yield model_dict
    del model_dict


@pytest.fixture(scope="session")
def layout_model(model_dict):
    yield model_dict["layout_model"]


@pytest.fixture(scope="session")
def detection_model(model_dict):
    yield model_dict["detection_model"]


@pytest.fixture(scope="session")
def recognition_model(model_dict):
    yield model_dict["recognition_model"]


@pytest.fixture(scope="session")
def table_rec_model(model_dict):
    yield model_dict["table_rec_model"]


@pytest.fixture(scope="session")
def ocr_error_model(model_dict):
    yield model_dict["ocr_error_model"]


@pytest.fixture(scope="function")
def config(request):
    config_mark = request.node.get_closest_marker("config")
    config = config_mark.args[0] if config_mark else {}

    override_map: Dict[BlockTypes, Type[Block]] = config.get("override_map", {})
    for block_type, override_block_type in override_map.items():
        register_block_class(block_type, override_block_type)

    return config


@pytest.fixture(scope="session")
def pdf_dataset():
    return datasets.load_dataset("datalab-to/pdfs", split="train")


@pytest.fixture(scope="function")
def temp_doc(request, pdf_dataset):
    filename_mark = request.node.get_closest_marker("filename")
    filename = filename_mark.args[0] if filename_mark else "adversarial.pdf"

    idx = pdf_dataset["filename"].index(filename)
    suffix = filename.split(".")[-1]

    temp_pdf = tempfile.NamedTemporaryFile(suffix=f".{suffix}")
    temp_pdf.write(pdf_dataset["pdf"][idx])
    temp_pdf.flush()
    yield temp_pdf


@pytest.fixture(scope="function")
def doc_provider(request, config, temp_doc):
    provider_cls = provider_from_filepath(temp_doc.name)
    yield provider_cls(temp_doc.name, config)


@pytest.fixture(scope="function")
def pdf_document(
    request,
    config,
    doc_provider,
    layout_model,
    ocr_error_model,
    recognition_model,
    detection_model,
):
    layout_builder = LayoutBuilder(layout_model, config)
    line_builder = LineBuilder(detection_model, ocr_error_model, config)
    ocr_builder = OcrBuilder(recognition_model, config)
    builder = DocumentBuilder(config)
    structure_builder = StructureBuilder(config)
    document = builder(doc_provider, layout_builder, line_builder, ocr_builder)
    structure_builder(document)
    yield document


@pytest.fixture(scope="function")
def pdf_converter(request, config, model_dict, renderer, llm_service):
    if llm_service:
        llm_service = classes_to_strings([llm_service])[0]
    yield PdfConverter(
        artifact_dict=model_dict,
        processor_list=None,
        renderer=classes_to_strings([renderer])[0],
        config=config,
        llm_service=llm_service,
    )


@pytest.fixture(scope="function")
def renderer(request, config):
    if request.node.get_closest_marker("output_format"):
        output_format = request.node.get_closest_marker("output_format").args[0]
        if output_format == "markdown":
            return MarkdownRenderer
        elif output_format == "json":
            return JSONRenderer
        elif output_format == "html":
            return HTMLRenderer
        elif output_format == "chunks":
            return ChunkRenderer
        else:
            raise ValueError(f"Unknown output format: {output_format}")
    else:
        return MarkdownRenderer


@pytest.fixture(scope="function")
def llm_service(request, config):
    llm_service = config.get("llm_service")
    if not llm_service:
        yield None
    else:
        yield strings_to_classes([llm_service])[0]


@pytest.fixture(scope="function")
def temp_image():
    img = Image.new("RGB", (512, 512), color="white")
    draw = ImageDraw.Draw(img)
    draw.text((200, 200), "Hello, World!", fill="black", font_size=36)
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
        img.save(f.name)
        f.flush()
        yield f