|
|
""" |
|
|
Docker Model Runner - FastAPI application with named endpoints |
|
|
""" |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from typing import Optional, List |
|
|
import torch |
|
|
from transformers import pipeline, AutoTokenizer, AutoModel |
|
|
import os |
|
|
from datetime import datetime |
|
|
|
|
|
app = FastAPI( |
|
|
title="Docker Model Runner", |
|
|
description="HuggingFace Space with named endpoints for model inference", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_NAME = os.getenv("MODEL_NAME", "distilbert-base-uncased") |
|
|
GENERATOR_MODEL = os.getenv("GENERATOR_MODEL", "gpt2") |
|
|
|
|
|
|
|
|
_classifier = None |
|
|
_generator = None |
|
|
_embedder = None |
|
|
|
|
|
|
|
|
def get_classifier(): |
|
|
global _classifier |
|
|
if _classifier is None: |
|
|
_classifier = pipeline("text-classification", model=MODEL_NAME) |
|
|
return _classifier |
|
|
|
|
|
|
|
|
def get_generator(): |
|
|
global _generator |
|
|
if _generator is None: |
|
|
_generator = pipeline("text-generation", model=GENERATOR_MODEL) |
|
|
return _generator |
|
|
|
|
|
|
|
|
def get_embedder(): |
|
|
global _embedder |
|
|
if _embedder is None: |
|
|
_embedder = { |
|
|
"tokenizer": AutoTokenizer.from_pretrained(MODEL_NAME), |
|
|
"model": AutoModel.from_pretrained(MODEL_NAME) |
|
|
} |
|
|
return _embedder |
|
|
|
|
|
|
|
|
|
|
|
class PredictRequest(BaseModel): |
|
|
text: str |
|
|
top_k: Optional[int] = 1 |
|
|
|
|
|
|
|
|
class PredictResponse(BaseModel): |
|
|
predictions: List[dict] |
|
|
model: str |
|
|
latency_ms: float |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
prompt: str |
|
|
max_length: Optional[int] = 50 |
|
|
num_return_sequences: Optional[int] = 1 |
|
|
temperature: Optional[float] = 1.0 |
|
|
|
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
generated_text: List[str] |
|
|
model: str |
|
|
latency_ms: float |
|
|
|
|
|
|
|
|
class EmbedRequest(BaseModel): |
|
|
texts: List[str] |
|
|
|
|
|
|
|
|
class EmbedResponse(BaseModel): |
|
|
embeddings: List[List[float]] |
|
|
model: str |
|
|
dimensions: int |
|
|
latency_ms: float |
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
timestamp: str |
|
|
gpu_available: bool |
|
|
|
|
|
|
|
|
class InfoResponse(BaseModel): |
|
|
name: str |
|
|
version: str |
|
|
models: dict |
|
|
endpoints: List[str] |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Welcome endpoint""" |
|
|
return { |
|
|
"message": "Docker Model Runner API", |
|
|
"docs": "/docs", |
|
|
"endpoints": ["/health", "/info", "/predict", "/generate", "/embed"] |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
async def health(): |
|
|
"""Health check endpoint""" |
|
|
return HealthResponse( |
|
|
status="healthy", |
|
|
timestamp=datetime.utcnow().isoformat(), |
|
|
gpu_available=torch.cuda.is_available() |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/info", response_model=InfoResponse) |
|
|
async def info(): |
|
|
"""Model and API information""" |
|
|
return InfoResponse( |
|
|
name="Docker Model Runner", |
|
|
version="1.0.0", |
|
|
models={ |
|
|
"classifier": MODEL_NAME, |
|
|
"generator": GENERATOR_MODEL, |
|
|
"embedder": MODEL_NAME |
|
|
}, |
|
|
endpoints=["/", "/health", "/info", "/predict", "/generate", "/embed"] |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/predict", response_model=PredictResponse) |
|
|
async def predict(request: PredictRequest): |
|
|
""" |
|
|
Run text classification prediction |
|
|
|
|
|
- **text**: Input text to classify |
|
|
- **top_k**: Number of top predictions to return |
|
|
""" |
|
|
try: |
|
|
start_time = datetime.now() |
|
|
classifier = get_classifier() |
|
|
results = classifier(request.text, top_k=request.top_k) |
|
|
latency = (datetime.now() - start_time).total_seconds() * 1000 |
|
|
|
|
|
return PredictResponse( |
|
|
predictions=results, |
|
|
model=MODEL_NAME, |
|
|
latency_ms=round(latency, 2) |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/generate", response_model=GenerateResponse) |
|
|
async def generate(request: GenerateRequest): |
|
|
""" |
|
|
Generate text from a prompt |
|
|
|
|
|
- **prompt**: Input prompt for generation |
|
|
- **max_length**: Maximum length of generated text |
|
|
- **num_return_sequences**: Number of sequences to generate |
|
|
- **temperature**: Sampling temperature |
|
|
""" |
|
|
try: |
|
|
start_time = datetime.now() |
|
|
generator = get_generator() |
|
|
results = generator( |
|
|
request.prompt, |
|
|
max_length=request.max_length, |
|
|
num_return_sequences=request.num_return_sequences, |
|
|
temperature=request.temperature, |
|
|
do_sample=True |
|
|
) |
|
|
latency = (datetime.now() - start_time).total_seconds() * 1000 |
|
|
|
|
|
generated_texts = [r["generated_text"] for r in results] |
|
|
|
|
|
return GenerateResponse( |
|
|
generated_text=generated_texts, |
|
|
model=GENERATOR_MODEL, |
|
|
latency_ms=round(latency, 2) |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/embed", response_model=EmbedResponse) |
|
|
async def embed(request: EmbedRequest): |
|
|
""" |
|
|
Get text embeddings |
|
|
|
|
|
- **texts**: List of texts to embed |
|
|
""" |
|
|
try: |
|
|
start_time = datetime.now() |
|
|
embedder = get_embedder() |
|
|
|
|
|
|
|
|
inputs = embedder["tokenizer"]( |
|
|
request.texts, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = embedder["model"](**inputs) |
|
|
|
|
|
embeddings = outputs.last_hidden_state.mean(dim=1) |
|
|
|
|
|
latency = (datetime.now() - start_time).total_seconds() * 1000 |
|
|
|
|
|
return EmbedResponse( |
|
|
embeddings=embeddings.tolist(), |
|
|
model=MODEL_NAME, |
|
|
dimensions=embeddings.shape[1], |
|
|
latency_ms=round(latency, 2) |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|