import os import re import json from fastapi import FastAPI, Header, HTTPException, Request from fastapi.responses import HTMLResponse from pydantic import BaseModel from llmlingua import PromptCompressor # ---- Force CPU (avoid CUDA on CPU-only hosts) os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") # ---- Config via env (tweak without code changes) FALLBACK_MODEL = "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank" MODEL_NAME = os.environ.get("LLMLINGUA_MODEL", FALLBACK_MODEL) API_KEY = os.environ.get("LLMLINGUA_API_KEY") # optional # For /privacy (edit by env if you like) SERVICE_NAME = os.environ.get("SERVICE_NAME", "llmlingua-gpts") SERVICE_OWNER = os.environ.get("SERVICE_OWNER", "clancylin") PRIVACY_EFFECTIVE = os.environ.get("PRIVACY_EFFECTIVE", "2025-11-02") app = FastAPI(title="LLMLingua Wrapper", version="1.0.0") def _build_compressor(model_name: str) -> PromptCompressor: return PromptCompressor( model_name=model_name, use_llmlingua2=True, device_map="cpu", model_config={"low_cpu_mem_usage": True}, ) # Try desired model; fall back to a public multilingual one if it fails _loaded_model = MODEL_NAME try: compressor = _build_compressor(MODEL_NAME) except Exception: _loaded_model = FALLBACK_MODEL compressor = _build_compressor(FALLBACK_MODEL) # ---- Schemas class CompressOut(BaseModel): compressed_text: str origin_tokens: int | None = None compressed_tokens: int | None = None ratio: str | None = None rate_used: float | None = None # ---- Optional API key check def verify(x_api_key: str | None = None): if API_KEY and x_api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") # ---- Body parsing: accept JSON / text / form-data / urlencoded def _coerce_numbers(d: dict) -> dict: if "rate" in d and isinstance(d["rate"], str): try: d["rate"] = float(d["rate"]) except: pass if "target_tokens" in d and isinstance(d["target_tokens"], str): try: d["target_tokens"] = int(d["target_tokens"]) except: pass if "force_tokens" in d and isinstance(d["force_tokens"], str): try: d["force_tokens"] = json.loads(d["force_tokens"]) except: pass return d async def _read_any_body(request: Request) -> dict: ct = (request.headers.get("content-type") or "").lower() if "application/json" in ct: raw = await request.body() if not raw: return {} try: return _coerce_numbers(json.loads(raw)) except Exception: # Fallback: treat whole body as plain text return {"text": raw.decode("utf-8", "ignore")} if "text/plain" in ct or not ct: return {"text": (await request.body()).decode("utf-8", "ignore")} if "multipart/form-data" in ct or "application/x-www-form-urlencoded" in ct: form = await request.form() data = {k: v for k, v in form.items()} return _coerce_numbers(data) return {"text": (await request.body()).decode("utf-8", "ignore")} # ---- Heuristic rate when not provided def _auto_rate(text: str, target_tokens: int | None) -> float: n = len(compressor.tokenizer.tokenize(text)) has_code = ("```" in text) or (re.search(r"[{}\[\]]", text) is not None) if target_tokens: return float(min(0.95, max(0.1, target_tokens / max(1, n)))) if has_code: return 0.7 if n >= 1200 else 0.6 if n >= 2000: return 0.4 if n >= 1200: return 0.5 return 0.6 # ---- Routes @app.get("/") def root(): return { "status": "ok", "service": SERVICE_NAME, "owner": SERVICE_OWNER, "requested_model": MODEL_NAME, "loaded_model": _loaded_model, "endpoints": ["/compress", "/healthz", "/privacy"], } @app.get("/healthz") def healthz(): return {"ok": True} @app.post("/compress", response_model=CompressOut) async def compress(request: Request, x_api_key: str | None = Header(default=None)): verify(x_api_key) # remove if you don't use API keys body = await _read_any_body(request) text = body.get("text") if not isinstance(text, str) or not text.strip(): raise HTTPException(status_code=422, detail="`text` is required (JSON, text/plain, or form-data).") rate = body.get("rate", None) target_tokens = body.get("target_tokens", None) force_tokens = body.get("force_tokens", None) kw = {} if isinstance(target_tokens, int) and target_tokens > 0: kw["target_token"] = target_tokens if isinstance(force_tokens, list): kw["force_tokens"] = force_tokens rate_used = float(rate) if rate is not None else _auto_rate(text, target_tokens if isinstance(target_tokens, int) else None) out = compressor.compress_prompt(text, rate=rate_used, **kw) raw = out.get("compressed_prompt", "") or out.get("compressed_text", "") # Detokenize to avoid "char + space" artifacts try: toks = compressor.tokenizer.tokenize(raw) comp_text = compressor.tokenizer.convert_tokens_to_string(toks) comp_text = re.sub(r"[^\S\r\n]+([,.;:!?])", r"\1", comp_text).strip() except Exception: comp_text = re.sub(r"[^\S\r\n]+", " ", raw or "").strip() origin_tokens = len(compressor.tokenizer.tokenize(text)) or 1 compressed_tokens = len(compressor.tokenizer.tokenize(comp_text)) ratio = f"{compressed_tokens / origin_tokens:.2f}x" return { "compressed_text": comp_text, "origin_tokens": origin_tokens, "compressed_tokens": compressed_tokens, "ratio": ratio, "rate_used": rate_used, } @app.get("/privacy", response_class=HTMLResponse) def privacy(): return """ Privacy Policy

Privacy Policy

This service compresses text using LLMLingua models hosted on Hugging Face Spaces.

What we process

How we use data

Retention

Third parties

Security

Contact

""" if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), reload=False)