Your Name Claude commited on
Commit
4213e35
·
1 Parent(s): 6997480

Add /search endpoint with 355M perplexity ranking (Option B implementation)

Browse files

NEW FEATURES:
- New /search endpoint returns structured JSON (no LLM response generation)
- Keeps Query Parser LLM for entity extraction + synonym expansion
- Implements 355M Clinical Trial GPT perplexity-based re-ranking
- Includes before/after benchmarking metrics for 355M impact
- Returns trials ranked by clinical relevance (70% hybrid + 30% perplexity)

IMPLEMENTATION:
- rank_trials_with_355m_perplexity(): Uses perplexity scoring (not generation) to avoid hallucinations
- parse_trial_text_to_dict(): Parses trial text into structured fields
- process_query_structured(): Main function for /search endpoint

BENCHMARKING:
- rank_before_355m: Original hybrid search ranking
- rank_after_355m: Final ranking after 355M perplexity adjustment
- perplexity: Raw perplexity score (lower = more relevant)
- perplexity_score: Normalized 0-1 score
- Processing time breakdown for each stage

API Response includes:
- query_analysis: Extracted entities and optimized search terms
- results: Total found, returned count, top relevance score
- trials[]: Structured trial data with scoring metadata
- benchmarking: Performance metrics and 355M ranking impact
- metadata: Model versions and database info

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +70 -5
  2. foundation_engine.py +475 -0
app.py CHANGED
@@ -40,6 +40,10 @@ class QueryResponse(BaseModel):
40
  summary: str
41
  processing_time: float
42
 
 
 
 
 
43
  class HealthResponse(BaseModel):
44
  status: str
45
  trials_loaded: int
@@ -64,18 +68,22 @@ async def root():
64
  """API information"""
65
  return {
66
  "service": "Clinical Trial API",
67
- "version": "1.0.0",
68
- "description": "Production REST API for Foundation 1.2",
69
  "status": "healthy",
70
  "endpoints": {
71
- "POST /query": "Query clinical trials and get AI-generated summary",
 
72
  "GET /health": "Health check",
73
  "GET /docs": "Interactive API documentation (Swagger UI)",
74
  "GET /redoc": "Alternative API documentation (ReDoc)"
75
  },
76
  "features": [
77
- "Drug Scoring",
78
- "355M foundation model"
 
 
 
79
  ]
80
  }
81
 
@@ -123,6 +131,63 @@ async def query_trials(request: QueryRequest):
123
  logger.error(f"Error processing query: {str(e)}")
124
  raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  if __name__ == "__main__":
127
  import uvicorn
128
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
40
  summary: str
41
  processing_time: float
42
 
43
+ class SearchRequest(BaseModel):
44
+ query: str
45
+ top_k: int = 10
46
+
47
  class HealthResponse(BaseModel):
48
  status: str
49
  trials_loaded: int
 
68
  """API information"""
69
  return {
70
  "service": "Clinical Trial API",
71
+ "version": "2.0.0",
72
+ "description": "Production REST API for Foundation 1.2 with 355M perplexity ranking",
73
  "status": "healthy",
74
  "endpoints": {
75
+ "POST /search": "[NEW] Search trials with structured JSON output (includes 355M ranking)",
76
+ "POST /query": "Query clinical trials and get AI-generated summary (legacy)",
77
  "GET /health": "Health check",
78
  "GET /docs": "Interactive API documentation (Swagger UI)",
79
  "GET /redoc": "Alternative API documentation (ReDoc)"
80
  },
81
  "features": [
82
+ "LLM Query Parser (entity extraction + synonyms)",
83
+ "Hybrid RAG Search (BM25 + semantic + inverted index)",
84
+ "355M Clinical Trial GPT perplexity-based ranking",
85
+ "Structured JSON output",
86
+ "Benchmarking metrics (before/after 355M scores)"
87
  ]
88
  }
89
 
 
131
  logger.error(f"Error processing query: {str(e)}")
132
  raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
133
 
134
+ @app.post("/search")
135
+ async def search_trials(request: SearchRequest):
136
+ """
137
+ Search clinical trials and get structured JSON results (NEW API v2.0)
138
+
139
+ This endpoint provides:
140
+ - Query parsing with LLM (entity extraction + synonym expansion)
141
+ - Hybrid RAG search (BM25 + semantic embeddings + inverted index)
142
+ - 355M Clinical Trial GPT perplexity-based re-ranking
143
+ - Structured JSON output with benchmarking data
144
+
145
+ **No response generation** - returns raw trial data for client-side processing
146
+
147
+ Args:
148
+ - **query**: Your question about clinical trials
149
+ - **top_k**: Number of trials to return (default: 10, max: 50)
150
+
151
+ Returns:
152
+ - Structured JSON with trials ranked by clinical relevance
153
+ - Includes before/after 355M ranking scores for benchmarking
154
+ - Processing time breakdown (query parsing, RAG search, 355M ranking)
155
+ """
156
+ try:
157
+ logger.info(f"[SEARCH API] Query received: {request.query[:100]}...")
158
+
159
+ # Validate top_k
160
+ if request.top_k > 50:
161
+ logger.warning(f"[SEARCH API] top_k={request.top_k} exceeds maximum 50, capping")
162
+ request.top_k = 50
163
+ elif request.top_k < 1:
164
+ logger.warning(f"[SEARCH API] top_k={request.top_k} is invalid, using default 10")
165
+ request.top_k = 10
166
+
167
+ start_time = time.time()
168
+
169
+ # Call the structured query processor
170
+ result = foundation_engine.process_query_structured(request.query, top_k=request.top_k)
171
+
172
+ processing_time = time.time() - start_time
173
+ logger.info(f"[SEARCH API] Query completed in {processing_time:.2f}s")
174
+
175
+ # Ensure processing_time is set
176
+ if 'processing_time' not in result or result['processing_time'] == 0:
177
+ result['processing_time'] = processing_time
178
+
179
+ return result
180
+
181
+ except Exception as e:
182
+ logger.error(f"[SEARCH API] Error processing query: {str(e)}")
183
+ import traceback
184
+ return {
185
+ "error": str(e),
186
+ "traceback": traceback.format_exc(),
187
+ "query": request.query,
188
+ "processing_time": time.time() - start_time if 'start_time' in locals() else 0
189
+ }
190
+
191
  if __name__ == "__main__":
192
  import uvicorn
193
  uvicorn.run(app, host="0.0.0.0", port=7860)
foundation_engine.py CHANGED
@@ -1814,6 +1814,481 @@ Query Type Distribution:
1814
  return report
1815
 
1816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1817
  # ============================================================================
1818
  # GRADIO INTERFACE
1819
  # ============================================================================
 
1814
  return report
1815
 
1816
 
1817
+ # ============================================================================
1818
+ # 355M PERPLEXITY-BASED RANKING (FOR STRUCTURED JSON API)
1819
+ # ============================================================================
1820
+
1821
+ def rank_trials_with_355m_perplexity(query, trials_list, hf_token=None):
1822
+ """
1823
+ Rank trials using 355M Clinical Trial GPT perplexity scoring
1824
+
1825
+ This uses the model for SCORING not GENERATION to avoid hallucinations
1826
+ Lower perplexity = more relevant trial
1827
+
1828
+ Args:
1829
+ query: User query
1830
+ trials_list: List of (score, trial_text) tuples from hybrid search
1831
+ hf_token: Not needed (model runs locally)
1832
+
1833
+ Returns:
1834
+ List of dicts with trial data and perplexity scores
1835
+ """
1836
+ import time
1837
+ import re
1838
+ import torch
1839
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
1840
+
1841
+ start_time = time.time()
1842
+
1843
+ # Only rank top 10 trials (balance between accuracy and speed)
1844
+ top_10 = trials_list[:10]
1845
+
1846
+ logger.info(f"[355M PERPLEXITY] Ranking {len(top_10)} trials with CT2 model...")
1847
+
1848
+ try:
1849
+ # Load 355M model
1850
+ tokenizer = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/CT2")
1851
+ model = GPT2LMHeadModel.from_pretrained(
1852
+ "gmkdigitalmedia/CT2",
1853
+ torch_dtype=torch.float16,
1854
+ device_map="auto"
1855
+ )
1856
+ model.eval()
1857
+ tokenizer.pad_token = tokenizer.eos_token
1858
+
1859
+ ranked_trials = []
1860
+
1861
+ for idx, (hybrid_score, trial_text) in enumerate(top_10):
1862
+ # Extract NCT ID
1863
+ nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
1864
+ nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
1865
+
1866
+ # Format test text for perplexity calculation
1867
+ # The model calculates: "How natural is this query-trial pairing?"
1868
+ test_text = f"""Query: {query}
1869
+
1870
+ Relevant Clinical Trial:
1871
+ {trial_text[:800]}
1872
+
1873
+ This trial is highly relevant because"""
1874
+
1875
+ # Calculate perplexity (lower = more relevant)
1876
+ inputs = tokenizer(
1877
+ test_text,
1878
+ return_tensors="pt",
1879
+ truncation=True,
1880
+ max_length=512,
1881
+ padding=True
1882
+ ).to(model.device)
1883
+
1884
+ with torch.no_grad():
1885
+ outputs = model(**inputs, labels=inputs.input_ids)
1886
+ perplexity = torch.exp(outputs.loss).item()
1887
+
1888
+ # Convert perplexity to relevance score (0-1)
1889
+ # Typical range: 10-1000, lower is better
1890
+ perplexity_score = 1.0 / (1.0 + perplexity / 100)
1891
+
1892
+ # Combine with hybrid score (70% hybrid, 30% perplexity)
1893
+ combined_score = 0.7 * hybrid_score + 0.3 * perplexity_score
1894
+
1895
+ logger.info(f"[355M] {nct_id}: Hybrid={hybrid_score:.3f}, "
1896
+ f"Perplexity={perplexity:.1f}, "
1897
+ f"Perplexity_Score={perplexity_score:.3f}, "
1898
+ f"Combined={combined_score:.3f}")
1899
+
1900
+ ranked_trials.append({
1901
+ 'nct_id': nct_id,
1902
+ 'trial_text': trial_text,
1903
+ 'hybrid_score': float(hybrid_score),
1904
+ 'perplexity': float(perplexity),
1905
+ 'perplexity_score': float(perplexity_score),
1906
+ 'combined_score': float(combined_score),
1907
+ 'rank_before_355m': idx + 1
1908
+ })
1909
+
1910
+ # Sort by combined score (descending)
1911
+ ranked_trials.sort(key=lambda x: x['combined_score'], reverse=True)
1912
+
1913
+ # Add final rank
1914
+ for idx, trial in enumerate(ranked_trials):
1915
+ trial['rank_after_355m'] = idx + 1
1916
+
1917
+ elapsed = time.time() - start_time
1918
+ logger.info(f"[355M PERPLEXITY] ✓ Ranking complete in {elapsed:.1f}s")
1919
+
1920
+ # Add remaining trials (beyond top 10) without 355M scoring
1921
+ for idx, (hybrid_score, trial_text) in enumerate(trials_list[10:], start=10):
1922
+ nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
1923
+ nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
1924
+
1925
+ ranked_trials.append({
1926
+ 'nct_id': nct_id,
1927
+ 'trial_text': trial_text,
1928
+ 'hybrid_score': float(hybrid_score),
1929
+ 'perplexity': None,
1930
+ 'perplexity_score': None,
1931
+ 'combined_score': float(hybrid_score),
1932
+ 'rank_before_355m': idx + 1,
1933
+ 'rank_after_355m': len(ranked_trials) + 1
1934
+ })
1935
+
1936
+ return ranked_trials
1937
+
1938
+ except Exception as e:
1939
+ logger.error(f"[355M PERPLEXITY] Error: {e}")
1940
+ logger.warning("[355M PERPLEXITY] Falling back to hybrid scores only")
1941
+
1942
+ # Fallback: return trials with hybrid scores only
1943
+ fallback_trials = []
1944
+ for idx, (hybrid_score, trial_text) in enumerate(trials_list):
1945
+ nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
1946
+ nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
1947
+
1948
+ fallback_trials.append({
1949
+ 'nct_id': nct_id,
1950
+ 'trial_text': trial_text,
1951
+ 'hybrid_score': float(hybrid_score),
1952
+ 'perplexity': None,
1953
+ 'perplexity_score': None,
1954
+ 'combined_score': float(hybrid_score),
1955
+ 'rank_before_355m': idx + 1,
1956
+ 'rank_after_355m': idx + 1
1957
+ })
1958
+
1959
+ return fallback_trials
1960
+
1961
+
1962
+ def parse_trial_text_to_dict(trial_text, nct_id):
1963
+ """
1964
+ Parse trial text into structured dictionary
1965
+
1966
+ Args:
1967
+ trial_text: Raw trial text
1968
+ nct_id: NCT ID
1969
+
1970
+ Returns:
1971
+ Dict with parsed trial fields
1972
+ """
1973
+ import re
1974
+
1975
+ trial_dict = {
1976
+ 'nct_id': nct_id,
1977
+ 'title': '',
1978
+ 'sponsor': '',
1979
+ 'collaborators': [],
1980
+ 'phase': '',
1981
+ 'status': '',
1982
+ 'enrollment': None,
1983
+ 'conditions': [],
1984
+ 'interventions': [],
1985
+ 'primary_outcome': '',
1986
+ 'results_summary': '',
1987
+ 'start_date': '',
1988
+ 'completion_date': '',
1989
+ 'last_update': '',
1990
+ 'locations': []
1991
+ }
1992
+
1993
+ # Extract fields using regex patterns
1994
+ try:
1995
+ # Title
1996
+ title_match = re.search(r'TITLE:\s*([^\n]+)', trial_text, re.IGNORECASE)
1997
+ if title_match:
1998
+ trial_dict['title'] = title_match.group(1).strip()
1999
+
2000
+ # Sponsor
2001
+ sponsor_match = re.search(r'SPONSOR:\s*([^\n]+)', trial_text, re.IGNORECASE)
2002
+ if sponsor_match:
2003
+ trial_dict['sponsor'] = sponsor_match.group(1).strip()
2004
+
2005
+ # Collaborators
2006
+ collab_match = re.search(r'COLLABORATOR[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE)
2007
+ if collab_match:
2008
+ collabs = collab_match.group(1).strip().split(',')
2009
+ trial_dict['collaborators'] = [c.strip() for c in collabs if c.strip()]
2010
+
2011
+ # Phase
2012
+ phase_match = re.search(r'PHASE:\s*([^\n]+)', trial_text, re.IGNORECASE)
2013
+ if phase_match:
2014
+ trial_dict['phase'] = phase_match.group(1).strip()
2015
+
2016
+ # Status
2017
+ status_match = re.search(r'STATUS:\s*([^\n]+)', trial_text, re.IGNORECASE)
2018
+ if status_match:
2019
+ trial_dict['status'] = status_match.group(1).strip()
2020
+
2021
+ # Enrollment
2022
+ enrollment_match = re.search(r'ENROLLMENT:\s*(\d+)', trial_text, re.IGNORECASE)
2023
+ if enrollment_match:
2024
+ trial_dict['enrollment'] = int(enrollment_match.group(1))
2025
+
2026
+ # Conditions
2027
+ condition_match = re.search(r'CONDITION[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE)
2028
+ if condition_match:
2029
+ conditions = condition_match.group(1).strip().split(',')
2030
+ trial_dict['conditions'] = [c.strip() for c in conditions if c.strip()]
2031
+
2032
+ # Interventions
2033
+ intervention_match = re.search(r'INTERVENTION[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE)
2034
+ if intervention_match:
2035
+ interventions = intervention_match.group(1).strip().split(',')
2036
+ trial_dict['interventions'] = [i.strip() for i in interventions if i.strip()]
2037
+
2038
+ # Primary outcome
2039
+ outcome_match = re.search(r'PRIMARY[_ ]OUTCOME:\s*([^\n]+)', trial_text, re.IGNORECASE)
2040
+ if outcome_match:
2041
+ trial_dict['primary_outcome'] = outcome_match.group(1).strip()
2042
+
2043
+ # Results summary
2044
+ results_match = re.search(r'RESULTS:\s*([^\n]+)', trial_text, re.IGNORECASE)
2045
+ if results_match:
2046
+ trial_dict['results_summary'] = results_match.group(1).strip()
2047
+
2048
+ # Dates
2049
+ start_match = re.search(r'START[_ ]DATE:\s*([^\n]+)', trial_text, re.IGNORECASE)
2050
+ if start_match:
2051
+ trial_dict['start_date'] = start_match.group(1).strip()
2052
+
2053
+ completion_match = re.search(r'COMPLETION[_ ]DATE:\s*([^\n]+)', trial_text, re.IGNORECASE)
2054
+ if completion_match:
2055
+ trial_dict['completion_date'] = completion_match.group(1).strip()
2056
+
2057
+ # Locations
2058
+ location_match = re.search(r'LOCATION[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE)
2059
+ if location_match:
2060
+ locations = location_match.group(1).strip().split(',')
2061
+ trial_dict['locations'] = [l.strip() for l in locations if l.strip()]
2062
+
2063
+ except Exception as e:
2064
+ logger.warning(f"Error parsing trial {nct_id}: {e}")
2065
+
2066
+ return trial_dict
2067
+
2068
+
2069
+ def process_query_structured(query, top_k=10):
2070
+ """
2071
+ Process query and return structured JSON (no LLM response generation)
2072
+
2073
+ This is the new API endpoint that:
2074
+ 1. Uses LLM for query parsing/entity extraction
2075
+ 2. Performs hybrid RAG search
2076
+ 3. Ranks with 355M perplexity scoring
2077
+ 4. Returns structured JSON
2078
+
2079
+ Args:
2080
+ query: User query
2081
+ top_k: Number of trials to return
2082
+
2083
+ Returns:
2084
+ Dict with structured response
2085
+ """
2086
+ import time
2087
+
2088
+ start_time = time.time()
2089
+
2090
+ result = {
2091
+ 'query': query,
2092
+ 'processing_time': 0,
2093
+ 'query_analysis': {},
2094
+ 'results': {},
2095
+ 'trials': [],
2096
+ 'benchmarking': {},
2097
+ 'metadata': {}
2098
+ }
2099
+
2100
+ try:
2101
+ # Step 1: Parse query with LLM
2102
+ step1_start = time.time()
2103
+ logger.info("[STRUCTURED API] Step 1: Parsing query with LLM...")
2104
+
2105
+ try:
2106
+ parsed_query = parse_query_with_llm(query, hf_token=hf_token)
2107
+ search_query = parsed_query['search_terms']
2108
+
2109
+ result['query_analysis'] = {
2110
+ 'extracted_entities': {
2111
+ 'drugs': parsed_query.get('drugs', []),
2112
+ 'diseases': parsed_query.get('diseases', []),
2113
+ 'companies': parsed_query.get('companies', []),
2114
+ 'endpoints': parsed_query.get('endpoints', [])
2115
+ },
2116
+ 'optimized_search': search_query,
2117
+ 'parsing_time': time.time() - step1_start
2118
+ }
2119
+ logger.info(f"[STRUCTURED API] Query parsed in {time.time() - step1_start:.1f}s")
2120
+
2121
+ except Exception as e:
2122
+ logger.warning(f"[STRUCTURED API] Query parsing failed: {e}, using original query")
2123
+ search_query = query
2124
+ parsed_query = {'drugs': [], 'diseases': [], 'companies': [], 'endpoints': []}
2125
+ result['query_analysis'] = {
2126
+ 'extracted_entities': parsed_query,
2127
+ 'optimized_search': search_query,
2128
+ 'parsing_time': time.time() - step1_start,
2129
+ 'error': str(e)
2130
+ }
2131
+
2132
+ # Step 2: Hybrid RAG search
2133
+ step2_start = time.time()
2134
+ logger.info("[STRUCTURED API] Step 2: Hybrid RAG search...")
2135
+
2136
+ # Get more candidates for 355M ranking
2137
+ candidate_k = top_k * 3
2138
+
2139
+ # We need to get the candidate trials with scores
2140
+ # Re-implement the key parts of retrieve_context_with_embeddings to get structured data
2141
+ from collections import Counter
2142
+ global doc_chunks, doc_embeddings, embedder, inverted_index
2143
+
2144
+ if doc_embeddings is None or len(doc_chunks) == 0:
2145
+ raise Exception("Embeddings not loaded!")
2146
+
2147
+ # Extract keywords
2148
+ stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
2149
+ 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'what', 'how', 'do', 'you', 'know',
2150
+ 'about', 'that', 'this', 'there', 'it'}
2151
+ query_lower = search_query.lower()
2152
+ import re
2153
+ words = re.findall(r'\b\w+\b', query_lower)
2154
+ query_terms = [w for w in words if len(w) > 2 and w not in stop_words]
2155
+
2156
+ # Keyword scoring with inverted index
2157
+ keyword_scores = {}
2158
+ if inverted_index is not None:
2159
+ inv_index_candidates = set()
2160
+ for term in query_terms:
2161
+ if term in inverted_index:
2162
+ inv_index_candidates.update(inverted_index[term])
2163
+
2164
+ if inv_index_candidates:
2165
+ drug_specific_terms = set()
2166
+ for term in query_terms:
2167
+ if term in inverted_index and len(inverted_index[term]) < 100:
2168
+ drug_specific_terms.add(term)
2169
+
2170
+ for idx in inv_index_candidates:
2171
+ chunk_data = doc_chunks[idx]
2172
+ chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
2173
+ chunk_lower = chunk_text.lower()
2174
+
2175
+ has_drug_match = any(drug_term in chunk_lower for drug_term in drug_specific_terms)
2176
+
2177
+ if has_drug_match:
2178
+ keyword_scores[idx] = 1000.0
2179
+ else:
2180
+ keyword_scores[idx] = 1.0
2181
+
2182
+ # Semantic scoring
2183
+ load_embedder()
2184
+ query_embedding = embedder.encode([search_query])[0]
2185
+ semantic_similarities = np.dot(doc_embeddings, query_embedding)
2186
+
2187
+ # Normalize and combine scores
2188
+ if keyword_scores:
2189
+ max_kw = max(keyword_scores.values())
2190
+ keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()}
2191
+ else:
2192
+ keyword_scores_norm = {}
2193
+
2194
+ max_sem = semantic_similarities.max()
2195
+ min_sem = semantic_similarities.min()
2196
+ semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10)
2197
+
2198
+ # Combined scores
2199
+ combined_scores = np.zeros(len(doc_chunks))
2200
+ for idx in range(len(doc_chunks)):
2201
+ kw_score = keyword_scores_norm.get(idx, 0.0)
2202
+ sem_score = semantic_scores_norm[idx]
2203
+ combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score if kw_score > 0 else sem_score
2204
+
2205
+ # Get top candidates
2206
+ top_indices = np.argsort(combined_scores)[-candidate_k:][::-1]
2207
+
2208
+ # Format as (score, text) tuples
2209
+ candidate_trials = [(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i])
2210
+ for i in top_indices]
2211
+
2212
+ rag_time = time.time() - step2_start
2213
+ logger.info(f"[STRUCTURED API] RAG search complete in {rag_time:.1f}s, found {len(candidate_trials)} candidates")
2214
+
2215
+ # Step 3: Rank with 355M perplexity
2216
+ step3_start = time.time()
2217
+ logger.info("[STRUCTURED API] Step 3: Ranking with 355M perplexity...")
2218
+
2219
+ ranked_trials = rank_trials_with_355m_perplexity(query, candidate_trials, hf_token=hf_token)
2220
+
2221
+ ranking_time = time.time() - step3_start
2222
+ logger.info(f"[STRUCTURED API] 355M ranking complete in {ranking_time:.1f}s")
2223
+
2224
+ # Format results
2225
+ result['results'] = {
2226
+ 'total_found': len(candidate_trials),
2227
+ 'returned': min(top_k, len(ranked_trials)),
2228
+ 'top_relevance_score': ranked_trials[0]['combined_score'] if ranked_trials else 0
2229
+ }
2230
+
2231
+ # Parse trials and add to results
2232
+ for trial_data in ranked_trials[:top_k]:
2233
+ trial_dict = parse_trial_text_to_dict(trial_data['trial_text'], trial_data['nct_id'])
2234
+ trial_dict['scoring'] = {
2235
+ 'relevance_score': trial_data['combined_score'],
2236
+ 'hybrid_score': trial_data['hybrid_score'],
2237
+ 'perplexity': trial_data['perplexity'],
2238
+ 'perplexity_score': trial_data['perplexity_score'],
2239
+ 'rank_before_355m': trial_data['rank_before_355m'],
2240
+ 'rank_after_355m': trial_data['rank_after_355m'],
2241
+ 'ranking_method': '355m_perplexity' if trial_data['perplexity'] is not None else 'hybrid_only'
2242
+ }
2243
+ trial_dict['url'] = f"https://clinicaltrials.gov/study/{trial_data['nct_id']}"
2244
+ result['trials'].append(trial_dict)
2245
+
2246
+ # Benchmarking data
2247
+ if ranked_trials:
2248
+ # Calculate how much 355M changed the ranking
2249
+ rank_changes = []
2250
+ for trial in ranked_trials[:top_k]:
2251
+ if trial['perplexity'] is not None:
2252
+ rank_change = trial['rank_before_355m'] - trial['rank_after_355m']
2253
+ rank_changes.append(rank_change)
2254
+
2255
+ result['benchmarking'] = {
2256
+ 'rag_search_time': rag_time,
2257
+ '355m_ranking_time': ranking_time,
2258
+ 'total_processing_time': time.time() - start_time,
2259
+ 'trials_ranked_by_355m': len([t for t in ranked_trials if t['perplexity'] is not None]),
2260
+ 'average_rank_change': sum(rank_changes) / len(rank_changes) if rank_changes else 0,
2261
+ 'max_rank_improvement': max(rank_changes) if rank_changes else 0,
2262
+ 'top_3_perplexity_scores': [t['perplexity'] for t in ranked_trials[:3] if t['perplexity'] is not None]
2263
+ }
2264
+
2265
+ # Metadata
2266
+ result['metadata'] = {
2267
+ 'database_version': '2025-01-06',
2268
+ 'total_trials_searched': len(doc_chunks),
2269
+ 'api_version': '2.0.0',
2270
+ 'model_info': {
2271
+ 'query_parser': 'Llama-3.1-70B-Instruct',
2272
+ 'ranking_model': 'gmkdigitalmedia/CT2-355M',
2273
+ 'embedding_model': 'all-MiniLM-L6-v2'
2274
+ }
2275
+ }
2276
+
2277
+ result['processing_time'] = time.time() - start_time
2278
+
2279
+ logger.info(f"[STRUCTURED API] ✓ Complete in {result['processing_time']:.1f}s")
2280
+
2281
+ return result
2282
+
2283
+ except Exception as e:
2284
+ logger.error(f"[STRUCTURED API] Error: {e}")
2285
+ import traceback
2286
+ result['error'] = str(e)
2287
+ result['traceback'] = traceback.format_exc()
2288
+ result['processing_time'] = time.time() - start_time
2289
+ return result
2290
+
2291
+
2292
  # ============================================================================
2293
  # GRADIO INTERFACE
2294
  # ============================================================================