CTapi-raw / foundation_rag_optionB.py
Your Name
Deploy Option B: Query Parser + RAG + 355M Ranking
45cf63e
"""
Foundation RAG - Option B: Clean 1-LLM Architecture
====================================================
Pipeline:
1. Query Parser LLM (Llama-70B) β†’ Extract entities + synonyms (3s, $0.001)
2. RAG Search (BM25 + Semantic + Inverted Index) β†’ Retrieve candidates (2s, free)
3. 355M Perplexity Ranking β†’ Rank by clinical relevance (2-5s, free)
4. Structured JSON Output β†’ Return ranked trials (instant, free)
Total: ~7-10 seconds, $0.001 per query
No response generation - clients handle that with their own LLMs
"""
import os
import time
import logging
import numpy as np
import torch
import re
from pathlib import Path
from typing import List, Dict, Tuple, Optional
from sentence_transformers import SentenceTransformer
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from huggingface_hub import InferenceClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# CONFIGURATION
# ============================================================================
hf_token = os.getenv("HF_TOKEN")
# Data paths (check /tmp first, then local)
DATA_DIR = Path("/tmp/foundation_data")
if not DATA_DIR.exists():
DATA_DIR = Path(__file__).parent
CHUNKS_FILE = DATA_DIR / "dataset_chunks_TRIAL_AWARE.pkl"
EMBEDDINGS_FILE = DATA_DIR / "dataset_embeddings_TRIAL_AWARE_FIXED.npy"
INVERTED_INDEX_FILE = DATA_DIR / "inverted_index_COMPREHENSIVE.pkl"
# Global state
embedder = None
doc_chunks = []
doc_embeddings = None
inverted_index = None
model_355m = None
tokenizer_355m = None
# ============================================================================
# STEP 1: QUERY PARSER LLM (Llama-70B)
# ============================================================================
def parse_query_with_llm(query: str, hf_token: str = None) -> Dict:
"""
Use Llama-70B to parse query and extract entities
Cost: $0.001 per query
Time: ~3 seconds
Returns:
{
'drugs': [...],
'diseases': [...],
'companies': [...],
'endpoints': [...],
'search_terms': "optimized search query"
}
"""
try:
logger.info("[QUERY PARSER] Analyzing query with Llama-70B...")
client = InferenceClient(token=hf_token, timeout=30)
parse_prompt = f"""You are an expert in clinical trial terminology. Extract entities from this query.
Query: "{query}"
Extract ALL possible names and synonyms:
DRUGS:
- Brand names, generic names, research codes (e.g., BNT162b2)
- Chemical names, abbreviations
- Company+drug combinations (e.g., Pfizer-BioNTech vaccine)
DISEASES:
- Medical synonyms, ICD-10 terms
- Technical and colloquial terms
- Related conditions
COMPANIES:
- Parent companies, subsidiaries
- Previous names, partnerships
ENDPOINTS:
- Specific outcomes or measures mentioned
SEARCH_TERMS:
- Comprehensive keywords for search
Format EXACTLY as:
DRUGS: [list or "none"]
DISEASES: [list or "none"]
COMPANIES: [list or "none"]
ENDPOINTS: [list or "none"]
SEARCH_TERMS: [comprehensive keyword list]"""
response = client.chat_completion(
model="meta-llama/Llama-3.1-70B-Instruct",
messages=[{"role": "user", "content": parse_prompt}],
max_tokens=500,
temperature=0.3
)
parsed = response.choices[0].message.content.strip()
logger.info(f"[QUERY PARSER] βœ“ Entities extracted")
# Parse response
result = {
'drugs': [],
'diseases': [],
'companies': [],
'endpoints': [],
'search_terms': query
}
for line in parsed.split('\n'):
line = line.strip()
if line.startswith('DRUGS:'):
drugs = line.replace('DRUGS:', '').strip().strip('[]')
if drugs and drugs.lower() != 'none':
result['drugs'] = [d.strip().strip('"\'') for d in drugs.split(',')]
elif line.startswith('DISEASES:'):
diseases = line.replace('DISEASES:', '').strip().strip('[]')
if diseases and diseases.lower() != 'none':
result['diseases'] = [d.strip().strip('"\'') for d in diseases.split(',')]
elif line.startswith('COMPANIES:'):
companies = line.replace('COMPANIES:', '').strip().strip('[]')
if companies and companies.lower() != 'none':
result['companies'] = [c.strip().strip('"\'') for c in companies.split(',')]
elif line.startswith('ENDPOINTS:'):
endpoints = line.replace('ENDPOINTS:', '').strip().strip('[]')
if endpoints and endpoints.lower() != 'none':
result['endpoints'] = [e.strip().strip('"\'') for e in endpoints.split(',')]
elif line.startswith('SEARCH_TERMS:'):
terms = line.replace('SEARCH_TERMS:', '').strip().strip('[]')
if terms:
result['search_terms'] = terms.strip('"\'')
return result
except Exception as e:
logger.warning(f"[QUERY PARSER] Failed: {e}, using original query")
return {
'drugs': [],
'diseases': [],
'companies': [],
'endpoints': [],
'search_terms': query,
'error': str(e)
}
# ============================================================================
# STEP 2: RAG SEARCH (Hybrid: BM25 + Semantic + Inverted Index)
# ============================================================================
def load_embedder():
"""Load embedding model for semantic search"""
global embedder
if embedder is None:
logger.info("[RAG] Loading MiniLM-L6 embedding model...")
embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
logger.info("[RAG] βœ“ Embedder loaded")
def hybrid_rag_search(search_query: str, top_k: int = 30) -> List[Tuple[float, str]]:
"""
Hybrid RAG search combining:
1. Inverted index (O(1) keyword lookup)
2. Semantic embeddings (MiniLM-L6)
3. Smart scoring (drugs get 1000x boost)
Time: ~2 seconds
Cost: $0 (all local)
Returns:
List of (score, trial_text) tuples
"""
global doc_chunks, doc_embeddings, embedder, inverted_index
if doc_embeddings is None or len(doc_chunks) == 0:
raise Exception("Embeddings not loaded!")
logger.info(f"[RAG] Searching {len(doc_chunks):,} trials...")
# Extract keywords
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to',
'for', 'of', 'with', 'is', 'are', 'was', 'were', 'be', 'been'}
words = re.findall(r'\b\w+\b', search_query.lower())
query_terms = [w for w in words if len(w) > 2 and w not in stop_words]
# Keyword scoring with inverted index
keyword_scores = {}
if inverted_index is not None:
inv_index_candidates = set()
for term in query_terms:
if term in inverted_index:
inv_index_candidates.update(inverted_index[term])
if inv_index_candidates:
# Identify drug-specific terms (rare = specific)
drug_specific_terms = {term for term in query_terms
if term in inverted_index and len(inverted_index[term]) < 100}
for idx in inv_index_candidates:
chunk_text = doc_chunks[idx][1] if isinstance(doc_chunks[idx], tuple) else doc_chunks[idx]
chunk_lower = chunk_text.lower()
# Drug match gets 1000x boost (critical for pharma queries)
has_drug_match = any(drug_term in chunk_lower for drug_term in drug_specific_terms)
keyword_scores[idx] = 1000.0 if has_drug_match else 1.0
# Semantic scoring
load_embedder()
query_embedding = embedder.encode([search_query])[0]
semantic_similarities = np.dot(doc_embeddings, query_embedding)
# Normalize scores
if keyword_scores:
max_kw = max(keyword_scores.values())
keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()}
else:
keyword_scores_norm = {}
max_sem = semantic_similarities.max()
min_sem = semantic_similarities.min()
semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10)
# Combine: 50% keyword, 50% semantic (keyword-matched trials prioritized)
combined_scores = np.zeros(len(doc_chunks))
for idx in range(len(doc_chunks)):
kw_score = keyword_scores_norm.get(idx, 0.0)
sem_score = semantic_scores_norm[idx]
combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score if kw_score > 0 else sem_score
# Get top candidates
top_indices = np.argsort(combined_scores)[-top_k:][::-1]
results = [
(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i])
for i in top_indices
]
logger.info(f"[RAG] βœ“ Found {len(results)} candidates (top score: {results[0][0]:.3f})")
return results
# ============================================================================
# STEP 3: 355M PERPLEXITY RANKING
# ============================================================================
def load_355m_model():
"""Load 355M Clinical Trial GPT model (cached)"""
global model_355m, tokenizer_355m
if model_355m is None:
logger.info("[355M] Loading CT2 model for perplexity ranking...")
tokenizer_355m = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/CT2")
model_355m = GPT2LMHeadModel.from_pretrained(
"gmkdigitalmedia/CT2",
torch_dtype=torch.float16,
device_map="auto"
)
model_355m.eval()
tokenizer_355m.pad_token = tokenizer_355m.eos_token
logger.info("[355M] βœ“ Model loaded")
def rank_with_355m_perplexity(query: str, candidates: List[Tuple[float, str]]) -> List[Dict]:
"""
Rank trials using 355M model's perplexity scores
Perplexity = "How natural does this query-trial pairing seem?"
Lower perplexity = more relevant
Time: ~2-5 seconds (depends on GPU)
Cost: $0 (local model)
Returns:
List of dicts with trial data and scores
"""
load_355m_model()
# Only rank top 10 (balance accuracy vs speed)
top_10 = candidates[:10]
logger.info(f"[355M] Ranking {len(top_10)} trials with perplexity...")
ranked_trials = []
for idx, (hybrid_score, trial_text) in enumerate(top_10):
# Extract NCT ID
nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
# Format test text
test_text = f"""Query: {query}
Relevant Clinical Trial:
{trial_text[:800]}
This trial is highly relevant because"""
# Calculate perplexity
inputs = tokenizer_355m(
test_text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(model_355m.device)
with torch.no_grad():
outputs = model_355m(**inputs, labels=inputs.input_ids)
perplexity = torch.exp(outputs.loss).item()
# Convert perplexity to 0-1 score
perplexity_score = 1.0 / (1.0 + perplexity / 100)
# Combine: 70% hybrid search, 30% perplexity
combined_score = 0.7 * hybrid_score + 0.3 * perplexity_score
logger.info(f"[355M] {nct_id}: Perplexity={perplexity:.1f}, Combined={combined_score:.3f}")
ranked_trials.append({
'nct_id': nct_id,
'trial_text': trial_text,
'hybrid_score': float(hybrid_score),
'perplexity': float(perplexity),
'perplexity_score': float(perplexity_score),
'combined_score': float(combined_score),
'rank_before_355m': idx + 1
})
# Sort by combined score
ranked_trials.sort(key=lambda x: x['combined_score'], reverse=True)
# Add final ranks
for idx, trial in enumerate(ranked_trials):
trial['rank_after_355m'] = idx + 1
logger.info(f"[355M] βœ“ Ranking complete")
# Add remaining trials (without 355M scoring)
for idx, (hybrid_score, trial_text) in enumerate(candidates[10:], start=10):
nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
ranked_trials.append({
'nct_id': nct_id,
'trial_text': trial_text,
'hybrid_score': float(hybrid_score),
'perplexity': None,
'perplexity_score': None,
'combined_score': float(hybrid_score),
'rank_before_355m': idx + 1,
'rank_after_355m': len(ranked_trials) + 1
})
return ranked_trials
# ============================================================================
# STEP 4: STRUCTURED JSON OUTPUT
# ============================================================================
def parse_trial_to_dict(trial_text: str, nct_id: str) -> Dict:
"""
Parse trial text into structured fields
Extracts:
- title, status, phase, conditions, interventions
- sponsor, enrollment, dates
- description, outcomes
"""
trial = {'nct_id': nct_id, 'url': f"https://clinicaltrials.gov/study/{nct_id}"}
# Extract fields using regex
fields = {
'title': r'TITLE:\s*([^\n]+)',
'status': r'STATUS:\s*([^\n]+)',
'phase': r'PHASE:\s*([^\n]+)',
'conditions': r'CONDITIONS:\s*([^\n]+)',
'interventions': r'INTERVENTION:\s*([^\n]+)',
'sponsor': r'SPONSOR:\s*([^\n]+)',
'enrollment': r'ENROLLMENT:\s*([^\n]+)',
'primary_outcome': r'PRIMARY OUTCOME:\s*([^\n]+)',
'description': r'DESCRIPTION:\s*([^\n]+)'
}
for field, pattern in fields.items():
match = re.search(pattern, trial_text, re.IGNORECASE)
trial[field] = match.group(1).strip() if match else None
return trial
def process_query_option_b(query: str, top_k: int = 10) -> Dict:
"""
Complete Option B pipeline
1. Parse query with LLM
2. RAG search
3. 355M perplexity ranking
4. Return structured JSON
Total time: ~7-10 seconds
Total cost: $0.001 per query
Returns:
{
'query': str,
'processing_time': float,
'query_analysis': {
'extracted_entities': {...},
'optimized_search': str,
'parsing_time': float
},
'results': {
'total_found': int,
'returned': int,
'top_relevance_score': float
},
'trials': [
{
'nct_id': str,
'title': str,
'status': str,
...
'scoring': {
'relevance_score': float,
'perplexity': float,
'rank_before_355m': int,
'rank_after_355m': int
},
'url': str
}
],
'benchmarking': {
'query_parsing_time': float,
'rag_search_time': float,
'355m_ranking_time': float,
'total_processing_time': float
}
}
"""
start_time = time.time()
result = {
'query': query,
'processing_time': 0,
'query_analysis': {},
'results': {},
'trials': [],
'benchmarking': {}
}
try:
# Step 1: Parse query with LLM
step1_start = time.time()
parsed_query = parse_query_with_llm(query, hf_token=hf_token)
search_query = parsed_query['search_terms']
result['query_analysis'] = {
'extracted_entities': {
'drugs': parsed_query.get('drugs', []),
'diseases': parsed_query.get('diseases', []),
'companies': parsed_query.get('companies', []),
'endpoints': parsed_query.get('endpoints', [])
},
'optimized_search': search_query,
'parsing_time': time.time() - step1_start
}
# Step 2: RAG search
step2_start = time.time()
candidates = hybrid_rag_search(search_query, top_k=top_k * 3)
rag_time = time.time() - step2_start
# Step 3: 355M perplexity ranking
step3_start = time.time()
ranked_trials = rank_with_355m_perplexity(query, candidates)
ranking_time = time.time() - step3_start
# Step 4: Format structured output
result['results'] = {
'total_found': len(candidates),
'returned': min(top_k, len(ranked_trials)),
'top_relevance_score': ranked_trials[0]['combined_score'] if ranked_trials else 0
}
# Parse trials
for trial_data in ranked_trials[:top_k]:
trial_dict = parse_trial_to_dict(trial_data['trial_text'], trial_data['nct_id'])
trial_dict['scoring'] = {
'relevance_score': trial_data['combined_score'],
'hybrid_score': trial_data['hybrid_score'],
'perplexity': trial_data['perplexity'],
'perplexity_score': trial_data['perplexity_score'],
'rank_before_355m': trial_data['rank_before_355m'],
'rank_after_355m': trial_data['rank_after_355m'],
'ranking_method': '355m_perplexity' if trial_data['perplexity'] is not None else 'hybrid_only'
}
result['trials'].append(trial_dict)
# Benchmarking
result['benchmarking'] = {
'query_parsing_time': result['query_analysis']['parsing_time'],
'rag_search_time': rag_time,
'355m_ranking_time': ranking_time,
'total_processing_time': time.time() - start_time
}
result['processing_time'] = time.time() - start_time
logger.info(f"[OPTION B] βœ“ Complete in {result['processing_time']:.1f}s")
return result
except Exception as e:
logger.error(f"[OPTION B] Error: {e}")
import traceback
result['error'] = str(e)
result['traceback'] = traceback.format_exc()
result['processing_time'] = time.time() - start_time
return result
# ============================================================================
# INITIALIZATION
# ============================================================================
def load_all_data():
"""Load embeddings, chunks, and inverted index at startup"""
global doc_chunks, doc_embeddings, inverted_index
import pickle
logger.info("=" * 60)
logger.info("LOADING FOUNDATION RAG - OPTION B")
logger.info("=" * 60)
# Load chunks
if CHUNKS_FILE.exists():
logger.info(f"Loading chunks from {CHUNKS_FILE}...")
with open(CHUNKS_FILE, 'rb') as f:
doc_chunks = pickle.load(f)
logger.info(f"βœ“ Loaded {len(doc_chunks):,} trial chunks")
# Load embeddings
if EMBEDDINGS_FILE.exists():
logger.info(f"Loading embeddings from {EMBEDDINGS_FILE}...")
doc_embeddings = np.load(EMBEDDINGS_FILE)
logger.info(f"βœ“ Loaded embeddings: {doc_embeddings.shape}")
# Load inverted index
if INVERTED_INDEX_FILE.exists():
logger.info(f"Loading inverted index from {INVERTED_INDEX_FILE}...")
with open(INVERTED_INDEX_FILE, 'rb') as f:
inverted_index = pickle.load(f)
logger.info(f"βœ“ Loaded inverted index: {len(inverted_index):,} terms")
logger.info("=" * 60)
logger.info("READY - Option B Pipeline Active")
logger.info("=" * 60)
# ============================================================================
# EXAMPLE USAGE
# ============================================================================
if __name__ == "__main__":
# Load data
load_all_data()
# Test query
test_query = "What are the results for ianalumab in Sjogren's syndrome?"
print(f"\nProcessing: {test_query}\n")
result = process_query_option_b(test_query, top_k=5)
print(f"\n{'='*60}")
print("RESULTS")
print(f"{'='*60}\n")
print(f"Processing Time: {result['processing_time']:.1f}s")
print(f"Query Parsing: {result['query_analysis']['parsing_time']:.1f}s")
print(f"RAG Search: {result['benchmarking']['rag_search_time']:.1f}s")
print(f"355M Ranking: {result['benchmarking']['355m_ranking_time']:.1f}s\n")
print(f"Extracted Entities:")
for entity_type, values in result['query_analysis']['extracted_entities'].items():
print(f" {entity_type}: {values}")
print(f"\nTop {len(result['trials'])} Trials:\n")
for i, trial in enumerate(result['trials'], 1):
print(f"{i}. {trial['nct_id']}: {trial.get('title', 'No title')}")
print(f" Relevance: {trial['scoring']['relevance_score']:.3f}")
print(f" Perplexity: {trial['scoring']['perplexity']:.1f if trial['scoring']['perplexity'] else 'N/A'}")
print(f" Rank change: {trial['scoring']['rank_before_355m']} β†’ {trial['scoring']['rank_after_355m']}")
print()