likhonsheikhdev's picture
Upload folder using huggingface_hub
09b5534 verified
raw
history blame
5.89 kB
"""
Docker Model Runner - FastAPI application with named endpoints
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List
import torch
from transformers import pipeline, AutoTokenizer, AutoModel
import os
from datetime import datetime
app = FastAPI(
title="Docker Model Runner",
description="HuggingFace Space with named endpoints for model inference",
version="1.0.0"
)
# Model configurations
MODEL_NAME = os.getenv("MODEL_NAME", "distilbert-base-uncased")
GENERATOR_MODEL = os.getenv("GENERATOR_MODEL", "gpt2")
# Lazy-loaded pipelines
_classifier = None
_generator = None
_embedder = None
def get_classifier():
global _classifier
if _classifier is None:
_classifier = pipeline("text-classification", model=MODEL_NAME)
return _classifier
def get_generator():
global _generator
if _generator is None:
_generator = pipeline("text-generation", model=GENERATOR_MODEL)
return _generator
def get_embedder():
global _embedder
if _embedder is None:
_embedder = {
"tokenizer": AutoTokenizer.from_pretrained(MODEL_NAME),
"model": AutoModel.from_pretrained(MODEL_NAME)
}
return _embedder
# Request/Response Models
class PredictRequest(BaseModel):
text: str
top_k: Optional[int] = 1
class PredictResponse(BaseModel):
predictions: List[dict]
model: str
latency_ms: float
class GenerateRequest(BaseModel):
prompt: str
max_length: Optional[int] = 50
num_return_sequences: Optional[int] = 1
temperature: Optional[float] = 1.0
class GenerateResponse(BaseModel):
generated_text: List[str]
model: str
latency_ms: float
class EmbedRequest(BaseModel):
texts: List[str]
class EmbedResponse(BaseModel):
embeddings: List[List[float]]
model: str
dimensions: int
latency_ms: float
class HealthResponse(BaseModel):
status: str
timestamp: str
gpu_available: bool
class InfoResponse(BaseModel):
name: str
version: str
models: dict
endpoints: List[str]
# Named Endpoints
@app.get("/")
async def root():
"""Welcome endpoint"""
return {
"message": "Docker Model Runner API",
"docs": "/docs",
"endpoints": ["/health", "/info", "/predict", "/generate", "/embed"]
}
@app.get("/health", response_model=HealthResponse)
async def health():
"""Health check endpoint"""
return HealthResponse(
status="healthy",
timestamp=datetime.utcnow().isoformat(),
gpu_available=torch.cuda.is_available()
)
@app.get("/info", response_model=InfoResponse)
async def info():
"""Model and API information"""
return InfoResponse(
name="Docker Model Runner",
version="1.0.0",
models={
"classifier": MODEL_NAME,
"generator": GENERATOR_MODEL,
"embedder": MODEL_NAME
},
endpoints=["/", "/health", "/info", "/predict", "/generate", "/embed"]
)
@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
"""
Run text classification prediction
- **text**: Input text to classify
- **top_k**: Number of top predictions to return
"""
try:
start_time = datetime.now()
classifier = get_classifier()
results = classifier(request.text, top_k=request.top_k)
latency = (datetime.now() - start_time).total_seconds() * 1000
return PredictResponse(
predictions=results,
model=MODEL_NAME,
latency_ms=round(latency, 2)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
"""
Generate text from a prompt
- **prompt**: Input prompt for generation
- **max_length**: Maximum length of generated text
- **num_return_sequences**: Number of sequences to generate
- **temperature**: Sampling temperature
"""
try:
start_time = datetime.now()
generator = get_generator()
results = generator(
request.prompt,
max_length=request.max_length,
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
do_sample=True
)
latency = (datetime.now() - start_time).total_seconds() * 1000
generated_texts = [r["generated_text"] for r in results]
return GenerateResponse(
generated_text=generated_texts,
model=GENERATOR_MODEL,
latency_ms=round(latency, 2)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embed", response_model=EmbedResponse)
async def embed(request: EmbedRequest):
"""
Get text embeddings
- **texts**: List of texts to embed
"""
try:
start_time = datetime.now()
embedder = get_embedder()
# Tokenize and get embeddings
inputs = embedder["tokenizer"](
request.texts,
padding=True,
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
outputs = embedder["model"](**inputs)
# Use mean pooling
embeddings = outputs.last_hidden_state.mean(dim=1)
latency = (datetime.now() - start_time).total_seconds() * 1000
return EmbedResponse(
embeddings=embeddings.tolist(),
model=MODEL_NAME,
dimensions=embeddings.shape[1],
latency_ms=round(latency, 2)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)