Spaces:
Paused
Paused
| """ | |
| FastAPI REST API for Foundation 1.2 Clinical Trial System | |
| Production-ready Docker space with proper REST endpoints | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import time | |
| import logging | |
| # Import the foundation engine | |
| import foundation_engine | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="Clinical Trial API", | |
| description="Production REST API for clinical trial analysis powered by Foundation 1.2 pipeline", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request/Response models | |
| class QueryRequest(BaseModel): | |
| query: str | |
| class QueryResponse(BaseModel): | |
| summary: str | |
| processing_time: float | |
| class SearchRequest(BaseModel): | |
| query: str | |
| top_k: int = 10 | |
| class HealthResponse(BaseModel): | |
| status: str | |
| trials_loaded: int | |
| embeddings_loaded: bool | |
| async def startup_event(): | |
| """Initialize the foundation engine on startup""" | |
| logger.info("=== API Startup ===") | |
| logger.info("Loading Foundation 1.2 engine...") | |
| try: | |
| # The foundation_engine will load embeddings when first accessed | |
| foundation_engine.load_embeddings() | |
| logger.info("=== API Ready - Embeddings Loaded ===") | |
| except Exception as e: | |
| logger.error(f"!!! Failed to load embeddings: {e}") | |
| logger.error("!!! API will start but queries will fail until embeddings are loaded") | |
| logger.info("=== API Ready - Degraded Mode ===") | |
| async def root(): | |
| """API information""" | |
| return { | |
| "service": "Clinical Trial API", | |
| "version": "2.0.0", | |
| "description": "Production REST API for Foundation 1.2 with 355M perplexity ranking", | |
| "status": "healthy", | |
| "endpoints": { | |
| "POST /search": "[NEW] Search trials with structured JSON output (includes 355M ranking)", | |
| "POST /query": "Query clinical trials and get AI-generated summary (legacy)", | |
| "GET /health": "Health check", | |
| "GET /docs": "Interactive API documentation (Swagger UI)", | |
| "GET /redoc": "Alternative API documentation (ReDoc)" | |
| }, | |
| "features": [ | |
| "LLM Query Parser (entity extraction + synonyms)", | |
| "Hybrid RAG Search (BM25 + semantic + inverted index)", | |
| "355M Clinical Trial GPT perplexity-based ranking", | |
| "Structured JSON output", | |
| "Benchmarking metrics (before/after 355M scores)" | |
| ] | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| embeddings_loaded = foundation_engine.doc_embeddings is not None | |
| chunks_loaded = len(foundation_engine.doc_chunks) if foundation_engine.doc_chunks else 0 | |
| return HealthResponse( | |
| status="healthy", | |
| trials_loaded=chunks_loaded, | |
| embeddings_loaded=embeddings_loaded | |
| ) | |
| async def query_trials(request: QueryRequest): | |
| """ | |
| Query clinical trials and get AI-generated summary | |
| - **query**: Your question about clinical trials (e.g., "What trials exist for Dekavil?") | |
| Returns a structured medical analysis with: | |
| - Drug/Intervention background | |
| - Clinical trial results and data | |
| - Treatment considerations | |
| - NCT trial IDs and references | |
| """ | |
| try: | |
| logger.info(f"API Query received: {request.query[:100]}...") | |
| start_time = time.time() | |
| # Call the foundation engine | |
| result = foundation_engine.process_query(request.query) | |
| processing_time = time.time() - start_time | |
| logger.info(f"Query completed in {processing_time:.2f}s") | |
| return QueryResponse( | |
| summary=result, | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing query: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") | |
| async def search_trials(request: SearchRequest): | |
| """ | |
| Search clinical trials and get structured JSON results (NEW API v2.0) | |
| This endpoint provides: | |
| - Query parsing with LLM (entity extraction + synonym expansion) | |
| - Hybrid RAG search (BM25 + semantic embeddings + inverted index) | |
| - 355M Clinical Trial GPT perplexity-based re-ranking | |
| - Structured JSON output with benchmarking data | |
| **No response generation** - returns raw trial data for client-side processing | |
| Args: | |
| - **query**: Your question about clinical trials | |
| - **top_k**: Number of trials to return (default: 10, max: 50) | |
| Returns: | |
| - Structured JSON with trials ranked by clinical relevance | |
| - Includes before/after 355M ranking scores for benchmarking | |
| - Processing time breakdown (query parsing, RAG search, 355M ranking) | |
| """ | |
| try: | |
| logger.info(f"[SEARCH API] Query received: {request.query[:100]}...") | |
| # Validate top_k | |
| if request.top_k > 50: | |
| logger.warning(f"[SEARCH API] top_k={request.top_k} exceeds maximum 50, capping") | |
| request.top_k = 50 | |
| elif request.top_k < 1: | |
| logger.warning(f"[SEARCH API] top_k={request.top_k} is invalid, using default 10") | |
| request.top_k = 10 | |
| start_time = time.time() | |
| # Call the structured query processor | |
| result = foundation_engine.process_query_structured(request.query, top_k=request.top_k) | |
| processing_time = time.time() - start_time | |
| logger.info(f"[SEARCH API] Query completed in {processing_time:.2f}s") | |
| # Ensure processing_time is set | |
| if 'processing_time' not in result or result['processing_time'] == 0: | |
| result['processing_time'] = processing_time | |
| return result | |
| except Exception as e: | |
| logger.error(f"[SEARCH API] Error processing query: {str(e)}") | |
| import traceback | |
| return { | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| "query": request.query, | |
| "processing_time": time.time() - start_time if 'start_time' in locals() else 0 | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |