denisshepelin commited on
Commit
a6f6ed1
·
1 Parent(s): 23e68a6

Add OpenAI-like service

Browse files
Files changed (2) hide show
  1. marker/services/openai.py +113 -0
  2. pyproject.toml +1 -0
marker/services/openai.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import time
4
+ from io import BytesIO
5
+ from typing import Annotated, List, Union
6
+
7
+ import openai
8
+ import PIL
9
+ from openai import APITimeoutError, RateLimitError
10
+ from PIL import Image
11
+ from pydantic import BaseModel
12
+
13
+ from marker.schema.blocks import Block
14
+ from marker.services import BaseService
15
+
16
+
17
+ class OpenAIService(BaseService):
18
+ openai_base_url: Annotated[
19
+ str, "The base url to use for OpenAI-like models. No trailing slash."
20
+ ] = "https://openrouter.ai/api/v1"
21
+ openai_model: Annotated[str, "The model name to use for OpenAI-like model."] = (
22
+ "openai/gpt-4o-mini"
23
+ )
24
+ openai_key: Annotated[str, "The API key to use for the OpenAI-like service."] = None
25
+
26
+ def image_to_base64(self, image: PIL.Image.Image):
27
+ image_bytes = BytesIO()
28
+ image.save(image_bytes, format="WEBP")
29
+ return base64.b64encode(image_bytes.getvalue()).decode("utf-8")
30
+
31
+ def prepare_images(
32
+ self, images: Union[Image.Image, List[Image.Image]]
33
+ ) -> List[dict]:
34
+ if isinstance(images, Image.Image):
35
+ images = [images]
36
+
37
+ return [
38
+ {
39
+ "type": "image_url",
40
+ "image_url": "data:image/webp;base64,{}".format(
41
+ self.image_to_base64(img)
42
+ ),
43
+ }
44
+ for img in images
45
+ ]
46
+
47
+ def __call__(
48
+ self,
49
+ prompt: str,
50
+ image: PIL.Image.Image | List[PIL.Image.Image],
51
+ block: Block,
52
+ response_schema: type[BaseModel],
53
+ max_retries: int | None = None,
54
+ timeout: int | None = None,
55
+ ):
56
+ if max_retries is None:
57
+ max_retries = self.max_retries
58
+
59
+ if timeout is None:
60
+ timeout = self.timeout
61
+
62
+ if not isinstance(image, list):
63
+ image = [image]
64
+
65
+ client = self.get_client()
66
+ image_data = self.prepare_images(image)
67
+
68
+ messages = [
69
+ {
70
+ "role": "user",
71
+ "content": [
72
+ *image_data,
73
+ {"type": "text", "text": prompt},
74
+ ],
75
+ }
76
+ ]
77
+
78
+ tries = 0
79
+ while tries < max_retries:
80
+ try:
81
+ response = client.beta.chat.completions.parse(
82
+ extra_headers={
83
+ "X-Title": "Marker",
84
+ "HTTP-Referer": "https://github.com/VikParuchuri/marker",
85
+ },
86
+ model=self.openai_model,
87
+ messages=messages,
88
+ timeout=timeout,
89
+ response_format=response_schema,
90
+ )
91
+ response_text = (
92
+ response.choices[0].message.content
93
+ ) # json string. In message.parsed you can find Pydantic object
94
+ # by default openai-like models respond with response.usage field
95
+ total_tokens = response.usage.total_tokens
96
+ block.update_metadata(llm_tokens_used=total_tokens, llm_request_count=1)
97
+ return json.loads(response_text)
98
+ except (APITimeoutError, RateLimitError) as e:
99
+ # Rate limit exceeded
100
+ tries += 1
101
+ wait_time = tries * 3
102
+ print(
103
+ f"Rate limit error: {e}. Retrying in {wait_time} seconds... (Attempt {tries}/{max_retries})"
104
+ )
105
+ time.sleep(wait_time)
106
+ except Exception as e:
107
+ print(e)
108
+ break
109
+
110
+ return {}
111
+
112
+ def get_client(self) -> openai.OpenAI:
113
+ return openai.OpenAI(api_key=self.openai_key, base_url=self.openai_base_url)
pyproject.toml CHANGED
@@ -43,6 +43,7 @@ openpyxl = {version = "^3.1.5", optional = true}
43
  python-pptx = {version = "^1.0.2", optional = true}
44
  ebooklib = {version = "^0.18", optional = true}
45
  weasyprint = {version = "^63.1", optional = true}
 
46
 
47
  [tool.poetry.group.dev.dependencies]
48
  jupyter = "^1.0.0"
 
43
  python-pptx = {version = "^1.0.2", optional = true}
44
  ebooklib = {version = "^0.18", optional = true}
45
  weasyprint = {version = "^63.1", optional = true}
46
+ openai = "^1.65.2"
47
 
48
  [tool.poetry.group.dev.dependencies]
49
  jupyter = "^1.0.0"