likhonsheikhdev commited on
Commit
51159ea
·
verified ·
1 Parent(s): ab0cf4f

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +64 -16
  2. main.py +276 -95
README.md CHANGED
@@ -11,35 +11,83 @@ pinned: false
11
 
12
  # Docker Model Runner
13
 
14
- A CPU-optimized Docker Space with named API endpoints for model inference.
15
 
16
  ## Hardware
17
  - **CPU Basic**: 2 vCPU · 16 GB RAM
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  ## Endpoints
20
 
21
  | Endpoint | Method | Description |
22
  |----------|--------|-------------|
23
- | `/` | GET | Welcome message |
 
 
24
  | `/health` | GET | Health check |
25
- | `/info` | GET | Model information |
26
  | `/predict` | POST | Text classification |
27
- | `/generate` | POST | Text generation |
28
  | `/embed` | POST | Text embeddings |
29
 
30
- ## Usage
31
 
32
- ```bash
33
- # Health Check
34
- curl https://likhonsheikhdev-docker-model-runner.hf.space/health
35
 
36
- # Prediction
37
- curl -X POST https://likhonsheikhdev-docker-model-runner.hf.space/predict \
38
- -H "Content-Type: application/json" \
39
- -d '{"text": "I love this product!"}'
40
 
41
- # Text Generation
42
- curl -X POST https://likhonsheikhdev-docker-model-runner.hf.space/generate \
43
- -H "Content-Type: application/json" \
44
- -d '{"prompt": "Once upon a time", "max_length": 50}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ```
 
11
 
12
  # Docker Model Runner
13
 
14
+ Anthropic & OpenAI API compatible Docker Space with named endpoints.
15
 
16
  ## Hardware
17
  - **CPU Basic**: 2 vCPU · 16 GB RAM
18
 
19
+ ## API Compatibility
20
+
21
+ ### Anthropic Messages API
22
+ ```bash
23
+ curl -X POST https://likhonsheikhdev-docker-model-runner.hf.space/v1/messages \
24
+ -H "Content-Type: application/json" \
25
+ -H "x-api-key: your-key" \
26
+ -d '{
27
+ "model": "distilgpt2",
28
+ "max_tokens": 256,
29
+ "messages": [
30
+ {"role": "user", "content": "Hello, how are you?"}
31
+ ]
32
+ }'
33
+ ```
34
+
35
+ ### OpenAI Chat Completions API
36
+ ```bash
37
+ curl -X POST https://likhonsheikhdev-docker-model-runner.hf.space/v1/chat/completions \
38
+ -H "Content-Type: application/json" \
39
+ -H "Authorization: Bearer your-key" \
40
+ -d '{
41
+ "model": "distilgpt2",
42
+ "messages": [
43
+ {"role": "user", "content": "Hello, how are you?"}
44
+ ]
45
+ }'
46
+ ```
47
+
48
  ## Endpoints
49
 
50
  | Endpoint | Method | Description |
51
  |----------|--------|-------------|
52
+ | `/v1/messages` | POST | Anthropic Messages API |
53
+ | `/v1/chat/completions` | POST | OpenAI Chat API |
54
+ | `/v1/models` | GET | List available models |
55
  | `/health` | GET | Health check |
56
+ | `/info` | GET | API information |
57
  | `/predict` | POST | Text classification |
 
58
  | `/embed` | POST | Text embeddings |
59
 
60
+ ## Python SDK Usage
61
 
62
+ ### With Anthropic SDK
63
+ ```python
64
+ from anthropic import Anthropic
65
 
66
+ client = Anthropic(
67
+ api_key="any-key",
68
+ base_url="https://likhonsheikhdev-docker-model-runner.hf.space"
69
+ )
70
 
71
+ message = client.messages.create(
72
+ model="distilgpt2",
73
+ max_tokens=256,
74
+ messages=[{"role": "user", "content": "Hello!"}]
75
+ )
76
+ print(message.content[0].text)
77
+ ```
78
+
79
+ ### With OpenAI SDK
80
+ ```python
81
+ from openai import OpenAI
82
+
83
+ client = OpenAI(
84
+ api_key="any-key",
85
+ base_url="https://likhonsheikhdev-docker-model-runner.hf.space/v1"
86
+ )
87
+
88
+ response = client.chat.completions.create(
89
+ model="distilgpt2",
90
+ messages=[{"role": "user", "content": "Hello!"}]
91
+ )
92
+ print(response.choices[0].message.content)
93
  ```
main.py CHANGED
@@ -1,15 +1,18 @@
1
  """
2
  Docker Model Runner - CPU-Optimized FastAPI application
 
3
  Optimized for: 2 vCPU, 16GB RAM
4
  """
5
- from fastapi import FastAPI, HTTPException
6
- from pydantic import BaseModel
7
- from typing import Optional, List
8
  import torch
9
- from transformers import pipeline, AutoTokenizer, AutoModel
10
  import os
11
  from datetime import datetime
12
  from contextlib import asynccontextmanager
 
 
13
 
14
  # CPU-optimized lightweight models
15
  MODEL_NAME = os.getenv("MODEL_NAME", "distilbert-base-uncased-finetuned-sst-2-english")
@@ -28,25 +31,27 @@ def load_models():
28
  global models
29
  print("Loading models for CPU inference...")
30
 
31
- # Use smaller, faster models optimized for CPU
32
  models["classifier"] = pipeline(
33
  "text-classification",
34
  model=MODEL_NAME,
35
- device=-1, # CPU
36
- torch_dtype=torch.float32
37
- )
38
-
39
- models["generator"] = pipeline(
40
- "text-generation",
41
- model=GENERATOR_MODEL,
42
  device=-1,
43
  torch_dtype=torch.float32
44
  )
45
 
46
- # Lightweight embedding model
47
- models["tokenizer"] = AutoTokenizer.from_pretrained(EMBED_MODEL)
48
- models["embedder"] = AutoModel.from_pretrained(EMBED_MODEL)
49
- models["embedder"].eval()
 
 
 
 
 
 
 
 
 
50
 
51
  print("✅ All models loaded successfully!")
52
 
@@ -60,33 +65,91 @@ async def lifespan(app: FastAPI):
60
 
61
  app = FastAPI(
62
  title="Docker Model Runner",
63
- description="CPU-Optimized HuggingFace Space with named endpoints",
64
  version="1.0.0",
65
  lifespan=lifespan
66
  )
67
 
68
 
69
- # Request/Response Models
70
- class PredictRequest(BaseModel):
 
 
71
  text: str
72
- top_k: Optional[int] = 1
73
 
74
 
75
- class PredictResponse(BaseModel):
76
- predictions: List[dict]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  model: str
78
- latency_ms: float
 
 
 
 
 
 
 
 
 
79
 
80
 
81
- class GenerateRequest(BaseModel):
82
- prompt: str
83
- max_length: Optional[int] = 50
84
- num_return_sequences: Optional[int] = 1
85
  temperature: Optional[float] = 0.7
 
 
86
 
87
 
88
- class GenerateResponse(BaseModel):
89
- generated_text: List[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  model: str
91
  latency_ms: float
92
 
@@ -109,23 +172,178 @@ class HealthResponse(BaseModel):
109
  models_loaded: bool
110
 
111
 
112
- class InfoResponse(BaseModel):
113
- name: str
114
- version: str
115
- hardware: str
116
- models: dict
117
- endpoints: List[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- # Named Endpoints
121
  @app.get("/")
122
  async def root():
123
  """Welcome endpoint"""
124
  return {
125
- "message": "Docker Model Runner API (CPU Optimized)",
126
  "hardware": "CPU Basic: 2 vCPU · 16 GB RAM",
127
  "docs": "/docs",
128
- "endpoints": ["/health", "/info", "/predict", "/generate", "/embed"]
 
 
 
 
 
129
  }
130
 
131
 
@@ -140,30 +358,32 @@ async def health():
140
  )
141
 
142
 
143
- @app.get("/info", response_model=InfoResponse)
144
  async def info():
145
  """Model and API information"""
146
- return InfoResponse(
147
- name="Docker Model Runner",
148
- version="1.0.0",
149
- hardware="CPU Basic: 2 vCPU · 16 GB RAM",
150
- models={
 
 
151
  "classifier": MODEL_NAME,
152
- "generator": GENERATOR_MODEL,
153
  "embedder": EMBED_MODEL
154
  },
155
- endpoints=["/", "/health", "/info", "/predict", "/generate", "/embed"]
156
- )
 
 
 
 
 
 
157
 
158
 
159
  @app.post("/predict", response_model=PredictResponse)
160
  async def predict(request: PredictRequest):
161
- """
162
- Run text classification (sentiment analysis)
163
-
164
- - **text**: Input text to classify
165
- - **top_k**: Number of top predictions to return
166
- """
167
  try:
168
  start_time = datetime.now()
169
  results = models["classifier"](request.text, top_k=request.top_k)
@@ -178,50 +398,13 @@ async def predict(request: PredictRequest):
178
  raise HTTPException(status_code=500, detail=str(e))
179
 
180
 
181
- @app.post("/generate", response_model=GenerateResponse)
182
- async def generate(request: GenerateRequest):
183
- """
184
- Generate text from a prompt
185
-
186
- - **prompt**: Input prompt for generation
187
- - **max_length**: Maximum length of generated text (default: 50)
188
- - **temperature**: Sampling temperature (default: 0.7)
189
- """
190
- try:
191
- start_time = datetime.now()
192
- results = models["generator"](
193
- request.prompt,
194
- max_length=request.max_length,
195
- num_return_sequences=request.num_return_sequences,
196
- temperature=request.temperature,
197
- do_sample=True,
198
- pad_token_id=50256 # GPT2 pad token
199
- )
200
- latency = (datetime.now() - start_time).total_seconds() * 1000
201
-
202
- generated_texts = [r["generated_text"] for r in results]
203
-
204
- return GenerateResponse(
205
- generated_text=generated_texts,
206
- model=GENERATOR_MODEL,
207
- latency_ms=round(latency, 2)
208
- )
209
- except Exception as e:
210
- raise HTTPException(status_code=500, detail=str(e))
211
-
212
-
213
  @app.post("/embed", response_model=EmbedResponse)
214
  async def embed(request: EmbedRequest):
215
- """
216
- Get text embeddings using MiniLM (384 dimensions)
217
-
218
- - **texts**: List of texts to embed
219
- """
220
  try:
221
  start_time = datetime.now()
222
 
223
- # Tokenize
224
- inputs = models["tokenizer"](
225
  request.texts,
226
  padding=True,
227
  truncation=True,
@@ -229,10 +412,8 @@ async def embed(request: EmbedRequest):
229
  return_tensors="pt"
230
  )
231
 
232
- # Get embeddings
233
  with torch.no_grad():
234
- outputs = models["embedder"](**inputs)
235
- # Mean pooling
236
  attention_mask = inputs["attention_mask"]
237
  token_embeddings = outputs.last_hidden_state
238
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
 
1
  """
2
  Docker Model Runner - CPU-Optimized FastAPI application
3
+ Compatible with Anthropic API format
4
  Optimized for: 2 vCPU, 16GB RAM
5
  """
6
+ from fastapi import FastAPI, HTTPException, Header
7
+ from pydantic import BaseModel, Field
8
+ from typing import Optional, List, Union, Literal
9
  import torch
10
+ from transformers import pipeline, AutoTokenizer, AutoModel, AutoModelForCausalLM
11
  import os
12
  from datetime import datetime
13
  from contextlib import asynccontextmanager
14
+ import uuid
15
+ import time
16
 
17
  # CPU-optimized lightweight models
18
  MODEL_NAME = os.getenv("MODEL_NAME", "distilbert-base-uncased-finetuned-sst-2-english")
 
31
  global models
32
  print("Loading models for CPU inference...")
33
 
34
+ # Classifier
35
  models["classifier"] = pipeline(
36
  "text-classification",
37
  model=MODEL_NAME,
 
 
 
 
 
 
 
38
  device=-1,
39
  torch_dtype=torch.float32
40
  )
41
 
42
+ # Generator with tokenizer for chat
43
+ models["generator_tokenizer"] = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
44
+ models["generator_model"] = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL)
45
+ models["generator_model"].eval()
46
+
47
+ # Set pad token
48
+ if models["generator_tokenizer"].pad_token is None:
49
+ models["generator_tokenizer"].pad_token = models["generator_tokenizer"].eos_token
50
+
51
+ # Embedding model
52
+ models["embed_tokenizer"] = AutoTokenizer.from_pretrained(EMBED_MODEL)
53
+ models["embed_model"] = AutoModel.from_pretrained(EMBED_MODEL)
54
+ models["embed_model"].eval()
55
 
56
  print("✅ All models loaded successfully!")
57
 
 
65
 
66
  app = FastAPI(
67
  title="Docker Model Runner",
68
+ description="Anthropic API Compatible - CPU-Optimized HuggingFace Space",
69
  version="1.0.0",
70
  lifespan=lifespan
71
  )
72
 
73
 
74
+ # ============== Anthropic API Models ==============
75
+
76
+ class ContentBlock(BaseModel):
77
+ type: Literal["text"] = "text"
78
  text: str
 
79
 
80
 
81
+ class MessageContent(BaseModel):
82
+ role: Literal["user", "assistant"]
83
+ content: Union[str, List[ContentBlock]]
84
+
85
+
86
+ class AnthropicRequest(BaseModel):
87
+ model: str = "distilgpt2"
88
+ messages: List[MessageContent]
89
+ max_tokens: int = 1024
90
+ temperature: Optional[float] = 0.7
91
+ top_p: Optional[float] = 1.0
92
+ stop_sequences: Optional[List[str]] = None
93
+ stream: Optional[bool] = False
94
+ system: Optional[str] = None
95
+
96
+
97
+ class Usage(BaseModel):
98
+ input_tokens: int
99
+ output_tokens: int
100
+
101
+
102
+ class AnthropicResponse(BaseModel):
103
+ id: str
104
+ type: Literal["message"] = "message"
105
+ role: Literal["assistant"] = "assistant"
106
+ content: List[ContentBlock]
107
  model: str
108
+ stop_reason: Literal["end_turn", "max_tokens", "stop_sequence"] = "end_turn"
109
+ stop_sequence: Optional[str] = None
110
+ usage: Usage
111
+
112
+
113
+ # ============== OpenAI Compatible Models ==============
114
+
115
+ class ChatMessage(BaseModel):
116
+ role: str
117
+ content: str
118
 
119
 
120
+ class ChatCompletionRequest(BaseModel):
121
+ model: str = "distilgpt2"
122
+ messages: List[ChatMessage]
123
+ max_tokens: Optional[int] = 1024
124
  temperature: Optional[float] = 0.7
125
+ top_p: Optional[float] = 1.0
126
+ stream: Optional[bool] = False
127
 
128
 
129
+ class ChatChoice(BaseModel):
130
+ index: int = 0
131
+ message: ChatMessage
132
+ finish_reason: str = "stop"
133
+
134
+
135
+ class ChatCompletionResponse(BaseModel):
136
+ id: str
137
+ object: str = "chat.completion"
138
+ created: int
139
+ model: str
140
+ choices: List[ChatChoice]
141
+ usage: dict
142
+
143
+
144
+ # ============== Other Request/Response Models ==============
145
+
146
+ class PredictRequest(BaseModel):
147
+ text: str
148
+ top_k: Optional[int] = 1
149
+
150
+
151
+ class PredictResponse(BaseModel):
152
+ predictions: List[dict]
153
  model: str
154
  latency_ms: float
155
 
 
172
  models_loaded: bool
173
 
174
 
175
+ class ModelInfo(BaseModel):
176
+ id: str
177
+ object: str = "model"
178
+ created: int
179
+ owned_by: str = "local"
180
+
181
+
182
+ class ModelsResponse(BaseModel):
183
+ object: str = "list"
184
+ data: List[ModelInfo]
185
+
186
+
187
+ # ============== Helper Functions ==============
188
+
189
+ def generate_text(prompt: str, max_tokens: int, temperature: float, top_p: float) -> tuple:
190
+ """Generate text and return (text, input_tokens, output_tokens)"""
191
+ tokenizer = models["generator_tokenizer"]
192
+ model = models["generator_model"]
193
+
194
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
195
+ input_tokens = inputs["input_ids"].shape[1]
196
+
197
+ with torch.no_grad():
198
+ outputs = model.generate(
199
+ **inputs,
200
+ max_new_tokens=max_tokens,
201
+ temperature=temperature if temperature > 0 else 1.0,
202
+ top_p=top_p,
203
+ do_sample=temperature > 0,
204
+ pad_token_id=tokenizer.pad_token_id,
205
+ eos_token_id=tokenizer.eos_token_id
206
+ )
207
+
208
+ generated_tokens = outputs[0][input_tokens:]
209
+ output_tokens = len(generated_tokens)
210
+ generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
211
+
212
+ return generated_text.strip(), input_tokens, output_tokens
213
+
214
+
215
+ def format_messages_to_prompt(messages: List, system: Optional[str] = None) -> str:
216
+ """Convert chat messages to a single prompt string"""
217
+ prompt_parts = []
218
+
219
+ if system:
220
+ prompt_parts.append(f"System: {system}\n")
221
+
222
+ for msg in messages:
223
+ role = msg.role if hasattr(msg, 'role') else msg.get('role', 'user')
224
+ content = msg.content if hasattr(msg, 'content') else msg.get('content', '')
225
+
226
+ # Handle content that might be a list of blocks
227
+ if isinstance(content, list):
228
+ content = " ".join([block.text if hasattr(block, 'text') else block.get('text', '') for block in content])
229
+
230
+ if role == "user":
231
+ prompt_parts.append(f"Human: {content}\n")
232
+ elif role == "assistant":
233
+ prompt_parts.append(f"Assistant: {content}\n")
234
+
235
+ prompt_parts.append("Assistant:")
236
+ return "".join(prompt_parts)
237
+
238
+
239
+ # ============== Anthropic API Endpoints ==============
240
+
241
+ @app.post("/v1/messages", response_model=AnthropicResponse)
242
+ async def create_message(
243
+ request: AnthropicRequest,
244
+ x_api_key: Optional[str] = Header(None, alias="x-api-key"),
245
+ authorization: Optional[str] = Header(None)
246
+ ):
247
+ """
248
+ Anthropic Messages API compatible endpoint
249
+
250
+ POST /v1/messages
251
+ """
252
+ try:
253
+ # Format messages to prompt
254
+ prompt = format_messages_to_prompt(request.messages, request.system)
255
+
256
+ # Generate response
257
+ generated_text, input_tokens, output_tokens = generate_text(
258
+ prompt=prompt,
259
+ max_tokens=request.max_tokens,
260
+ temperature=request.temperature or 0.7,
261
+ top_p=request.top_p or 1.0
262
+ )
263
+
264
+ return AnthropicResponse(
265
+ id=f"msg_{uuid.uuid4().hex[:24]}",
266
+ content=[ContentBlock(type="text", text=generated_text)],
267
+ model=GENERATOR_MODEL,
268
+ stop_reason="end_turn",
269
+ usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens)
270
+ )
271
+ except Exception as e:
272
+ raise HTTPException(status_code=500, detail=str(e))
273
+
274
+
275
+ # ============== OpenAI Compatible Endpoints ==============
276
+
277
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
278
+ async def chat_completions(
279
+ request: ChatCompletionRequest,
280
+ authorization: Optional[str] = Header(None)
281
+ ):
282
+ """
283
+ OpenAI Chat Completions API compatible endpoint
284
 
285
+ POST /v1/chat/completions
286
+ """
287
+ try:
288
+ # Format messages to prompt
289
+ prompt = format_messages_to_prompt(request.messages)
290
+
291
+ # Generate response
292
+ generated_text, input_tokens, output_tokens = generate_text(
293
+ prompt=prompt,
294
+ max_tokens=request.max_tokens or 1024,
295
+ temperature=request.temperature or 0.7,
296
+ top_p=request.top_p or 1.0
297
+ )
298
+
299
+ return ChatCompletionResponse(
300
+ id=f"chatcmpl-{uuid.uuid4().hex[:24]}",
301
+ created=int(time.time()),
302
+ model=GENERATOR_MODEL,
303
+ choices=[
304
+ ChatChoice(
305
+ index=0,
306
+ message=ChatMessage(role="assistant", content=generated_text),
307
+ finish_reason="stop"
308
+ )
309
+ ],
310
+ usage={
311
+ "prompt_tokens": input_tokens,
312
+ "completion_tokens": output_tokens,
313
+ "total_tokens": input_tokens + output_tokens
314
+ }
315
+ )
316
+ except Exception as e:
317
+ raise HTTPException(status_code=500, detail=str(e))
318
+
319
+
320
+ @app.get("/v1/models", response_model=ModelsResponse)
321
+ async def list_models():
322
+ """List available models (OpenAI compatible)"""
323
+ return ModelsResponse(
324
+ data=[
325
+ ModelInfo(id=GENERATOR_MODEL, created=int(time.time())),
326
+ ModelInfo(id=MODEL_NAME, created=int(time.time())),
327
+ ModelInfo(id=EMBED_MODEL, created=int(time.time()))
328
+ ]
329
+ )
330
+
331
+
332
+ # ============== Original Endpoints ==============
333
 
 
334
  @app.get("/")
335
  async def root():
336
  """Welcome endpoint"""
337
  return {
338
+ "message": "Docker Model Runner API (Anthropic Compatible)",
339
  "hardware": "CPU Basic: 2 vCPU · 16 GB RAM",
340
  "docs": "/docs",
341
+ "api_endpoints": {
342
+ "anthropic": "/v1/messages",
343
+ "openai": "/v1/chat/completions",
344
+ "models": "/v1/models"
345
+ },
346
+ "utility_endpoints": ["/health", "/info", "/predict", "/embed"]
347
  }
348
 
349
 
 
358
  )
359
 
360
 
361
+ @app.get("/info")
362
  async def info():
363
  """Model and API information"""
364
+ return {
365
+ "name": "Docker Model Runner",
366
+ "version": "1.0.0",
367
+ "api_compatibility": ["anthropic", "openai"],
368
+ "hardware": "CPU Basic: 2 vCPU · 16 GB RAM",
369
+ "models": {
370
+ "chat": GENERATOR_MODEL,
371
  "classifier": MODEL_NAME,
 
372
  "embedder": EMBED_MODEL
373
  },
374
+ "endpoints": {
375
+ "anthropic_messages": "POST /v1/messages",
376
+ "openai_chat": "POST /v1/chat/completions",
377
+ "models": "GET /v1/models",
378
+ "predict": "POST /predict",
379
+ "embed": "POST /embed"
380
+ }
381
+ }
382
 
383
 
384
  @app.post("/predict", response_model=PredictResponse)
385
  async def predict(request: PredictRequest):
386
+ """Text classification (sentiment analysis)"""
 
 
 
 
 
387
  try:
388
  start_time = datetime.now()
389
  results = models["classifier"](request.text, top_k=request.top_k)
 
398
  raise HTTPException(status_code=500, detail=str(e))
399
 
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  @app.post("/embed", response_model=EmbedResponse)
402
  async def embed(request: EmbedRequest):
403
+ """Get text embeddings"""
 
 
 
 
404
  try:
405
  start_time = datetime.now()
406
 
407
+ inputs = models["embed_tokenizer"](
 
408
  request.texts,
409
  padding=True,
410
  truncation=True,
 
412
  return_tensors="pt"
413
  )
414
 
 
415
  with torch.no_grad():
416
+ outputs = models["embed_model"](**inputs)
 
417
  attention_mask = inputs["attention_mask"]
418
  token_embeddings = outputs.last_hidden_state
419
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()