import os import re import json from fastapi import FastAPI, Header, HTTPException, Request from fastapi.responses import HTMLResponse from pydantic import BaseModel from llmlingua import PromptCompressor # ---- Force CPU (avoid CUDA on CPU-only hosts) os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") # ---- Config via env (tweak without code changes) FALLBACK_MODEL = "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank" MODEL_NAME = os.environ.get("LLMLINGUA_MODEL", FALLBACK_MODEL) API_KEY = os.environ.get("LLMLINGUA_API_KEY") # optional # For /privacy (edit by env if you like) SERVICE_NAME = os.environ.get("SERVICE_NAME", "llmlingua-gpts") SERVICE_OWNER = os.environ.get("SERVICE_OWNER", "clancylin") PRIVACY_EFFECTIVE = os.environ.get("PRIVACY_EFFECTIVE", "2025-11-02") app = FastAPI(title="LLMLingua Wrapper", version="1.0.0") def _build_compressor(model_name: str) -> PromptCompressor: return PromptCompressor( model_name=model_name, use_llmlingua2=True, device_map="cpu", model_config={"low_cpu_mem_usage": True}, ) # Try desired model; fall back to a public multilingual one if it fails _loaded_model = MODEL_NAME try: compressor = _build_compressor(MODEL_NAME) except Exception: _loaded_model = FALLBACK_MODEL compressor = _build_compressor(FALLBACK_MODEL) # ---- Schemas class CompressOut(BaseModel): compressed_text: str origin_tokens: int | None = None compressed_tokens: int | None = None ratio: str | None = None rate_used: float | None = None # ---- Optional API key check def verify(x_api_key: str | None = None): if API_KEY and x_api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") # ---- Body parsing: accept JSON / text / form-data / urlencoded def _coerce_numbers(d: dict) -> dict: if "rate" in d and isinstance(d["rate"], str): try: d["rate"] = float(d["rate"]) except: pass if "target_tokens" in d and isinstance(d["target_tokens"], str): try: d["target_tokens"] = int(d["target_tokens"]) except: pass if "force_tokens" in d and isinstance(d["force_tokens"], str): try: d["force_tokens"] = json.loads(d["force_tokens"]) except: pass return d async def _read_any_body(request: Request) -> dict: ct = (request.headers.get("content-type") or "").lower() if "application/json" in ct: raw = await request.body() if not raw: return {} try: return _coerce_numbers(json.loads(raw)) except Exception: # Fallback: treat whole body as plain text return {"text": raw.decode("utf-8", "ignore")} if "text/plain" in ct or not ct: return {"text": (await request.body()).decode("utf-8", "ignore")} if "multipart/form-data" in ct or "application/x-www-form-urlencoded" in ct: form = await request.form() data = {k: v for k, v in form.items()} return _coerce_numbers(data) return {"text": (await request.body()).decode("utf-8", "ignore")} # ---- Heuristic rate when not provided def _auto_rate(text: str, target_tokens: int | None) -> float: n = len(compressor.tokenizer.tokenize(text)) has_code = ("```" in text) or (re.search(r"[{}\[\]]", text) is not None) if target_tokens: return float(min(0.95, max(0.1, target_tokens / max(1, n)))) if has_code: return 0.7 if n >= 1200 else 0.6 if n >= 2000: return 0.4 if n >= 1200: return 0.5 return 0.6 # ---- Routes @app.get("/") def root(): return { "status": "ok", "service": SERVICE_NAME, "owner": SERVICE_OWNER, "requested_model": MODEL_NAME, "loaded_model": _loaded_model, "endpoints": ["/compress", "/healthz", "/privacy"], } @app.get("/healthz") def healthz(): return {"ok": True} @app.post("/compress", response_model=CompressOut) async def compress(request: Request, x_api_key: str | None = Header(default=None)): verify(x_api_key) # remove if you don't use API keys body = await _read_any_body(request) text = body.get("text") if not isinstance(text, str) or not text.strip(): raise HTTPException(status_code=422, detail="`text` is required (JSON, text/plain, or form-data).") rate = body.get("rate", None) target_tokens = body.get("target_tokens", None) force_tokens = body.get("force_tokens", None) kw = {} if isinstance(target_tokens, int) and target_tokens > 0: kw["target_token"] = target_tokens if isinstance(force_tokens, list): kw["force_tokens"] = force_tokens rate_used = float(rate) if rate is not None else _auto_rate(text, target_tokens if isinstance(target_tokens, int) else None) out = compressor.compress_prompt(text, rate=rate_used, **kw) raw = out.get("compressed_prompt", "") or out.get("compressed_text", "") # Detokenize to avoid "char + space" artifacts try: toks = compressor.tokenizer.tokenize(raw) comp_text = compressor.tokenizer.convert_tokens_to_string(toks) comp_text = re.sub(r"[^\S\r\n]+([,.;:!?])", r"\1", comp_text).strip() except Exception: comp_text = re.sub(r"[^\S\r\n]+", " ", raw or "").strip() origin_tokens = len(compressor.tokenizer.tokenize(text)) or 1 compressed_tokens = len(compressor.tokenizer.tokenize(comp_text)) ratio = f"{compressed_tokens / origin_tokens:.2f}x" return { "compressed_text": comp_text, "origin_tokens": origin_tokens, "compressed_tokens": compressed_tokens, "ratio": ratio, "rate_used": rate_used, } @app.get("/privacy", response_class=HTMLResponse) def privacy(): return """
This service compresses text using LLMLingua models hosted on Hugging Face Spaces.
text field and optional parameters (rate, etc.)./compress), and error traces for reliability.