# app.py import io, os, threading from typing import Optional, Any, Dict from PIL import Image import torch from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Header from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse # ── CORS ─────────────────────────────────────────────────────────────────────── ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "*").split(",") ALLOW_CREDENTIALS = not (len(ALLOWED_ORIGINS) == 1 and ALLOWED_ORIGINS[0] == "*") app = FastAPI(title="Image Captioning API (BLIP2 / CogVLM)") app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, allow_credentials=ALLOW_CREDENTIALS, allow_methods=["*"], allow_headers=["*"], ) # ── (Jednoduché) Bearer auth přes tajný token v headeru Authorization ---------- API_TOKEN = os.environ.get("VITE_API_TOKEN") # nastav v HF Space Settings → Secrets def check_auth(auth_header: Optional[str]): if API_TOKEN: if not auth_header or not auth_header.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing Bearer token") token = auth_header.split(" ", 1)[1] if token != API_TOKEN: raise HTTPException(status_code=403, detail="Invalid token") # ── Model registry (thread-safe lazy loading) ───────────────────────────────── MODELS: Dict[str, Any] = {"blip2": None, "cogvlm": None} DEVICE = "cuda" if torch.cuda.is_available() else "cpu" CACHE_DIR = os.environ.get("HF_HOME", "/app/.cache/huggingface") # Thread-safe locks pro model loading _loading_locks = {"blip2": threading.Lock(), "cogvlm": threading.Lock()} def remove_prompt_from_output(full_text: str, original_prompt: str) -> str: """Lepší odstranění promptu z výstupu modelu""" if not original_prompt: return full_text # Zkus exact match if full_text.startswith(original_prompt): return full_text[len(original_prompt):].strip() # Zkus case-insensitive if full_text.lower().startswith(original_prompt.lower()): return full_text[len(original_prompt):].strip() # Zkus s různými whitespace prompt_normalized = " ".join(original_prompt.split()) text_normalized = " ".join(full_text.split()) if text_normalized.startswith(prompt_normalized): # Najdi pozici v původním textu words_to_remove = len(original_prompt.split()) remaining_words = full_text.split()[words_to_remove:] return " ".join(remaining_words).strip() # Fallback: vrať celý text return full_text def load_blip2(): try: from transformers import Blip2Processor, Blip2ForConditionalGeneration name = os.environ.get("BLIP2_NAME", "Salesforce/blip2-opt-2.7b") processor = Blip2Processor.from_pretrained(name, cache_dir=CACHE_DIR) model = Blip2ForConditionalGeneration.from_pretrained( name, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, cache_dir=CACHE_DIR ).to(DEVICE) return {"name": name, "processor": processor, "model": model} except Exception as e: # Cleanup při chybě if DEVICE == "cuda": torch.cuda.empty_cache() raise HTTPException(500, detail=f"Failed to load BLIP2: {str(e)}") def caption_blip2(image: Image.Image, prompt: Optional[str], max_new_tokens: int): # Thread-safe model loading if MODELS["blip2"] is None: with _loading_locks["blip2"]: if MODELS["blip2"] is None: # Double-check MODELS["blip2"] = load_blip2() entry = MODELS["blip2"] processor = entry["processor"] model = entry["model"] text = prompt or "Describe this image." try: inputs = processor(images=image, text=text, return_tensors="pt").to(DEVICE) with torch.no_grad(): output_ids = model.generate(**inputs, max_new_tokens=128) # BLIP2 vrací celou sekvenci včetně input promptu full_text = processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() # Odebereme vstupní prompt z výsledku caption = remove_prompt_from_output(full_text, text) return caption if caption else full_text except Exception as e: raise HTTPException(500, detail=f"BLIP2 generation failed: {str(e)}") def load_cogvlm(): try: from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM name = os.environ.get("COGVLM_NAME", "THUDM/cogvlm2-llama3-captioner") processor = AutoProcessor.from_pretrained(name, trust_remote_code=True, cache_dir=CACHE_DIR) tokenizer = AutoTokenizer.from_pretrained( name, trust_remote_code=True, use_fast=False, # CogVLM může mít problémy s fast tokenizerem cache_dir=CACHE_DIR ) # Konzistentní device handling - použij buď DEVICE nebo device_map="auto", ne oboje if DEVICE == "cuda" and torch.cuda.device_count() > 1: # Multi-GPU setup model = AutoModelForCausalLM.from_pretrained( name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, cache_dir=CACHE_DIR ) else: # Single GPU/CPU setup model = AutoModelForCausalLM.from_pretrained( name, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, trust_remote_code=True, cache_dir=CACHE_DIR ).to(DEVICE) return {"name": name, "processor": processor, "tokenizer": tokenizer, "model": model} except Exception as e: # Cleanup při chybě if DEVICE == "cuda": torch.cuda.empty_cache() raise HTTPException(500, detail=f"Failed to load CogVLM: {str(e)}") def caption_cogvlm(image: Image.Image, prompt: Optional[str], max_new_tokens: int): # Thread-safe model loading if MODELS["cogvlm"] is None: with _loading_locks["cogvlm"]: if MODELS["cogvlm"] is None: # Double-check MODELS["cogvlm"] = load_cogvlm() entry = MODELS["cogvlm"] processor = entry["processor"] tokenizer = entry["tokenizer"] model = entry["model"] text = prompt or "Describe this image." try: # Použij konzistentní device target_device = model.device if hasattr(model, 'device') else DEVICE inputs = processor(images=image, text=text, return_tensors="pt").to(target_device) with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) # CogVLM také může vracet celou sekvenci včetně promptu full_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip() # Odebereme vstupní prompt z výsledku caption = remove_prompt_from_output(full_text, text) return caption if caption else full_text except Exception as e: raise HTTPException(500, detail=f"CogVLM generation failed: {str(e)}") # ── Routes ──────────────────────────────────────────────────────────────────── @app.get("/") def root(): return { "message": "Image Captioning API (BLIP2 / CogVLM)", "endpoints": ["/health", "/caption"], "device": DEVICE, "models": list(MODELS.keys()), "cuda_available": torch.cuda.is_available(), "cuda_devices": torch.cuda.device_count() if torch.cuda.is_available() else 0 } @app.get("/health") def health(): return { "status": "ok", "device": DEVICE, "cuda": torch.cuda.is_available(), "loaded_models": [k for k, v in MODELS.items() if v is not None] } @app.get("/caption") def caption_info(): return { "method": "POST", "description": "Upload image and get caption", "parameters": { "file": "image file (required)", "model": "blip2 or cogvlm (default: blip2)", "prompt": "custom prompt (optional)", "max_new_tokens": "max tokens to generate (default: 64)" }, "auth": "Bearer token in Authorization header (if API_TOKEN is set)" } @app.post("/caption") async def caption( file: UploadFile = File(...), model: str = Form("blip2"), prompt: Optional[str] = Form(None), max_new_tokens: int = Form(128), authorization: Optional[str] = Header(None) ): check_auth(authorization) if model not in ("blip2", "cogvlm"): raise HTTPException(400, detail="model must be 'blip2' or 'cogvlm'") if max_new_tokens < 1 or max_new_tokens > 512: raise HTTPException(400, detail="max_new_tokens must be between 1 and 512") try: content = await file.read() image = Image.open(io.BytesIO(content)).convert("RGB") except Exception as e: raise HTTPException(400, detail=f"Invalid image file: {str(e)}") try: if model == "blip2": caption_text = caption_blip2(image, prompt, max_new_tokens) else: caption_text = caption_cogvlm(image, prompt, max_new_tokens) return JSONResponse({ "model": model, "caption": caption_text, "prompt_used": prompt or "Describe this image.", "max_new_tokens": max_new_tokens }) except HTTPException: # Re-raise HTTP exceptions (už mají správný status code) raise except Exception as e: # Catch-all pro neočekávané chyby raise HTTPException(500, detail=f"Caption generation failed: {str(e)}")