Spaces:
Paused
Add /search endpoint with 355M perplexity ranking (Option B implementation)
Browse filesNEW FEATURES:
- New /search endpoint returns structured JSON (no LLM response generation)
- Keeps Query Parser LLM for entity extraction + synonym expansion
- Implements 355M Clinical Trial GPT perplexity-based re-ranking
- Includes before/after benchmarking metrics for 355M impact
- Returns trials ranked by clinical relevance (70% hybrid + 30% perplexity)
IMPLEMENTATION:
- rank_trials_with_355m_perplexity(): Uses perplexity scoring (not generation) to avoid hallucinations
- parse_trial_text_to_dict(): Parses trial text into structured fields
- process_query_structured(): Main function for /search endpoint
BENCHMARKING:
- rank_before_355m: Original hybrid search ranking
- rank_after_355m: Final ranking after 355M perplexity adjustment
- perplexity: Raw perplexity score (lower = more relevant)
- perplexity_score: Normalized 0-1 score
- Processing time breakdown for each stage
API Response includes:
- query_analysis: Extracted entities and optimized search terms
- results: Total found, returned count, top relevance score
- trials[]: Structured trial data with scoring metadata
- benchmarking: Performance metrics and 355M ranking impact
- metadata: Model versions and database info
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- app.py +70 -5
- foundation_engine.py +475 -0
|
@@ -40,6 +40,10 @@ class QueryResponse(BaseModel):
|
|
| 40 |
summary: str
|
| 41 |
processing_time: float
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
class HealthResponse(BaseModel):
|
| 44 |
status: str
|
| 45 |
trials_loaded: int
|
|
@@ -64,18 +68,22 @@ async def root():
|
|
| 64 |
"""API information"""
|
| 65 |
return {
|
| 66 |
"service": "Clinical Trial API",
|
| 67 |
-
"version": "
|
| 68 |
-
"description": "Production REST API for Foundation 1.2",
|
| 69 |
"status": "healthy",
|
| 70 |
"endpoints": {
|
| 71 |
-
"POST /
|
|
|
|
| 72 |
"GET /health": "Health check",
|
| 73 |
"GET /docs": "Interactive API documentation (Swagger UI)",
|
| 74 |
"GET /redoc": "Alternative API documentation (ReDoc)"
|
| 75 |
},
|
| 76 |
"features": [
|
| 77 |
-
"
|
| 78 |
-
"
|
|
|
|
|
|
|
|
|
|
| 79 |
]
|
| 80 |
}
|
| 81 |
|
|
@@ -123,6 +131,63 @@ async def query_trials(request: QueryRequest):
|
|
| 123 |
logger.error(f"Error processing query: {str(e)}")
|
| 124 |
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
if __name__ == "__main__":
|
| 127 |
import uvicorn
|
| 128 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 40 |
summary: str
|
| 41 |
processing_time: float
|
| 42 |
|
| 43 |
+
class SearchRequest(BaseModel):
|
| 44 |
+
query: str
|
| 45 |
+
top_k: int = 10
|
| 46 |
+
|
| 47 |
class HealthResponse(BaseModel):
|
| 48 |
status: str
|
| 49 |
trials_loaded: int
|
|
|
|
| 68 |
"""API information"""
|
| 69 |
return {
|
| 70 |
"service": "Clinical Trial API",
|
| 71 |
+
"version": "2.0.0",
|
| 72 |
+
"description": "Production REST API for Foundation 1.2 with 355M perplexity ranking",
|
| 73 |
"status": "healthy",
|
| 74 |
"endpoints": {
|
| 75 |
+
"POST /search": "[NEW] Search trials with structured JSON output (includes 355M ranking)",
|
| 76 |
+
"POST /query": "Query clinical trials and get AI-generated summary (legacy)",
|
| 77 |
"GET /health": "Health check",
|
| 78 |
"GET /docs": "Interactive API documentation (Swagger UI)",
|
| 79 |
"GET /redoc": "Alternative API documentation (ReDoc)"
|
| 80 |
},
|
| 81 |
"features": [
|
| 82 |
+
"LLM Query Parser (entity extraction + synonyms)",
|
| 83 |
+
"Hybrid RAG Search (BM25 + semantic + inverted index)",
|
| 84 |
+
"355M Clinical Trial GPT perplexity-based ranking",
|
| 85 |
+
"Structured JSON output",
|
| 86 |
+
"Benchmarking metrics (before/after 355M scores)"
|
| 87 |
]
|
| 88 |
}
|
| 89 |
|
|
|
|
| 131 |
logger.error(f"Error processing query: {str(e)}")
|
| 132 |
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
| 133 |
|
| 134 |
+
@app.post("/search")
|
| 135 |
+
async def search_trials(request: SearchRequest):
|
| 136 |
+
"""
|
| 137 |
+
Search clinical trials and get structured JSON results (NEW API v2.0)
|
| 138 |
+
|
| 139 |
+
This endpoint provides:
|
| 140 |
+
- Query parsing with LLM (entity extraction + synonym expansion)
|
| 141 |
+
- Hybrid RAG search (BM25 + semantic embeddings + inverted index)
|
| 142 |
+
- 355M Clinical Trial GPT perplexity-based re-ranking
|
| 143 |
+
- Structured JSON output with benchmarking data
|
| 144 |
+
|
| 145 |
+
**No response generation** - returns raw trial data for client-side processing
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
- **query**: Your question about clinical trials
|
| 149 |
+
- **top_k**: Number of trials to return (default: 10, max: 50)
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
- Structured JSON with trials ranked by clinical relevance
|
| 153 |
+
- Includes before/after 355M ranking scores for benchmarking
|
| 154 |
+
- Processing time breakdown (query parsing, RAG search, 355M ranking)
|
| 155 |
+
"""
|
| 156 |
+
try:
|
| 157 |
+
logger.info(f"[SEARCH API] Query received: {request.query[:100]}...")
|
| 158 |
+
|
| 159 |
+
# Validate top_k
|
| 160 |
+
if request.top_k > 50:
|
| 161 |
+
logger.warning(f"[SEARCH API] top_k={request.top_k} exceeds maximum 50, capping")
|
| 162 |
+
request.top_k = 50
|
| 163 |
+
elif request.top_k < 1:
|
| 164 |
+
logger.warning(f"[SEARCH API] top_k={request.top_k} is invalid, using default 10")
|
| 165 |
+
request.top_k = 10
|
| 166 |
+
|
| 167 |
+
start_time = time.time()
|
| 168 |
+
|
| 169 |
+
# Call the structured query processor
|
| 170 |
+
result = foundation_engine.process_query_structured(request.query, top_k=request.top_k)
|
| 171 |
+
|
| 172 |
+
processing_time = time.time() - start_time
|
| 173 |
+
logger.info(f"[SEARCH API] Query completed in {processing_time:.2f}s")
|
| 174 |
+
|
| 175 |
+
# Ensure processing_time is set
|
| 176 |
+
if 'processing_time' not in result or result['processing_time'] == 0:
|
| 177 |
+
result['processing_time'] = processing_time
|
| 178 |
+
|
| 179 |
+
return result
|
| 180 |
+
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logger.error(f"[SEARCH API] Error processing query: {str(e)}")
|
| 183 |
+
import traceback
|
| 184 |
+
return {
|
| 185 |
+
"error": str(e),
|
| 186 |
+
"traceback": traceback.format_exc(),
|
| 187 |
+
"query": request.query,
|
| 188 |
+
"processing_time": time.time() - start_time if 'start_time' in locals() else 0
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
if __name__ == "__main__":
|
| 192 |
import uvicorn
|
| 193 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
@@ -1814,6 +1814,481 @@ Query Type Distribution:
|
|
| 1814 |
return report
|
| 1815 |
|
| 1816 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1817 |
# ============================================================================
|
| 1818 |
# GRADIO INTERFACE
|
| 1819 |
# ============================================================================
|
|
|
|
| 1814 |
return report
|
| 1815 |
|
| 1816 |
|
| 1817 |
+
# ============================================================================
|
| 1818 |
+
# 355M PERPLEXITY-BASED RANKING (FOR STRUCTURED JSON API)
|
| 1819 |
+
# ============================================================================
|
| 1820 |
+
|
| 1821 |
+
def rank_trials_with_355m_perplexity(query, trials_list, hf_token=None):
|
| 1822 |
+
"""
|
| 1823 |
+
Rank trials using 355M Clinical Trial GPT perplexity scoring
|
| 1824 |
+
|
| 1825 |
+
This uses the model for SCORING not GENERATION to avoid hallucinations
|
| 1826 |
+
Lower perplexity = more relevant trial
|
| 1827 |
+
|
| 1828 |
+
Args:
|
| 1829 |
+
query: User query
|
| 1830 |
+
trials_list: List of (score, trial_text) tuples from hybrid search
|
| 1831 |
+
hf_token: Not needed (model runs locally)
|
| 1832 |
+
|
| 1833 |
+
Returns:
|
| 1834 |
+
List of dicts with trial data and perplexity scores
|
| 1835 |
+
"""
|
| 1836 |
+
import time
|
| 1837 |
+
import re
|
| 1838 |
+
import torch
|
| 1839 |
+
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
| 1840 |
+
|
| 1841 |
+
start_time = time.time()
|
| 1842 |
+
|
| 1843 |
+
# Only rank top 10 trials (balance between accuracy and speed)
|
| 1844 |
+
top_10 = trials_list[:10]
|
| 1845 |
+
|
| 1846 |
+
logger.info(f"[355M PERPLEXITY] Ranking {len(top_10)} trials with CT2 model...")
|
| 1847 |
+
|
| 1848 |
+
try:
|
| 1849 |
+
# Load 355M model
|
| 1850 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/CT2")
|
| 1851 |
+
model = GPT2LMHeadModel.from_pretrained(
|
| 1852 |
+
"gmkdigitalmedia/CT2",
|
| 1853 |
+
torch_dtype=torch.float16,
|
| 1854 |
+
device_map="auto"
|
| 1855 |
+
)
|
| 1856 |
+
model.eval()
|
| 1857 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 1858 |
+
|
| 1859 |
+
ranked_trials = []
|
| 1860 |
+
|
| 1861 |
+
for idx, (hybrid_score, trial_text) in enumerate(top_10):
|
| 1862 |
+
# Extract NCT ID
|
| 1863 |
+
nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
|
| 1864 |
+
nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
|
| 1865 |
+
|
| 1866 |
+
# Format test text for perplexity calculation
|
| 1867 |
+
# The model calculates: "How natural is this query-trial pairing?"
|
| 1868 |
+
test_text = f"""Query: {query}
|
| 1869 |
+
|
| 1870 |
+
Relevant Clinical Trial:
|
| 1871 |
+
{trial_text[:800]}
|
| 1872 |
+
|
| 1873 |
+
This trial is highly relevant because"""
|
| 1874 |
+
|
| 1875 |
+
# Calculate perplexity (lower = more relevant)
|
| 1876 |
+
inputs = tokenizer(
|
| 1877 |
+
test_text,
|
| 1878 |
+
return_tensors="pt",
|
| 1879 |
+
truncation=True,
|
| 1880 |
+
max_length=512,
|
| 1881 |
+
padding=True
|
| 1882 |
+
).to(model.device)
|
| 1883 |
+
|
| 1884 |
+
with torch.no_grad():
|
| 1885 |
+
outputs = model(**inputs, labels=inputs.input_ids)
|
| 1886 |
+
perplexity = torch.exp(outputs.loss).item()
|
| 1887 |
+
|
| 1888 |
+
# Convert perplexity to relevance score (0-1)
|
| 1889 |
+
# Typical range: 10-1000, lower is better
|
| 1890 |
+
perplexity_score = 1.0 / (1.0 + perplexity / 100)
|
| 1891 |
+
|
| 1892 |
+
# Combine with hybrid score (70% hybrid, 30% perplexity)
|
| 1893 |
+
combined_score = 0.7 * hybrid_score + 0.3 * perplexity_score
|
| 1894 |
+
|
| 1895 |
+
logger.info(f"[355M] {nct_id}: Hybrid={hybrid_score:.3f}, "
|
| 1896 |
+
f"Perplexity={perplexity:.1f}, "
|
| 1897 |
+
f"Perplexity_Score={perplexity_score:.3f}, "
|
| 1898 |
+
f"Combined={combined_score:.3f}")
|
| 1899 |
+
|
| 1900 |
+
ranked_trials.append({
|
| 1901 |
+
'nct_id': nct_id,
|
| 1902 |
+
'trial_text': trial_text,
|
| 1903 |
+
'hybrid_score': float(hybrid_score),
|
| 1904 |
+
'perplexity': float(perplexity),
|
| 1905 |
+
'perplexity_score': float(perplexity_score),
|
| 1906 |
+
'combined_score': float(combined_score),
|
| 1907 |
+
'rank_before_355m': idx + 1
|
| 1908 |
+
})
|
| 1909 |
+
|
| 1910 |
+
# Sort by combined score (descending)
|
| 1911 |
+
ranked_trials.sort(key=lambda x: x['combined_score'], reverse=True)
|
| 1912 |
+
|
| 1913 |
+
# Add final rank
|
| 1914 |
+
for idx, trial in enumerate(ranked_trials):
|
| 1915 |
+
trial['rank_after_355m'] = idx + 1
|
| 1916 |
+
|
| 1917 |
+
elapsed = time.time() - start_time
|
| 1918 |
+
logger.info(f"[355M PERPLEXITY] ✓ Ranking complete in {elapsed:.1f}s")
|
| 1919 |
+
|
| 1920 |
+
# Add remaining trials (beyond top 10) without 355M scoring
|
| 1921 |
+
for idx, (hybrid_score, trial_text) in enumerate(trials_list[10:], start=10):
|
| 1922 |
+
nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
|
| 1923 |
+
nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
|
| 1924 |
+
|
| 1925 |
+
ranked_trials.append({
|
| 1926 |
+
'nct_id': nct_id,
|
| 1927 |
+
'trial_text': trial_text,
|
| 1928 |
+
'hybrid_score': float(hybrid_score),
|
| 1929 |
+
'perplexity': None,
|
| 1930 |
+
'perplexity_score': None,
|
| 1931 |
+
'combined_score': float(hybrid_score),
|
| 1932 |
+
'rank_before_355m': idx + 1,
|
| 1933 |
+
'rank_after_355m': len(ranked_trials) + 1
|
| 1934 |
+
})
|
| 1935 |
+
|
| 1936 |
+
return ranked_trials
|
| 1937 |
+
|
| 1938 |
+
except Exception as e:
|
| 1939 |
+
logger.error(f"[355M PERPLEXITY] Error: {e}")
|
| 1940 |
+
logger.warning("[355M PERPLEXITY] Falling back to hybrid scores only")
|
| 1941 |
+
|
| 1942 |
+
# Fallback: return trials with hybrid scores only
|
| 1943 |
+
fallback_trials = []
|
| 1944 |
+
for idx, (hybrid_score, trial_text) in enumerate(trials_list):
|
| 1945 |
+
nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
|
| 1946 |
+
nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
|
| 1947 |
+
|
| 1948 |
+
fallback_trials.append({
|
| 1949 |
+
'nct_id': nct_id,
|
| 1950 |
+
'trial_text': trial_text,
|
| 1951 |
+
'hybrid_score': float(hybrid_score),
|
| 1952 |
+
'perplexity': None,
|
| 1953 |
+
'perplexity_score': None,
|
| 1954 |
+
'combined_score': float(hybrid_score),
|
| 1955 |
+
'rank_before_355m': idx + 1,
|
| 1956 |
+
'rank_after_355m': idx + 1
|
| 1957 |
+
})
|
| 1958 |
+
|
| 1959 |
+
return fallback_trials
|
| 1960 |
+
|
| 1961 |
+
|
| 1962 |
+
def parse_trial_text_to_dict(trial_text, nct_id):
|
| 1963 |
+
"""
|
| 1964 |
+
Parse trial text into structured dictionary
|
| 1965 |
+
|
| 1966 |
+
Args:
|
| 1967 |
+
trial_text: Raw trial text
|
| 1968 |
+
nct_id: NCT ID
|
| 1969 |
+
|
| 1970 |
+
Returns:
|
| 1971 |
+
Dict with parsed trial fields
|
| 1972 |
+
"""
|
| 1973 |
+
import re
|
| 1974 |
+
|
| 1975 |
+
trial_dict = {
|
| 1976 |
+
'nct_id': nct_id,
|
| 1977 |
+
'title': '',
|
| 1978 |
+
'sponsor': '',
|
| 1979 |
+
'collaborators': [],
|
| 1980 |
+
'phase': '',
|
| 1981 |
+
'status': '',
|
| 1982 |
+
'enrollment': None,
|
| 1983 |
+
'conditions': [],
|
| 1984 |
+
'interventions': [],
|
| 1985 |
+
'primary_outcome': '',
|
| 1986 |
+
'results_summary': '',
|
| 1987 |
+
'start_date': '',
|
| 1988 |
+
'completion_date': '',
|
| 1989 |
+
'last_update': '',
|
| 1990 |
+
'locations': []
|
| 1991 |
+
}
|
| 1992 |
+
|
| 1993 |
+
# Extract fields using regex patterns
|
| 1994 |
+
try:
|
| 1995 |
+
# Title
|
| 1996 |
+
title_match = re.search(r'TITLE:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 1997 |
+
if title_match:
|
| 1998 |
+
trial_dict['title'] = title_match.group(1).strip()
|
| 1999 |
+
|
| 2000 |
+
# Sponsor
|
| 2001 |
+
sponsor_match = re.search(r'SPONSOR:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2002 |
+
if sponsor_match:
|
| 2003 |
+
trial_dict['sponsor'] = sponsor_match.group(1).strip()
|
| 2004 |
+
|
| 2005 |
+
# Collaborators
|
| 2006 |
+
collab_match = re.search(r'COLLABORATOR[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2007 |
+
if collab_match:
|
| 2008 |
+
collabs = collab_match.group(1).strip().split(',')
|
| 2009 |
+
trial_dict['collaborators'] = [c.strip() for c in collabs if c.strip()]
|
| 2010 |
+
|
| 2011 |
+
# Phase
|
| 2012 |
+
phase_match = re.search(r'PHASE:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2013 |
+
if phase_match:
|
| 2014 |
+
trial_dict['phase'] = phase_match.group(1).strip()
|
| 2015 |
+
|
| 2016 |
+
# Status
|
| 2017 |
+
status_match = re.search(r'STATUS:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2018 |
+
if status_match:
|
| 2019 |
+
trial_dict['status'] = status_match.group(1).strip()
|
| 2020 |
+
|
| 2021 |
+
# Enrollment
|
| 2022 |
+
enrollment_match = re.search(r'ENROLLMENT:\s*(\d+)', trial_text, re.IGNORECASE)
|
| 2023 |
+
if enrollment_match:
|
| 2024 |
+
trial_dict['enrollment'] = int(enrollment_match.group(1))
|
| 2025 |
+
|
| 2026 |
+
# Conditions
|
| 2027 |
+
condition_match = re.search(r'CONDITION[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2028 |
+
if condition_match:
|
| 2029 |
+
conditions = condition_match.group(1).strip().split(',')
|
| 2030 |
+
trial_dict['conditions'] = [c.strip() for c in conditions if c.strip()]
|
| 2031 |
+
|
| 2032 |
+
# Interventions
|
| 2033 |
+
intervention_match = re.search(r'INTERVENTION[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2034 |
+
if intervention_match:
|
| 2035 |
+
interventions = intervention_match.group(1).strip().split(',')
|
| 2036 |
+
trial_dict['interventions'] = [i.strip() for i in interventions if i.strip()]
|
| 2037 |
+
|
| 2038 |
+
# Primary outcome
|
| 2039 |
+
outcome_match = re.search(r'PRIMARY[_ ]OUTCOME:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2040 |
+
if outcome_match:
|
| 2041 |
+
trial_dict['primary_outcome'] = outcome_match.group(1).strip()
|
| 2042 |
+
|
| 2043 |
+
# Results summary
|
| 2044 |
+
results_match = re.search(r'RESULTS:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2045 |
+
if results_match:
|
| 2046 |
+
trial_dict['results_summary'] = results_match.group(1).strip()
|
| 2047 |
+
|
| 2048 |
+
# Dates
|
| 2049 |
+
start_match = re.search(r'START[_ ]DATE:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2050 |
+
if start_match:
|
| 2051 |
+
trial_dict['start_date'] = start_match.group(1).strip()
|
| 2052 |
+
|
| 2053 |
+
completion_match = re.search(r'COMPLETION[_ ]DATE:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2054 |
+
if completion_match:
|
| 2055 |
+
trial_dict['completion_date'] = completion_match.group(1).strip()
|
| 2056 |
+
|
| 2057 |
+
# Locations
|
| 2058 |
+
location_match = re.search(r'LOCATION[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE)
|
| 2059 |
+
if location_match:
|
| 2060 |
+
locations = location_match.group(1).strip().split(',')
|
| 2061 |
+
trial_dict['locations'] = [l.strip() for l in locations if l.strip()]
|
| 2062 |
+
|
| 2063 |
+
except Exception as e:
|
| 2064 |
+
logger.warning(f"Error parsing trial {nct_id}: {e}")
|
| 2065 |
+
|
| 2066 |
+
return trial_dict
|
| 2067 |
+
|
| 2068 |
+
|
| 2069 |
+
def process_query_structured(query, top_k=10):
|
| 2070 |
+
"""
|
| 2071 |
+
Process query and return structured JSON (no LLM response generation)
|
| 2072 |
+
|
| 2073 |
+
This is the new API endpoint that:
|
| 2074 |
+
1. Uses LLM for query parsing/entity extraction
|
| 2075 |
+
2. Performs hybrid RAG search
|
| 2076 |
+
3. Ranks with 355M perplexity scoring
|
| 2077 |
+
4. Returns structured JSON
|
| 2078 |
+
|
| 2079 |
+
Args:
|
| 2080 |
+
query: User query
|
| 2081 |
+
top_k: Number of trials to return
|
| 2082 |
+
|
| 2083 |
+
Returns:
|
| 2084 |
+
Dict with structured response
|
| 2085 |
+
"""
|
| 2086 |
+
import time
|
| 2087 |
+
|
| 2088 |
+
start_time = time.time()
|
| 2089 |
+
|
| 2090 |
+
result = {
|
| 2091 |
+
'query': query,
|
| 2092 |
+
'processing_time': 0,
|
| 2093 |
+
'query_analysis': {},
|
| 2094 |
+
'results': {},
|
| 2095 |
+
'trials': [],
|
| 2096 |
+
'benchmarking': {},
|
| 2097 |
+
'metadata': {}
|
| 2098 |
+
}
|
| 2099 |
+
|
| 2100 |
+
try:
|
| 2101 |
+
# Step 1: Parse query with LLM
|
| 2102 |
+
step1_start = time.time()
|
| 2103 |
+
logger.info("[STRUCTURED API] Step 1: Parsing query with LLM...")
|
| 2104 |
+
|
| 2105 |
+
try:
|
| 2106 |
+
parsed_query = parse_query_with_llm(query, hf_token=hf_token)
|
| 2107 |
+
search_query = parsed_query['search_terms']
|
| 2108 |
+
|
| 2109 |
+
result['query_analysis'] = {
|
| 2110 |
+
'extracted_entities': {
|
| 2111 |
+
'drugs': parsed_query.get('drugs', []),
|
| 2112 |
+
'diseases': parsed_query.get('diseases', []),
|
| 2113 |
+
'companies': parsed_query.get('companies', []),
|
| 2114 |
+
'endpoints': parsed_query.get('endpoints', [])
|
| 2115 |
+
},
|
| 2116 |
+
'optimized_search': search_query,
|
| 2117 |
+
'parsing_time': time.time() - step1_start
|
| 2118 |
+
}
|
| 2119 |
+
logger.info(f"[STRUCTURED API] Query parsed in {time.time() - step1_start:.1f}s")
|
| 2120 |
+
|
| 2121 |
+
except Exception as e:
|
| 2122 |
+
logger.warning(f"[STRUCTURED API] Query parsing failed: {e}, using original query")
|
| 2123 |
+
search_query = query
|
| 2124 |
+
parsed_query = {'drugs': [], 'diseases': [], 'companies': [], 'endpoints': []}
|
| 2125 |
+
result['query_analysis'] = {
|
| 2126 |
+
'extracted_entities': parsed_query,
|
| 2127 |
+
'optimized_search': search_query,
|
| 2128 |
+
'parsing_time': time.time() - step1_start,
|
| 2129 |
+
'error': str(e)
|
| 2130 |
+
}
|
| 2131 |
+
|
| 2132 |
+
# Step 2: Hybrid RAG search
|
| 2133 |
+
step2_start = time.time()
|
| 2134 |
+
logger.info("[STRUCTURED API] Step 2: Hybrid RAG search...")
|
| 2135 |
+
|
| 2136 |
+
# Get more candidates for 355M ranking
|
| 2137 |
+
candidate_k = top_k * 3
|
| 2138 |
+
|
| 2139 |
+
# We need to get the candidate trials with scores
|
| 2140 |
+
# Re-implement the key parts of retrieve_context_with_embeddings to get structured data
|
| 2141 |
+
from collections import Counter
|
| 2142 |
+
global doc_chunks, doc_embeddings, embedder, inverted_index
|
| 2143 |
+
|
| 2144 |
+
if doc_embeddings is None or len(doc_chunks) == 0:
|
| 2145 |
+
raise Exception("Embeddings not loaded!")
|
| 2146 |
+
|
| 2147 |
+
# Extract keywords
|
| 2148 |
+
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
|
| 2149 |
+
'is', 'are', 'was', 'were', 'be', 'been', 'being', 'what', 'how', 'do', 'you', 'know',
|
| 2150 |
+
'about', 'that', 'this', 'there', 'it'}
|
| 2151 |
+
query_lower = search_query.lower()
|
| 2152 |
+
import re
|
| 2153 |
+
words = re.findall(r'\b\w+\b', query_lower)
|
| 2154 |
+
query_terms = [w for w in words if len(w) > 2 and w not in stop_words]
|
| 2155 |
+
|
| 2156 |
+
# Keyword scoring with inverted index
|
| 2157 |
+
keyword_scores = {}
|
| 2158 |
+
if inverted_index is not None:
|
| 2159 |
+
inv_index_candidates = set()
|
| 2160 |
+
for term in query_terms:
|
| 2161 |
+
if term in inverted_index:
|
| 2162 |
+
inv_index_candidates.update(inverted_index[term])
|
| 2163 |
+
|
| 2164 |
+
if inv_index_candidates:
|
| 2165 |
+
drug_specific_terms = set()
|
| 2166 |
+
for term in query_terms:
|
| 2167 |
+
if term in inverted_index and len(inverted_index[term]) < 100:
|
| 2168 |
+
drug_specific_terms.add(term)
|
| 2169 |
+
|
| 2170 |
+
for idx in inv_index_candidates:
|
| 2171 |
+
chunk_data = doc_chunks[idx]
|
| 2172 |
+
chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
|
| 2173 |
+
chunk_lower = chunk_text.lower()
|
| 2174 |
+
|
| 2175 |
+
has_drug_match = any(drug_term in chunk_lower for drug_term in drug_specific_terms)
|
| 2176 |
+
|
| 2177 |
+
if has_drug_match:
|
| 2178 |
+
keyword_scores[idx] = 1000.0
|
| 2179 |
+
else:
|
| 2180 |
+
keyword_scores[idx] = 1.0
|
| 2181 |
+
|
| 2182 |
+
# Semantic scoring
|
| 2183 |
+
load_embedder()
|
| 2184 |
+
query_embedding = embedder.encode([search_query])[0]
|
| 2185 |
+
semantic_similarities = np.dot(doc_embeddings, query_embedding)
|
| 2186 |
+
|
| 2187 |
+
# Normalize and combine scores
|
| 2188 |
+
if keyword_scores:
|
| 2189 |
+
max_kw = max(keyword_scores.values())
|
| 2190 |
+
keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()}
|
| 2191 |
+
else:
|
| 2192 |
+
keyword_scores_norm = {}
|
| 2193 |
+
|
| 2194 |
+
max_sem = semantic_similarities.max()
|
| 2195 |
+
min_sem = semantic_similarities.min()
|
| 2196 |
+
semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10)
|
| 2197 |
+
|
| 2198 |
+
# Combined scores
|
| 2199 |
+
combined_scores = np.zeros(len(doc_chunks))
|
| 2200 |
+
for idx in range(len(doc_chunks)):
|
| 2201 |
+
kw_score = keyword_scores_norm.get(idx, 0.0)
|
| 2202 |
+
sem_score = semantic_scores_norm[idx]
|
| 2203 |
+
combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score if kw_score > 0 else sem_score
|
| 2204 |
+
|
| 2205 |
+
# Get top candidates
|
| 2206 |
+
top_indices = np.argsort(combined_scores)[-candidate_k:][::-1]
|
| 2207 |
+
|
| 2208 |
+
# Format as (score, text) tuples
|
| 2209 |
+
candidate_trials = [(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i])
|
| 2210 |
+
for i in top_indices]
|
| 2211 |
+
|
| 2212 |
+
rag_time = time.time() - step2_start
|
| 2213 |
+
logger.info(f"[STRUCTURED API] RAG search complete in {rag_time:.1f}s, found {len(candidate_trials)} candidates")
|
| 2214 |
+
|
| 2215 |
+
# Step 3: Rank with 355M perplexity
|
| 2216 |
+
step3_start = time.time()
|
| 2217 |
+
logger.info("[STRUCTURED API] Step 3: Ranking with 355M perplexity...")
|
| 2218 |
+
|
| 2219 |
+
ranked_trials = rank_trials_with_355m_perplexity(query, candidate_trials, hf_token=hf_token)
|
| 2220 |
+
|
| 2221 |
+
ranking_time = time.time() - step3_start
|
| 2222 |
+
logger.info(f"[STRUCTURED API] 355M ranking complete in {ranking_time:.1f}s")
|
| 2223 |
+
|
| 2224 |
+
# Format results
|
| 2225 |
+
result['results'] = {
|
| 2226 |
+
'total_found': len(candidate_trials),
|
| 2227 |
+
'returned': min(top_k, len(ranked_trials)),
|
| 2228 |
+
'top_relevance_score': ranked_trials[0]['combined_score'] if ranked_trials else 0
|
| 2229 |
+
}
|
| 2230 |
+
|
| 2231 |
+
# Parse trials and add to results
|
| 2232 |
+
for trial_data in ranked_trials[:top_k]:
|
| 2233 |
+
trial_dict = parse_trial_text_to_dict(trial_data['trial_text'], trial_data['nct_id'])
|
| 2234 |
+
trial_dict['scoring'] = {
|
| 2235 |
+
'relevance_score': trial_data['combined_score'],
|
| 2236 |
+
'hybrid_score': trial_data['hybrid_score'],
|
| 2237 |
+
'perplexity': trial_data['perplexity'],
|
| 2238 |
+
'perplexity_score': trial_data['perplexity_score'],
|
| 2239 |
+
'rank_before_355m': trial_data['rank_before_355m'],
|
| 2240 |
+
'rank_after_355m': trial_data['rank_after_355m'],
|
| 2241 |
+
'ranking_method': '355m_perplexity' if trial_data['perplexity'] is not None else 'hybrid_only'
|
| 2242 |
+
}
|
| 2243 |
+
trial_dict['url'] = f"https://clinicaltrials.gov/study/{trial_data['nct_id']}"
|
| 2244 |
+
result['trials'].append(trial_dict)
|
| 2245 |
+
|
| 2246 |
+
# Benchmarking data
|
| 2247 |
+
if ranked_trials:
|
| 2248 |
+
# Calculate how much 355M changed the ranking
|
| 2249 |
+
rank_changes = []
|
| 2250 |
+
for trial in ranked_trials[:top_k]:
|
| 2251 |
+
if trial['perplexity'] is not None:
|
| 2252 |
+
rank_change = trial['rank_before_355m'] - trial['rank_after_355m']
|
| 2253 |
+
rank_changes.append(rank_change)
|
| 2254 |
+
|
| 2255 |
+
result['benchmarking'] = {
|
| 2256 |
+
'rag_search_time': rag_time,
|
| 2257 |
+
'355m_ranking_time': ranking_time,
|
| 2258 |
+
'total_processing_time': time.time() - start_time,
|
| 2259 |
+
'trials_ranked_by_355m': len([t for t in ranked_trials if t['perplexity'] is not None]),
|
| 2260 |
+
'average_rank_change': sum(rank_changes) / len(rank_changes) if rank_changes else 0,
|
| 2261 |
+
'max_rank_improvement': max(rank_changes) if rank_changes else 0,
|
| 2262 |
+
'top_3_perplexity_scores': [t['perplexity'] for t in ranked_trials[:3] if t['perplexity'] is not None]
|
| 2263 |
+
}
|
| 2264 |
+
|
| 2265 |
+
# Metadata
|
| 2266 |
+
result['metadata'] = {
|
| 2267 |
+
'database_version': '2025-01-06',
|
| 2268 |
+
'total_trials_searched': len(doc_chunks),
|
| 2269 |
+
'api_version': '2.0.0',
|
| 2270 |
+
'model_info': {
|
| 2271 |
+
'query_parser': 'Llama-3.1-70B-Instruct',
|
| 2272 |
+
'ranking_model': 'gmkdigitalmedia/CT2-355M',
|
| 2273 |
+
'embedding_model': 'all-MiniLM-L6-v2'
|
| 2274 |
+
}
|
| 2275 |
+
}
|
| 2276 |
+
|
| 2277 |
+
result['processing_time'] = time.time() - start_time
|
| 2278 |
+
|
| 2279 |
+
logger.info(f"[STRUCTURED API] ✓ Complete in {result['processing_time']:.1f}s")
|
| 2280 |
+
|
| 2281 |
+
return result
|
| 2282 |
+
|
| 2283 |
+
except Exception as e:
|
| 2284 |
+
logger.error(f"[STRUCTURED API] Error: {e}")
|
| 2285 |
+
import traceback
|
| 2286 |
+
result['error'] = str(e)
|
| 2287 |
+
result['traceback'] = traceback.format_exc()
|
| 2288 |
+
result['processing_time'] = time.time() - start_time
|
| 2289 |
+
return result
|
| 2290 |
+
|
| 2291 |
+
|
| 2292 |
# ============================================================================
|
| 2293 |
# GRADIO INTERFACE
|
| 2294 |
# ============================================================================
|