Vik Paruchuri commited on
Commit
5c982c9
·
1 Parent(s): 94b8583

Add tests for extraction converter

Browse files
extraction_app.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from marker.scripts.run_streamlit_app import extraction_app_cli
2
+
3
+ if __name__ == "__main__":
4
+ extraction_app_cli()
marker/converters/extraction.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import re
2
 
3
  from marker.builders.document import DocumentBuilder
@@ -8,7 +9,7 @@ from marker.converters.pdf import PdfConverter
8
  from marker.extractors.page import PageExtractor, json_schema_to_base_model
9
  from marker.providers.registry import provider_from_filepath
10
 
11
- from marker.renderers.extraction import ExtractionMerger
12
  from marker.renderers.markdown import MarkdownRenderer
13
 
14
  from marker.logger import get_logger
@@ -36,13 +37,13 @@ class ExtractionConverter(PdfConverter):
36
 
37
  return document, provider
38
 
39
- def __call__(self, filepath: str) -> str:
40
  self.config["paginate_output"] = True # Ensure we can split the output properly
41
  self.config["output_format"] = (
42
  "markdown" # Output must be markdown for extraction
43
  )
44
  try:
45
- json_schema_to_base_model(self.config["page_schema"])
46
  except Exception as e:
47
  logger.error(f"Could not parse page schema: {e}")
48
  raise ValueError(
 
1
+ import json
2
  import re
3
 
4
  from marker.builders.document import DocumentBuilder
 
9
  from marker.extractors.page import PageExtractor, json_schema_to_base_model
10
  from marker.providers.registry import provider_from_filepath
11
 
12
+ from marker.renderers.extraction import ExtractionMerger, ExtractionOutput
13
  from marker.renderers.markdown import MarkdownRenderer
14
 
15
  from marker.logger import get_logger
 
37
 
38
  return document, provider
39
 
40
+ def __call__(self, filepath: str) -> ExtractionOutput:
41
  self.config["paginate_output"] = True # Ensure we can split the output properly
42
  self.config["output_format"] = (
43
  "markdown" # Output must be markdown for extraction
44
  )
45
  try:
46
+ json_schema_to_base_model(json.loads(self.config["page_schema"]))
47
  except Exception as e:
48
  logger.error(f"Could not parse page schema: {e}")
49
  raise ValueError(
marker/renderers/extraction.py CHANGED
@@ -45,7 +45,7 @@ class ExtractionMerger:
45
  def __init__(self):
46
  pass
47
 
48
- def __call__(self, outputs: Dict[int, ExtractionResult]):
49
  pnums = sorted(list(outputs.keys()))
50
  merged_result = outputs[pnums[0]].extracted_data.copy()
51
  confidence_exists = outputs[pnums[0]].existence_confidence
 
45
  def __init__(self):
46
  pass
47
 
48
+ def __call__(self, outputs: Dict[int, ExtractionResult]) -> ExtractionOutput:
49
  pnums = sorted(list(outputs.keys()))
50
  merged_result = outputs[pnums[0]].extracted_data.copy()
51
  confidence_exists = outputs[pnums[0]].existence_confidence
pyproject.toml CHANGED
@@ -70,6 +70,7 @@ marker = "marker.scripts.convert:convert_cli"
70
  marker_single = "marker.scripts.convert_single:convert_single_cli"
71
  marker_chunk_convert = "marker.scripts.chunk_convert:chunk_convert_cli"
72
  marker_gui = "marker.scripts.run_streamlit_app:streamlit_app_cli"
 
73
  marker_server = "marker.scripts.server:server_cli"
74
 
75
  [build-system]
 
70
  marker_single = "marker.scripts.convert_single:convert_single_cli"
71
  marker_chunk_convert = "marker.scripts.chunk_convert:chunk_convert_cli"
72
  marker_gui = "marker.scripts.run_streamlit_app:streamlit_app_cli"
73
+ marker_extract = "marker.scripts.run_streamlit_app:extraction_app_cli"
74
  marker_server = "marker.scripts.server:server_cli"
75
 
76
  [build-system]
tests/converters/test_extraction_converter.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pytest
3
+
4
+ from marker.converters.extraction import ExtractionConverter
5
+ from marker.extractors.page import PageExtractionSchema
6
+ from marker.services import BaseService
7
+
8
+
9
+ class MockLLMService(BaseService):
10
+ def __call__(self, prompt, image=None, page=None, response_schema=None, **kwargs):
11
+ assert response_schema == PageExtractionSchema
12
+ return {
13
+ "description": "Mock extraction description",
14
+ "extracted_json": json.dumps({"test_key": "test_value"}),
15
+ "existence_confidence": 5,
16
+ "value_confidence": 5,
17
+ }
18
+
19
+
20
+ @pytest.fixture
21
+ def mock_llm_service():
22
+ return MockLLMService
23
+
24
+
25
+ @pytest.fixture
26
+ def extraction_converter(config, model_dict, mock_llm_service):
27
+ test_schema = {
28
+ "title": "TestSchema",
29
+ "type": "object",
30
+ "properties": {"test_key": {"title": "Test Key", "type": "string"}},
31
+ "required": ["test_key"],
32
+ }
33
+
34
+ config["page_schema"] = json.dumps(test_schema)
35
+ config["output_format"] = "markdown"
36
+ model_dict["llm_service"] = mock_llm_service
37
+
38
+ converter = ExtractionConverter(
39
+ artifact_dict=model_dict, processor_list=None, config=config
40
+ )
41
+ converter.default_llm_service = MockLLMService
42
+ return converter
43
+
44
+
45
+ @pytest.mark.config({"page_range": [0]})
46
+ def test_extraction_converter_invalid_schema(
47
+ config, model_dict, mock_llm_service, temp_doc
48
+ ):
49
+ config["page_schema"] = "invalid json"
50
+
51
+ model_dict["llm_service"] = mock_llm_service
52
+ converter = ExtractionConverter(
53
+ artifact_dict=model_dict, processor_list=None, config=config
54
+ )
55
+
56
+ with pytest.raises(ValueError):
57
+ converter(temp_doc.name)
58
+
59
+
60
+ @pytest.mark.config({"page_range": [0, 1]})
61
+ def test_extraction_converter_multiple_pages(extraction_converter, temp_doc):
62
+ result = extraction_converter(temp_doc.name)
63
+
64
+ assert result is not None
65
+ assert result.json is not None
66
+ assert result.json == {"test_key": "test_value"}