Moses Paul R commited on
Commit
321ab9a
·
2 Parent(s): 0b5878f 2dd8e10

Merge remote-tracking branch 'origin/dev' into highquality-processors

Browse files
marker/builders/layout.py CHANGED
@@ -5,6 +5,10 @@ from surya.layout import batch_layout_detection
5
  from surya.schema import LayoutResult
6
  from surya.model.layout.encoderdecoder import SuryaLayoutModel
7
 
 
 
 
 
8
  from marker.settings import settings
9
  from marker.builders import BaseBuilder
10
  from marker.providers import ProviderOutput, ProviderPageLines
@@ -37,15 +41,21 @@ class LayoutBuilder(BaseBuilder):
37
  document_ocr_threshold (float):
38
  The minimum ratio of pages that must pass the layout coverage check
39
  to avoid OCR. Default is 0.8.
 
 
 
 
40
  """
41
  batch_size = None
42
  layout_coverage_min_lines = 1
43
  layout_coverage_threshold = .1
44
  document_ocr_threshold = .8
 
45
  excluded_for_coverage = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
46
 
47
- def __init__(self, layout_model: SuryaLayoutModel, config=None):
48
  self.layout_model = layout_model
 
49
 
50
  super().__init__(config)
51
 
@@ -71,6 +81,31 @@ class LayoutBuilder(BaseBuilder):
71
  )
72
  return layout_results
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[LayoutResult]):
75
  for page, layout_result in zip(pages, layout_results):
76
  layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
@@ -92,16 +127,17 @@ class LayoutBuilder(BaseBuilder):
92
  page.children = []
93
 
94
  def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: ProviderPageLines):
 
 
95
  good_pages = []
96
- for document_page in document_pages:
97
  provider_lines = provider_page_lines.get(document_page.page_id, [])
98
- good_pages.append(self.check_layout_coverage(document_page, provider_lines))
99
 
100
  ocr_document = sum(good_pages) / len(good_pages) < self.document_ocr_threshold
101
  for idx, document_page in enumerate(document_pages):
102
  provider_lines = provider_page_lines.get(document_page.page_id, [])
103
  needs_ocr = not good_pages[idx]
104
-
105
  if needs_ocr and ocr_document:
106
  document_page.text_extraction_method = "surya"
107
  continue
 
5
  from surya.schema import LayoutResult
6
  from surya.model.layout.encoderdecoder import SuryaLayoutModel
7
 
8
+ from surya.ocr_error import batch_ocr_error_detection
9
+ from surya.schema import OCRErrorDetectionResult
10
+ from surya.model.ocr_error.model import DistilBertForSequenceClassification
11
+
12
  from marker.settings import settings
13
  from marker.builders import BaseBuilder
14
  from marker.providers import ProviderOutput, ProviderPageLines
 
41
  document_ocr_threshold (float):
42
  The minimum ratio of pages that must pass the layout coverage check
43
  to avoid OCR. Default is 0.8.
44
+
45
+ error_model_segment_length (int):
46
+ The maximum number of characters to send to the OCR error model.
47
+ Default is 1024.
48
  """
49
  batch_size = None
50
  layout_coverage_min_lines = 1
51
  layout_coverage_threshold = .1
52
  document_ocr_threshold = .8
53
+ error_model_segment_length = 512
54
  excluded_for_coverage = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
55
 
56
+ def __init__(self, layout_model: SuryaLayoutModel, ocr_error_model: DistilBertForSequenceClassification, config=None):
57
  self.layout_model = layout_model
58
+ self.ocr_error_model = ocr_error_model
59
 
60
  super().__init__(config)
61
 
 
81
  )
82
  return layout_results
83
 
84
+ def surya_ocr_error_detection(self, pages:List[PageGroup], provider_page_lines: ProviderPageLines) -> OCRErrorDetectionResult:
85
+ page_texts = []
86
+ for document_page in pages:
87
+ page_text = ''
88
+ provider_lines = provider_page_lines.get(document_page.page_id, [])
89
+ for line in provider_lines:
90
+ page_text += ' '.join([s.text for s in line.spans])
91
+
92
+ # Sample text from the middle
93
+ if len(page_text) > 0:
94
+ page_text_middle = len(page_text) // 2
95
+ page_text_start = max(0, page_text_middle - self.error_model_segment_length // 2)
96
+ page_text_end = page_text_start + self.error_model_segment_length
97
+ page_text = page_text[page_text_start:page_text_end]
98
+
99
+ page_texts.append(page_text)
100
+
101
+ ocr_error_detection_results = batch_ocr_error_detection(
102
+ page_texts,
103
+ self.ocr_error_model,
104
+ self.ocr_error_model.tokenizer,
105
+ batch_size=int(self.get_batch_size()) #TODO Better Multiplier
106
+ )
107
+ return ocr_error_detection_results
108
+
109
  def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[LayoutResult]):
110
  for page, layout_result in zip(pages, layout_results):
111
  layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
 
127
  page.children = []
128
 
129
  def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: ProviderPageLines):
130
+ ocr_error_detection_labels = self.surya_ocr_error_detection(document_pages, provider_page_lines).labels
131
+
132
  good_pages = []
133
+ for (document_page, ocr_error_detection_label) in zip(document_pages, ocr_error_detection_labels):
134
  provider_lines = provider_page_lines.get(document_page.page_id, [])
135
+ good_pages.append(self.check_layout_coverage(document_page, provider_lines) and (ocr_error_detection_label != "bad"))
136
 
137
  ocr_document = sum(good_pages) / len(good_pages) < self.document_ocr_threshold
138
  for idx, document_page in enumerate(document_pages):
139
  provider_lines = provider_page_lines.get(document_page.page_id, [])
140
  needs_ocr = not good_pages[idx]
 
141
  if needs_ocr and ocr_document:
142
  document_page.text_extraction_method = "surya"
143
  continue
marker/models.py CHANGED
@@ -12,12 +12,15 @@ from surya.model.recognition.model import load_model as load_recognition_model
12
  from surya.model.recognition.processor import load_processor as load_recognition_processor
13
  from surya.model.table_rec.model import load_model as load_table_model
14
  from surya.model.table_rec.processor import load_processor as load_table_processor
 
 
15
 
16
  from texify.model.model import GenerateVisionEncoderDecoderModel
17
  from surya.model.layout.encoderdecoder import SuryaLayoutModel
18
  from surya.model.detection.model import EfficientViTForSemanticSegmentation
19
  from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
20
  from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
 
21
 
22
 
23
  def setup_table_rec_model(device=None, dtype=None) -> TableRecEncoderDecoderModel:
@@ -64,6 +67,13 @@ def setup_layout_model(device=None, dtype=None) -> SuryaLayoutModel:
64
  model.processor = load_layout_processor()
65
  return model
66
 
 
 
 
 
 
 
 
67
 
68
  def create_model_dict(device=None, dtype=None) -> dict:
69
  return {
@@ -72,4 +82,5 @@ def create_model_dict(device=None, dtype=None) -> dict:
72
  "recognition_model": setup_recognition_model(device, dtype),
73
  "table_rec_model": setup_table_rec_model(device, dtype),
74
  "detection_model": setup_detection_model(device, dtype),
 
75
  }
 
12
  from surya.model.recognition.processor import load_processor as load_recognition_processor
13
  from surya.model.table_rec.model import load_model as load_table_model
14
  from surya.model.table_rec.processor import load_processor as load_table_processor
15
+ from surya.model.ocr_error.model import load_model as load_ocr_error_model
16
+ from surya.model.ocr_error.model import load_tokenizer as load_ocr_error_tokenizer
17
 
18
  from texify.model.model import GenerateVisionEncoderDecoderModel
19
  from surya.model.layout.encoderdecoder import SuryaLayoutModel
20
  from surya.model.detection.model import EfficientViTForSemanticSegmentation
21
  from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
22
  from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
23
+ from surya.model.ocr_error.model import DistilBertForSequenceClassification
24
 
25
 
26
  def setup_table_rec_model(device=None, dtype=None) -> TableRecEncoderDecoderModel:
 
67
  model.processor = load_layout_processor()
68
  return model
69
 
70
+ def setup_ocr_error_model(device=None, dtype=None) -> DistilBertForSequenceClassification:
71
+ if device:
72
+ model = load_ocr_error_model(device=device, dtype=dtype)
73
+ else:
74
+ model = load_ocr_error_model()
75
+ model.tokenizer = load_ocr_error_tokenizer()
76
+ return model
77
 
78
  def create_model_dict(device=None, dtype=None) -> dict:
79
  return {
 
82
  "recognition_model": setup_recognition_model(device, dtype),
83
  "table_rec_model": setup_table_rec_model(device, dtype),
84
  "detection_model": setup_detection_model(device, dtype),
85
+ "ocr_error_model": setup_ocr_error_model(device,dtype)
86
  }
marker/renderers/markdown.py CHANGED
@@ -53,6 +53,15 @@ class Markdownify(MarkdownConverter):
53
  else:
54
  return "\n" + self.block_math_delimiters[0] + text + self.block_math_delimiters[1] + "\n\n"
55
 
 
 
 
 
 
 
 
 
 
56
 
57
  class MarkdownOutput(BaseModel):
58
  markdown: str
 
53
  else:
54
  return "\n" + self.block_math_delimiters[0] + text + self.block_math_delimiters[1] + "\n\n"
55
 
56
+ def convert_td(self, el, text, convert_as_inline):
57
+ text = text.replace("|", " ").replace("\n", " ")
58
+ return super().convert_td(el, text, convert_as_inline)
59
+
60
+ def convert_th(self, el, text, convert_as_inline):
61
+ text = text.replace("|", " ").replace("\n", " ")
62
+ return super().convert_th(el, text, convert_as_inline)
63
+
64
+
65
 
66
  class MarkdownOutput(BaseModel):
67
  markdown: str
poetry.lock CHANGED
@@ -1447,13 +1447,13 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio
1447
 
1448
  [[package]]
1449
  name = "ipython"
1450
- version = "8.30.0"
1451
  description = "IPython: Productive Interactive Computing"
1452
  optional = false
1453
  python-versions = ">=3.10"
1454
  files = [
1455
- {file = "ipython-8.30.0-py3-none-any.whl", hash = "sha256:85ec56a7e20f6c38fce7727dcca699ae4ffc85985aa7b23635a8008f918ae321"},
1456
- {file = "ipython-8.30.0.tar.gz", hash = "sha256:cb0a405a306d2995a5cbb9901894d240784a9f341394c6ba3f4fe8c6eb89ff6e"},
1457
  ]
1458
 
1459
  [package.dependencies]
@@ -1759,13 +1759,13 @@ jupyter-server = ">=1.1.2"
1759
 
1760
  [[package]]
1761
  name = "jupyter-server"
1762
- version = "2.14.2"
1763
  description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications."
1764
  optional = false
1765
- python-versions = ">=3.8"
1766
  files = [
1767
- {file = "jupyter_server-2.14.2-py3-none-any.whl", hash = "sha256:47ff506127c2f7851a17bf4713434208fc490955d0e8632e95014a9a9afbeefd"},
1768
- {file = "jupyter_server-2.14.2.tar.gz", hash = "sha256:66095021aa9638ced276c248b1d81862e4c50f292d575920bbe960de1c56b12b"},
1769
  ]
1770
 
1771
  [package.dependencies]
@@ -1774,7 +1774,7 @@ argon2-cffi = ">=21.1"
1774
  jinja2 = ">=3.0.3"
1775
  jupyter-client = ">=7.4.4"
1776
  jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
1777
- jupyter-events = ">=0.9.0"
1778
  jupyter-server-terminals = ">=0.4.4"
1779
  nbconvert = ">=6.4.4"
1780
  nbformat = ">=5.3.0"
@@ -3392,24 +3392,24 @@ diagrams = ["jinja2", "railroad-diagrams"]
3392
 
3393
  [[package]]
3394
  name = "pypdfium2"
3395
- version = "4.30.0"
3396
  description = "Python bindings to PDFium"
3397
  optional = false
3398
  python-versions = ">=3.6"
3399
  files = [
3400
- {file = "pypdfium2-4.30.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:b33ceded0b6ff5b2b93bc1fe0ad4b71aa6b7e7bd5875f1ca0cdfb6ba6ac01aab"},
3401
- {file = "pypdfium2-4.30.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:4e55689f4b06e2d2406203e771f78789bd4f190731b5d57383d05cf611d829de"},
3402
- {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e6e50f5ce7f65a40a33d7c9edc39f23140c57e37144c2d6d9e9262a2a854854"},
3403
- {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3d0dd3ecaffd0b6dbda3da663220e705cb563918249bda26058c6036752ba3a2"},
3404
- {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cc3bf29b0db8c76cdfaac1ec1cde8edf211a7de7390fbf8934ad2aa9b4d6dfad"},
3405
- {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1f78d2189e0ddf9ac2b7a9b9bd4f0c66f54d1389ff6c17e9fd9dc034d06eb3f"},
3406
- {file = "pypdfium2-4.30.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:5eda3641a2da7a7a0b2f4dbd71d706401a656fea521b6b6faa0675b15d31a163"},
3407
- {file = "pypdfium2-4.30.0-py3-none-musllinux_1_1_i686.whl", hash = "sha256:0dfa61421b5eb68e1188b0b2231e7ba35735aef2d867d86e48ee6cab6975195e"},
3408
- {file = "pypdfium2-4.30.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:f33bd79e7a09d5f7acca3b0b69ff6c8a488869a7fab48fdf400fec6e20b9c8be"},
3409
- {file = "pypdfium2-4.30.0-py3-none-win32.whl", hash = "sha256:ee2410f15d576d976c2ab2558c93d392a25fb9f6635e8dd0a8a3a5241b275e0e"},
3410
- {file = "pypdfium2-4.30.0-py3-none-win_amd64.whl", hash = "sha256:90dbb2ac07be53219f56be09961eb95cf2473f834d01a42d901d13ccfad64b4c"},
3411
- {file = "pypdfium2-4.30.0-py3-none-win_arm64.whl", hash = "sha256:119b2969a6d6b1e8d55e99caaf05290294f2d0fe49c12a3f17102d01c441bd29"},
3412
- {file = "pypdfium2-4.30.0.tar.gz", hash = "sha256:48b5b7e5566665bc1015b9d69c1ebabe21f6aee468b509531c3c8318eeee2e16"},
3413
  ]
3414
 
3415
  [[package]]
 
1447
 
1448
  [[package]]
1449
  name = "ipython"
1450
+ version = "8.31.0"
1451
  description = "IPython: Productive Interactive Computing"
1452
  optional = false
1453
  python-versions = ">=3.10"
1454
  files = [
1455
+ {file = "ipython-8.31.0-py3-none-any.whl", hash = "sha256:46ec58f8d3d076a61d128fe517a51eb730e3aaf0c184ea8c17d16e366660c6a6"},
1456
+ {file = "ipython-8.31.0.tar.gz", hash = "sha256:b6a2274606bec6166405ff05e54932ed6e5cfecaca1fc05f2cacde7bb074d70b"},
1457
  ]
1458
 
1459
  [package.dependencies]
 
1759
 
1760
  [[package]]
1761
  name = "jupyter-server"
1762
+ version = "2.15.0"
1763
  description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications."
1764
  optional = false
1765
+ python-versions = ">=3.9"
1766
  files = [
1767
+ {file = "jupyter_server-2.15.0-py3-none-any.whl", hash = "sha256:872d989becf83517012ee669f09604aa4a28097c0bd90b2f424310156c2cdae3"},
1768
+ {file = "jupyter_server-2.15.0.tar.gz", hash = "sha256:9d446b8697b4f7337a1b7cdcac40778babdd93ba614b6d68ab1c0c918f1c4084"},
1769
  ]
1770
 
1771
  [package.dependencies]
 
1774
  jinja2 = ">=3.0.3"
1775
  jupyter-client = ">=7.4.4"
1776
  jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
1777
+ jupyter-events = ">=0.11.0"
1778
  jupyter-server-terminals = ">=0.4.4"
1779
  nbconvert = ">=6.4.4"
1780
  nbformat = ">=5.3.0"
 
3392
 
3393
  [[package]]
3394
  name = "pypdfium2"
3395
+ version = "4.30.1"
3396
  description = "Python bindings to PDFium"
3397
  optional = false
3398
  python-versions = ">=3.6"
3399
  files = [
3400
+ {file = "pypdfium2-4.30.1-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:e07c47633732cc18d890bb7e965ad28a9c5a932e548acb928596f86be2e5ae37"},
3401
+ {file = "pypdfium2-4.30.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5ea2d44e96d361123b67b00f527017aa9c847c871b5714e013c01c3eb36a79fe"},
3402
+ {file = "pypdfium2-4.30.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1de7a3a36803171b3f66911131046d65a732f9e7834438191cb58235e6163c4e"},
3403
+ {file = "pypdfium2-4.30.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8a4231efb13170354f568c722d6540b8d5b476b08825586d48ef70c40d16e03"},
3404
+ {file = "pypdfium2-4.30.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f434a4934e8244aa95343ffcf24e9ad9f120dbb4785f631bb40a88c39292493"},
3405
+ {file = "pypdfium2-4.30.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f454032a0bc7681900170f67d8711b3942824531e765f91c2f5ce7937f999794"},
3406
+ {file = "pypdfium2-4.30.1-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:bbf9130a72370ee9d602e39949b902db669a2a1c24746a91e5586eb829055d9f"},
3407
+ {file = "pypdfium2-4.30.1-py3-none-musllinux_1_1_i686.whl", hash = "sha256:5cb52884b1583b96e94fd78542c63bb42e06df5e8f9e52f8f31f5ad5a1e53367"},
3408
+ {file = "pypdfium2-4.30.1-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:1a9e372bd4867ff223cc8c338e33fe11055dad12f22885950fc27646cc8d9122"},
3409
+ {file = "pypdfium2-4.30.1-py3-none-win32.whl", hash = "sha256:421f1cf205e213e07c1f2934905779547f4f4a2ff2f59dde29da3d511d3fc806"},
3410
+ {file = "pypdfium2-4.30.1-py3-none-win_amd64.whl", hash = "sha256:598a7f20264ab5113853cba6d86c4566e4356cad037d7d1f849c8c9021007e05"},
3411
+ {file = "pypdfium2-4.30.1-py3-none-win_arm64.whl", hash = "sha256:c2b6d63f6d425d9416c08d2511822b54b8e3ac38e639fc41164b1d75584b3a8c"},
3412
+ {file = "pypdfium2-4.30.1.tar.gz", hash = "sha256:5f5c7c6d03598e107d974f66b220a49436aceb191da34cda5f692be098a814ce"},
3413
  ]
3414
 
3415
  [[package]]
tests/builders/test_blank_page.py CHANGED
@@ -5,8 +5,8 @@ from marker.builders.layout import LayoutBuilder
5
  from marker.builders.ocr import OcrBuilder
6
 
7
 
8
- def test_blank_page(config, pdf_provider, layout_model, recognition_model, detection_model):
9
- layout_builder = LayoutBuilder(layout_model, config)
10
  builder = DocumentBuilder(config)
11
  document = builder.build_document(pdf_provider)
12
 
 
5
  from marker.builders.ocr import OcrBuilder
6
 
7
 
8
+ def test_blank_page(config, pdf_provider, layout_model, ocr_error_model, recognition_model, detection_model):
9
+ layout_builder = LayoutBuilder(layout_model, ocr_error_model, config)
10
  builder = DocumentBuilder(config)
11
  document = builder.build_document(pdf_provider)
12
 
tests/builders/test_garbled_pdf.py CHANGED
@@ -1,8 +1,9 @@
1
  import pytest
2
- from marker.schema import BlockTypes
3
 
 
 
 
4
 
5
- @pytest.mark.skip(reason="This is failing because we need better garbled text detection")
6
  @pytest.mark.filename("water_damage.pdf")
7
  def test_garbled_pdf(pdf_document):
8
  assert pdf_document.pages[0].structure[0] == '/page/0/Table/0'
@@ -18,3 +19,29 @@ def test_garbled_pdf(pdf_document):
18
  span = pdf_document.pages[0].get_block(table_cell.structure[0])
19
  assert span.block_type == BlockTypes.Span
20
  assert "комплекс" in span.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pytest
 
2
 
3
+ from marker.builders.document import DocumentBuilder
4
+ from marker.builders.layout import LayoutBuilder
5
+ from marker.schema import BlockTypes
6
 
 
7
  @pytest.mark.filename("water_damage.pdf")
8
  def test_garbled_pdf(pdf_document):
9
  assert pdf_document.pages[0].structure[0] == '/page/0/Table/0'
 
19
  span = pdf_document.pages[0].get_block(table_cell.structure[0])
20
  assert span.block_type == BlockTypes.Span
21
  assert "комплекс" in span.text
22
+
23
+
24
+ @pytest.mark.filename("hindi_judgement.pdf")
25
+ @pytest.mark.config({"page_range": [2, 3]})
26
+ def test_garbled_builder(config, pdf_provider, layout_model, ocr_error_model):
27
+ layout_builder = LayoutBuilder(layout_model, ocr_error_model, config)
28
+ builder = DocumentBuilder(config)
29
+ document = builder.build_document(pdf_provider)
30
+
31
+ bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines)
32
+ assert len(bad_ocr_results.labels) == 2
33
+ assert all([l == "bad" for l in bad_ocr_results.labels])
34
+
35
+
36
+ @pytest.mark.filename("adversarial.pdf")
37
+ @pytest.mark.config({"page_range": [2, 3]})
38
+ def test_nongarbled_builder(config, pdf_provider, layout_model, ocr_error_model):
39
+ layout_builder = LayoutBuilder(layout_model, ocr_error_model, config)
40
+ builder = DocumentBuilder(config)
41
+ document = builder.build_document(pdf_provider)
42
+
43
+ bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines)
44
+ assert len(bad_ocr_results.labels) == 2
45
+ assert all([l == "good" for l in bad_ocr_results.labels])
46
+
47
+
tests/conftest.py CHANGED
@@ -11,7 +11,7 @@ from marker.builders.ocr import OcrBuilder
11
  from marker.converters.pdf import PdfConverter
12
  from marker.models import setup_detection_model, setup_layout_model, \
13
  setup_recognition_model, setup_table_rec_model, \
14
- setup_texify_model
15
  from marker.schema import BlockTypes
16
  from marker.schema.blocks import Block
17
  from marker.renderers.markdown import MarkdownRenderer
@@ -54,6 +54,11 @@ def table_rec_model():
54
  yield table_rec_m
55
  del table_rec_m
56
 
 
 
 
 
 
57
 
58
  @pytest.fixture(scope="function")
59
  def config(request):
@@ -87,8 +92,8 @@ def pdf_provider(request, config, temp_pdf):
87
 
88
 
89
  @pytest.fixture(scope="function")
90
- def pdf_document(request, config, pdf_provider, layout_model, recognition_model, detection_model):
91
- layout_builder = LayoutBuilder(layout_model, config)
92
  ocr_builder = OcrBuilder(detection_model, recognition_model, config)
93
  builder = DocumentBuilder(config)
94
  document = builder(pdf_provider, layout_builder, ocr_builder)
@@ -96,13 +101,14 @@ def pdf_document(request, config, pdf_provider, layout_model, recognition_model,
96
 
97
 
98
  @pytest.fixture(scope="function")
99
- def pdf_converter(request, config, layout_model, texify_model, recognition_model, table_rec_model, detection_model, renderer):
100
  model_dict = {
101
  "layout_model": layout_model,
102
  "texify_model": texify_model,
103
  "recognition_model": recognition_model,
104
  "table_rec_model": table_rec_model,
105
- "detection_model": detection_model
 
106
  }
107
  yield PdfConverter(
108
  artifact_dict=model_dict,
 
11
  from marker.converters.pdf import PdfConverter
12
  from marker.models import setup_detection_model, setup_layout_model, \
13
  setup_recognition_model, setup_table_rec_model, \
14
+ setup_texify_model, setup_ocr_error_model
15
  from marker.schema import BlockTypes
16
  from marker.schema.blocks import Block
17
  from marker.renderers.markdown import MarkdownRenderer
 
54
  yield table_rec_m
55
  del table_rec_m
56
 
57
+ @pytest.fixture(scope="session")
58
+ def ocr_error_model():
59
+ ocr_error_m = setup_ocr_error_model()
60
+ yield ocr_error_m
61
+ del ocr_error_m
62
 
63
  @pytest.fixture(scope="function")
64
  def config(request):
 
92
 
93
 
94
  @pytest.fixture(scope="function")
95
+ def pdf_document(request, config, pdf_provider, layout_model, ocr_error_model, recognition_model, detection_model):
96
+ layout_builder = LayoutBuilder(layout_model, ocr_error_model, config)
97
  ocr_builder = OcrBuilder(detection_model, recognition_model, config)
98
  builder = DocumentBuilder(config)
99
  document = builder(pdf_provider, layout_builder, ocr_builder)
 
101
 
102
 
103
  @pytest.fixture(scope="function")
104
+ def pdf_converter(request, config, layout_model, texify_model, recognition_model, table_rec_model, detection_model, ocr_error_model, renderer):
105
  model_dict = {
106
  "layout_model": layout_model,
107
  "texify_model": texify_model,
108
  "recognition_model": recognition_model,
109
  "table_rec_model": table_rec_model,
110
+ "detection_model": detection_model,
111
+ "ocr_error_model": ocr_error_model
112
  }
113
  yield PdfConverter(
114
  artifact_dict=model_dict,