""" Foundation RAG - Option B: Clean 1-LLM Architecture ==================================================== Pipeline: 1. Query Parser LLM (Llama-70B) → Extract entities + synonyms (3s, $0.001) 2. RAG Search (BM25 + Semantic + Inverted Index) → Retrieve candidates (2s, free) 3. 355M Perplexity Ranking → Rank by clinical relevance (2-5s, free) 4. Structured JSON Output → Return ranked trials (instant, free) Total: ~7-10 seconds, $0.001 per query No response generation - clients handle that with their own LLMs """ import os import time import logging import numpy as np import torch import re from pathlib import Path from typing import List, Dict, Tuple, Optional from sentence_transformers import SentenceTransformer from transformers import GPT2LMHeadModel, GPT2TokenizerFast from huggingface_hub import InferenceClient logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ============================================================================ # CONFIGURATION # ============================================================================ hf_token = os.getenv("HF_TOKEN") # Data paths (check /tmp first, then local) DATA_DIR = Path("/tmp/foundation_data") if not DATA_DIR.exists(): DATA_DIR = Path(__file__).parent CHUNKS_FILE = DATA_DIR / "dataset_chunks_TRIAL_AWARE.pkl" EMBEDDINGS_FILE = DATA_DIR / "dataset_embeddings_TRIAL_AWARE_FIXED.npy" INVERTED_INDEX_FILE = DATA_DIR / "inverted_index_COMPREHENSIVE.pkl" # Global state embedder = None doc_chunks = [] doc_embeddings = None inverted_index = None model_355m = None tokenizer_355m = None # ============================================================================ # STEP 1: QUERY PARSER LLM (Llama-70B) # ============================================================================ def parse_query_with_llm(query: str, hf_token: str = None) -> Dict: """ Use Llama-70B to parse query and extract entities Cost: $0.001 per query Time: ~3 seconds Returns: { 'drugs': [...], 'diseases': [...], 'companies': [...], 'endpoints': [...], 'search_terms': "optimized search query" } """ try: logger.info("[QUERY PARSER] Analyzing query with Llama-70B...") client = InferenceClient(token=hf_token, timeout=30) parse_prompt = f"""You are an expert in clinical trial terminology. Extract entities from this query. Query: "{query}" Extract ALL possible names and synonyms: DRUGS: - Brand names, generic names, research codes (e.g., BNT162b2) - Chemical names, abbreviations - Company+drug combinations (e.g., Pfizer-BioNTech vaccine) DISEASES: - Medical synonyms, ICD-10 terms - Technical and colloquial terms - Related conditions COMPANIES: - Parent companies, subsidiaries - Previous names, partnerships ENDPOINTS: - Specific outcomes or measures mentioned SEARCH_TERMS: - Comprehensive keywords for search Format EXACTLY as: DRUGS: [list or "none"] DISEASES: [list or "none"] COMPANIES: [list or "none"] ENDPOINTS: [list or "none"] SEARCH_TERMS: [comprehensive keyword list]""" response = client.chat_completion( model="meta-llama/Llama-3.1-70B-Instruct", messages=[{"role": "user", "content": parse_prompt}], max_tokens=500, temperature=0.3 ) parsed = response.choices[0].message.content.strip() logger.info(f"[QUERY PARSER] ✓ Entities extracted") # Parse response result = { 'drugs': [], 'diseases': [], 'companies': [], 'endpoints': [], 'search_terms': query } for line in parsed.split('\n'): line = line.strip() if line.startswith('DRUGS:'): drugs = line.replace('DRUGS:', '').strip().strip('[]') if drugs and drugs.lower() != 'none': result['drugs'] = [d.strip().strip('"\'') for d in drugs.split(',')] elif line.startswith('DISEASES:'): diseases = line.replace('DISEASES:', '').strip().strip('[]') if diseases and diseases.lower() != 'none': result['diseases'] = [d.strip().strip('"\'') for d in diseases.split(',')] elif line.startswith('COMPANIES:'): companies = line.replace('COMPANIES:', '').strip().strip('[]') if companies and companies.lower() != 'none': result['companies'] = [c.strip().strip('"\'') for c in companies.split(',')] elif line.startswith('ENDPOINTS:'): endpoints = line.replace('ENDPOINTS:', '').strip().strip('[]') if endpoints and endpoints.lower() != 'none': result['endpoints'] = [e.strip().strip('"\'') for e in endpoints.split(',')] elif line.startswith('SEARCH_TERMS:'): terms = line.replace('SEARCH_TERMS:', '').strip().strip('[]') if terms: result['search_terms'] = terms.strip('"\'') return result except Exception as e: logger.warning(f"[QUERY PARSER] Failed: {e}, using original query") return { 'drugs': [], 'diseases': [], 'companies': [], 'endpoints': [], 'search_terms': query, 'error': str(e) } # ============================================================================ # STEP 2: RAG SEARCH (Hybrid: BM25 + Semantic + Inverted Index) # ============================================================================ def load_embedder(): """Load embedding model for semantic search""" global embedder if embedder is None: logger.info("[RAG] Loading MiniLM-L6 embedding model...") embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') logger.info("[RAG] ✓ Embedder loaded") def hybrid_rag_search(search_query: str, top_k: int = 30) -> List[Tuple[float, str]]: """ Hybrid RAG search combining: 1. Inverted index (O(1) keyword lookup) 2. Semantic embeddings (MiniLM-L6) 3. Smart scoring (drugs get 1000x boost) Time: ~2 seconds Cost: $0 (all local) Returns: List of (score, trial_text) tuples """ global doc_chunks, doc_embeddings, embedder, inverted_index if doc_embeddings is None or len(doc_chunks) == 0: raise Exception("Embeddings not loaded!") logger.info(f"[RAG] Searching {len(doc_chunks):,} trials...") # Extract keywords stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'is', 'are', 'was', 'were', 'be', 'been'} words = re.findall(r'\b\w+\b', search_query.lower()) query_terms = [w for w in words if len(w) > 2 and w not in stop_words] # Keyword scoring with inverted index keyword_scores = {} if inverted_index is not None: inv_index_candidates = set() for term in query_terms: if term in inverted_index: inv_index_candidates.update(inverted_index[term]) if inv_index_candidates: # Identify drug-specific terms (rare = specific) drug_specific_terms = {term for term in query_terms if term in inverted_index and len(inverted_index[term]) < 100} for idx in inv_index_candidates: chunk_text = doc_chunks[idx][1] if isinstance(doc_chunks[idx], tuple) else doc_chunks[idx] chunk_lower = chunk_text.lower() # Drug match gets 1000x boost (critical for pharma queries) has_drug_match = any(drug_term in chunk_lower for drug_term in drug_specific_terms) keyword_scores[idx] = 1000.0 if has_drug_match else 1.0 # Semantic scoring load_embedder() query_embedding = embedder.encode([search_query])[0] semantic_similarities = np.dot(doc_embeddings, query_embedding) # Normalize scores if keyword_scores: max_kw = max(keyword_scores.values()) keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()} else: keyword_scores_norm = {} max_sem = semantic_similarities.max() min_sem = semantic_similarities.min() semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10) # Combine: 50% keyword, 50% semantic (keyword-matched trials prioritized) combined_scores = np.zeros(len(doc_chunks)) for idx in range(len(doc_chunks)): kw_score = keyword_scores_norm.get(idx, 0.0) sem_score = semantic_scores_norm[idx] combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score if kw_score > 0 else sem_score # Get top candidates top_indices = np.argsort(combined_scores)[-top_k:][::-1] results = [ (combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i]) for i in top_indices ] logger.info(f"[RAG] ✓ Found {len(results)} candidates (top score: {results[0][0]:.3f})") return results # ============================================================================ # STEP 3: 355M PERPLEXITY RANKING # ============================================================================ def load_355m_model(): """Load 355M Clinical Trial GPT model (cached)""" global model_355m, tokenizer_355m if model_355m is None: logger.info("[355M] Loading CT2 model for perplexity ranking...") tokenizer_355m = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/CT2") model_355m = GPT2LMHeadModel.from_pretrained( "gmkdigitalmedia/CT2", torch_dtype=torch.float16, device_map="auto" ) model_355m.eval() tokenizer_355m.pad_token = tokenizer_355m.eos_token logger.info("[355M] ✓ Model loaded") def rank_with_355m_perplexity(query: str, candidates: List[Tuple[float, str]]) -> List[Dict]: """ Rank trials using 355M model's perplexity scores Perplexity = "How natural does this query-trial pairing seem?" Lower perplexity = more relevant Time: ~2-5 seconds (depends on GPU) Cost: $0 (local model) Returns: List of dicts with trial data and scores """ load_355m_model() # Only rank top 10 (balance accuracy vs speed) top_10 = candidates[:10] logger.info(f"[355M] Ranking {len(top_10)} trials with perplexity...") ranked_trials = [] for idx, (hybrid_score, trial_text) in enumerate(top_10): # Extract NCT ID nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text) nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}" # Format test text test_text = f"""Query: {query} Relevant Clinical Trial: {trial_text[:800]} This trial is highly relevant because""" # Calculate perplexity inputs = tokenizer_355m( test_text, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(model_355m.device) with torch.no_grad(): outputs = model_355m(**inputs, labels=inputs.input_ids) perplexity = torch.exp(outputs.loss).item() # Convert perplexity to 0-1 score perplexity_score = 1.0 / (1.0 + perplexity / 100) # Combine: 70% hybrid search, 30% perplexity combined_score = 0.7 * hybrid_score + 0.3 * perplexity_score logger.info(f"[355M] {nct_id}: Perplexity={perplexity:.1f}, Combined={combined_score:.3f}") ranked_trials.append({ 'nct_id': nct_id, 'trial_text': trial_text, 'hybrid_score': float(hybrid_score), 'perplexity': float(perplexity), 'perplexity_score': float(perplexity_score), 'combined_score': float(combined_score), 'rank_before_355m': idx + 1 }) # Sort by combined score ranked_trials.sort(key=lambda x: x['combined_score'], reverse=True) # Add final ranks for idx, trial in enumerate(ranked_trials): trial['rank_after_355m'] = idx + 1 logger.info(f"[355M] ✓ Ranking complete") # Add remaining trials (without 355M scoring) for idx, (hybrid_score, trial_text) in enumerate(candidates[10:], start=10): nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text) nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}" ranked_trials.append({ 'nct_id': nct_id, 'trial_text': trial_text, 'hybrid_score': float(hybrid_score), 'perplexity': None, 'perplexity_score': None, 'combined_score': float(hybrid_score), 'rank_before_355m': idx + 1, 'rank_after_355m': len(ranked_trials) + 1 }) return ranked_trials # ============================================================================ # STEP 4: STRUCTURED JSON OUTPUT # ============================================================================ def parse_trial_to_dict(trial_text: str, nct_id: str) -> Dict: """ Parse trial text into structured fields Extracts: - title, status, phase, conditions, interventions - sponsor, enrollment, dates - description, outcomes """ trial = {'nct_id': nct_id, 'url': f"https://clinicaltrials.gov/study/{nct_id}"} # Extract fields using regex fields = { 'title': r'TITLE:\s*([^\n]+)', 'status': r'STATUS:\s*([^\n]+)', 'phase': r'PHASE:\s*([^\n]+)', 'conditions': r'CONDITIONS:\s*([^\n]+)', 'interventions': r'INTERVENTION:\s*([^\n]+)', 'sponsor': r'SPONSOR:\s*([^\n]+)', 'enrollment': r'ENROLLMENT:\s*([^\n]+)', 'primary_outcome': r'PRIMARY OUTCOME:\s*([^\n]+)', 'description': r'DESCRIPTION:\s*([^\n]+)' } for field, pattern in fields.items(): match = re.search(pattern, trial_text, re.IGNORECASE) trial[field] = match.group(1).strip() if match else None return trial def process_query_option_b(query: str, top_k: int = 10) -> Dict: """ Complete Option B pipeline 1. Parse query with LLM 2. RAG search 3. 355M perplexity ranking 4. Return structured JSON Total time: ~7-10 seconds Total cost: $0.001 per query Returns: { 'query': str, 'processing_time': float, 'query_analysis': { 'extracted_entities': {...}, 'optimized_search': str, 'parsing_time': float }, 'results': { 'total_found': int, 'returned': int, 'top_relevance_score': float }, 'trials': [ { 'nct_id': str, 'title': str, 'status': str, ... 'scoring': { 'relevance_score': float, 'perplexity': float, 'rank_before_355m': int, 'rank_after_355m': int }, 'url': str } ], 'benchmarking': { 'query_parsing_time': float, 'rag_search_time': float, '355m_ranking_time': float, 'total_processing_time': float } } """ start_time = time.time() result = { 'query': query, 'processing_time': 0, 'query_analysis': {}, 'results': {}, 'trials': [], 'benchmarking': {} } try: # Step 1: Parse query with LLM step1_start = time.time() parsed_query = parse_query_with_llm(query, hf_token=hf_token) search_query = parsed_query['search_terms'] result['query_analysis'] = { 'extracted_entities': { 'drugs': parsed_query.get('drugs', []), 'diseases': parsed_query.get('diseases', []), 'companies': parsed_query.get('companies', []), 'endpoints': parsed_query.get('endpoints', []) }, 'optimized_search': search_query, 'parsing_time': time.time() - step1_start } # Step 2: RAG search step2_start = time.time() candidates = hybrid_rag_search(search_query, top_k=top_k * 3) rag_time = time.time() - step2_start # Step 3: 355M perplexity ranking step3_start = time.time() ranked_trials = rank_with_355m_perplexity(query, candidates) ranking_time = time.time() - step3_start # Step 4: Format structured output result['results'] = { 'total_found': len(candidates), 'returned': min(top_k, len(ranked_trials)), 'top_relevance_score': ranked_trials[0]['combined_score'] if ranked_trials else 0 } # Parse trials for trial_data in ranked_trials[:top_k]: trial_dict = parse_trial_to_dict(trial_data['trial_text'], trial_data['nct_id']) trial_dict['scoring'] = { 'relevance_score': trial_data['combined_score'], 'hybrid_score': trial_data['hybrid_score'], 'perplexity': trial_data['perplexity'], 'perplexity_score': trial_data['perplexity_score'], 'rank_before_355m': trial_data['rank_before_355m'], 'rank_after_355m': trial_data['rank_after_355m'], 'ranking_method': '355m_perplexity' if trial_data['perplexity'] is not None else 'hybrid_only' } result['trials'].append(trial_dict) # Benchmarking result['benchmarking'] = { 'query_parsing_time': result['query_analysis']['parsing_time'], 'rag_search_time': rag_time, '355m_ranking_time': ranking_time, 'total_processing_time': time.time() - start_time } result['processing_time'] = time.time() - start_time logger.info(f"[OPTION B] ✓ Complete in {result['processing_time']:.1f}s") return result except Exception as e: logger.error(f"[OPTION B] Error: {e}") import traceback result['error'] = str(e) result['traceback'] = traceback.format_exc() result['processing_time'] = time.time() - start_time return result # ============================================================================ # INITIALIZATION # ============================================================================ def load_all_data(): """Load embeddings, chunks, and inverted index at startup""" global doc_chunks, doc_embeddings, inverted_index import pickle logger.info("=" * 60) logger.info("LOADING FOUNDATION RAG - OPTION B") logger.info("=" * 60) # Load chunks if CHUNKS_FILE.exists(): logger.info(f"Loading chunks from {CHUNKS_FILE}...") with open(CHUNKS_FILE, 'rb') as f: doc_chunks = pickle.load(f) logger.info(f"✓ Loaded {len(doc_chunks):,} trial chunks") # Load embeddings if EMBEDDINGS_FILE.exists(): logger.info(f"Loading embeddings from {EMBEDDINGS_FILE}...") doc_embeddings = np.load(EMBEDDINGS_FILE) logger.info(f"✓ Loaded embeddings: {doc_embeddings.shape}") # Load inverted index if INVERTED_INDEX_FILE.exists(): logger.info(f"Loading inverted index from {INVERTED_INDEX_FILE}...") with open(INVERTED_INDEX_FILE, 'rb') as f: inverted_index = pickle.load(f) logger.info(f"✓ Loaded inverted index: {len(inverted_index):,} terms") logger.info("=" * 60) logger.info("READY - Option B Pipeline Active") logger.info("=" * 60) # ============================================================================ # EXAMPLE USAGE # ============================================================================ if __name__ == "__main__": # Load data load_all_data() # Test query test_query = "What are the results for ianalumab in Sjogren's syndrome?" print(f"\nProcessing: {test_query}\n") result = process_query_option_b(test_query, top_k=5) print(f"\n{'='*60}") print("RESULTS") print(f"{'='*60}\n") print(f"Processing Time: {result['processing_time']:.1f}s") print(f"Query Parsing: {result['query_analysis']['parsing_time']:.1f}s") print(f"RAG Search: {result['benchmarking']['rag_search_time']:.1f}s") print(f"355M Ranking: {result['benchmarking']['355m_ranking_time']:.1f}s\n") print(f"Extracted Entities:") for entity_type, values in result['query_analysis']['extracted_entities'].items(): print(f" {entity_type}: {values}") print(f"\nTop {len(result['trials'])} Trials:\n") for i, trial in enumerate(result['trials'], 1): print(f"{i}. {trial['nct_id']}: {trial.get('title', 'No title')}") print(f" Relevance: {trial['scoring']['relevance_score']:.3f}") print(f" Perplexity: {trial['scoring']['perplexity']:.1f if trial['scoring']['perplexity'] else 'N/A'}") print(f" Rank change: {trial['scoring']['rank_before_355m']} → {trial['scoring']['rank_after_355m']}") print()