Spaces:
Paused
Paused
| """ | |
| Foundation RAG - Option B: Clean 1-LLM Architecture | |
| ==================================================== | |
| Pipeline: | |
| 1. Query Parser LLM (Llama-70B) β Extract entities + synonyms (3s, $0.001) | |
| 2. RAG Search (BM25 + Semantic + Inverted Index) β Retrieve candidates (2s, free) | |
| 3. 355M Perplexity Ranking β Rank by clinical relevance (2-5s, free) | |
| 4. Structured JSON Output β Return ranked trials (instant, free) | |
| Total: ~7-10 seconds, $0.001 per query | |
| No response generation - clients handle that with their own LLMs | |
| """ | |
| import os | |
| import time | |
| import logging | |
| import numpy as np | |
| import torch | |
| import re | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple, Optional | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast | |
| from huggingface_hub import InferenceClient | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| hf_token = os.getenv("HF_TOKEN") | |
| # Data paths (check /tmp first, then local) | |
| DATA_DIR = Path("/tmp/foundation_data") | |
| if not DATA_DIR.exists(): | |
| DATA_DIR = Path(__file__).parent | |
| CHUNKS_FILE = DATA_DIR / "dataset_chunks_TRIAL_AWARE.pkl" | |
| EMBEDDINGS_FILE = DATA_DIR / "dataset_embeddings_TRIAL_AWARE_FIXED.npy" | |
| INVERTED_INDEX_FILE = DATA_DIR / "inverted_index_COMPREHENSIVE.pkl" | |
| # Global state | |
| embedder = None | |
| doc_chunks = [] | |
| doc_embeddings = None | |
| inverted_index = None | |
| model_355m = None | |
| tokenizer_355m = None | |
| # ============================================================================ | |
| # STEP 1: QUERY PARSER LLM (Llama-70B) | |
| # ============================================================================ | |
| def parse_query_with_llm(query: str, hf_token: str = None) -> Dict: | |
| """ | |
| Use Llama-70B to parse query and extract entities | |
| Cost: $0.001 per query | |
| Time: ~3 seconds | |
| Returns: | |
| { | |
| 'drugs': [...], | |
| 'diseases': [...], | |
| 'companies': [...], | |
| 'endpoints': [...], | |
| 'search_terms': "optimized search query" | |
| } | |
| """ | |
| try: | |
| logger.info("[QUERY PARSER] Analyzing query with Llama-70B...") | |
| client = InferenceClient(token=hf_token, timeout=30) | |
| parse_prompt = f"""You are an expert in clinical trial terminology. Extract entities from this query. | |
| Query: "{query}" | |
| Extract ALL possible names and synonyms: | |
| DRUGS: | |
| - Brand names, generic names, research codes (e.g., BNT162b2) | |
| - Chemical names, abbreviations | |
| - Company+drug combinations (e.g., Pfizer-BioNTech vaccine) | |
| DISEASES: | |
| - Medical synonyms, ICD-10 terms | |
| - Technical and colloquial terms | |
| - Related conditions | |
| COMPANIES: | |
| - Parent companies, subsidiaries | |
| - Previous names, partnerships | |
| ENDPOINTS: | |
| - Specific outcomes or measures mentioned | |
| SEARCH_TERMS: | |
| - Comprehensive keywords for search | |
| Format EXACTLY as: | |
| DRUGS: [list or "none"] | |
| DISEASES: [list or "none"] | |
| COMPANIES: [list or "none"] | |
| ENDPOINTS: [list or "none"] | |
| SEARCH_TERMS: [comprehensive keyword list]""" | |
| response = client.chat_completion( | |
| model="meta-llama/Llama-3.1-70B-Instruct", | |
| messages=[{"role": "user", "content": parse_prompt}], | |
| max_tokens=500, | |
| temperature=0.3 | |
| ) | |
| parsed = response.choices[0].message.content.strip() | |
| logger.info(f"[QUERY PARSER] β Entities extracted") | |
| # Parse response | |
| result = { | |
| 'drugs': [], | |
| 'diseases': [], | |
| 'companies': [], | |
| 'endpoints': [], | |
| 'search_terms': query | |
| } | |
| for line in parsed.split('\n'): | |
| line = line.strip() | |
| if line.startswith('DRUGS:'): | |
| drugs = line.replace('DRUGS:', '').strip().strip('[]') | |
| if drugs and drugs.lower() != 'none': | |
| result['drugs'] = [d.strip().strip('"\'') for d in drugs.split(',')] | |
| elif line.startswith('DISEASES:'): | |
| diseases = line.replace('DISEASES:', '').strip().strip('[]') | |
| if diseases and diseases.lower() != 'none': | |
| result['diseases'] = [d.strip().strip('"\'') for d in diseases.split(',')] | |
| elif line.startswith('COMPANIES:'): | |
| companies = line.replace('COMPANIES:', '').strip().strip('[]') | |
| if companies and companies.lower() != 'none': | |
| result['companies'] = [c.strip().strip('"\'') for c in companies.split(',')] | |
| elif line.startswith('ENDPOINTS:'): | |
| endpoints = line.replace('ENDPOINTS:', '').strip().strip('[]') | |
| if endpoints and endpoints.lower() != 'none': | |
| result['endpoints'] = [e.strip().strip('"\'') for e in endpoints.split(',')] | |
| elif line.startswith('SEARCH_TERMS:'): | |
| terms = line.replace('SEARCH_TERMS:', '').strip().strip('[]') | |
| if terms: | |
| result['search_terms'] = terms.strip('"\'') | |
| return result | |
| except Exception as e: | |
| logger.warning(f"[QUERY PARSER] Failed: {e}, using original query") | |
| return { | |
| 'drugs': [], | |
| 'diseases': [], | |
| 'companies': [], | |
| 'endpoints': [], | |
| 'search_terms': query, | |
| 'error': str(e) | |
| } | |
| # ============================================================================ | |
| # STEP 2: RAG SEARCH (Hybrid: BM25 + Semantic + Inverted Index) | |
| # ============================================================================ | |
| def load_embedder(): | |
| """Load embedding model for semantic search""" | |
| global embedder | |
| if embedder is None: | |
| logger.info("[RAG] Loading MiniLM-L6 embedding model...") | |
| embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') | |
| logger.info("[RAG] β Embedder loaded") | |
| def hybrid_rag_search(search_query: str, top_k: int = 30) -> List[Tuple[float, str]]: | |
| """ | |
| Hybrid RAG search combining: | |
| 1. Inverted index (O(1) keyword lookup) | |
| 2. Semantic embeddings (MiniLM-L6) | |
| 3. Smart scoring (drugs get 1000x boost) | |
| Time: ~2 seconds | |
| Cost: $0 (all local) | |
| Returns: | |
| List of (score, trial_text) tuples | |
| """ | |
| global doc_chunks, doc_embeddings, embedder, inverted_index | |
| if doc_embeddings is None or len(doc_chunks) == 0: | |
| raise Exception("Embeddings not loaded!") | |
| logger.info(f"[RAG] Searching {len(doc_chunks):,} trials...") | |
| # Extract keywords | |
| stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', | |
| 'for', 'of', 'with', 'is', 'are', 'was', 'were', 'be', 'been'} | |
| words = re.findall(r'\b\w+\b', search_query.lower()) | |
| query_terms = [w for w in words if len(w) > 2 and w not in stop_words] | |
| # Keyword scoring with inverted index | |
| keyword_scores = {} | |
| if inverted_index is not None: | |
| inv_index_candidates = set() | |
| for term in query_terms: | |
| if term in inverted_index: | |
| inv_index_candidates.update(inverted_index[term]) | |
| if inv_index_candidates: | |
| # Identify drug-specific terms (rare = specific) | |
| drug_specific_terms = {term for term in query_terms | |
| if term in inverted_index and len(inverted_index[term]) < 100} | |
| for idx in inv_index_candidates: | |
| chunk_text = doc_chunks[idx][1] if isinstance(doc_chunks[idx], tuple) else doc_chunks[idx] | |
| chunk_lower = chunk_text.lower() | |
| # Drug match gets 1000x boost (critical for pharma queries) | |
| has_drug_match = any(drug_term in chunk_lower for drug_term in drug_specific_terms) | |
| keyword_scores[idx] = 1000.0 if has_drug_match else 1.0 | |
| # Semantic scoring | |
| load_embedder() | |
| query_embedding = embedder.encode([search_query])[0] | |
| semantic_similarities = np.dot(doc_embeddings, query_embedding) | |
| # Normalize scores | |
| if keyword_scores: | |
| max_kw = max(keyword_scores.values()) | |
| keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()} | |
| else: | |
| keyword_scores_norm = {} | |
| max_sem = semantic_similarities.max() | |
| min_sem = semantic_similarities.min() | |
| semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10) | |
| # Combine: 50% keyword, 50% semantic (keyword-matched trials prioritized) | |
| combined_scores = np.zeros(len(doc_chunks)) | |
| for idx in range(len(doc_chunks)): | |
| kw_score = keyword_scores_norm.get(idx, 0.0) | |
| sem_score = semantic_scores_norm[idx] | |
| combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score if kw_score > 0 else sem_score | |
| # Get top candidates | |
| top_indices = np.argsort(combined_scores)[-top_k:][::-1] | |
| results = [ | |
| (combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i]) | |
| for i in top_indices | |
| ] | |
| logger.info(f"[RAG] β Found {len(results)} candidates (top score: {results[0][0]:.3f})") | |
| return results | |
| # ============================================================================ | |
| # STEP 3: 355M PERPLEXITY RANKING | |
| # ============================================================================ | |
| def load_355m_model(): | |
| """Load 355M Clinical Trial GPT model (cached)""" | |
| global model_355m, tokenizer_355m | |
| if model_355m is None: | |
| logger.info("[355M] Loading CT2 model for perplexity ranking...") | |
| tokenizer_355m = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/CT2") | |
| model_355m = GPT2LMHeadModel.from_pretrained( | |
| "gmkdigitalmedia/CT2", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| model_355m.eval() | |
| tokenizer_355m.pad_token = tokenizer_355m.eos_token | |
| logger.info("[355M] β Model loaded") | |
| def rank_with_355m_perplexity(query: str, candidates: List[Tuple[float, str]]) -> List[Dict]: | |
| """ | |
| Rank trials using 355M model's perplexity scores | |
| Perplexity = "How natural does this query-trial pairing seem?" | |
| Lower perplexity = more relevant | |
| Time: ~2-5 seconds (depends on GPU) | |
| Cost: $0 (local model) | |
| Returns: | |
| List of dicts with trial data and scores | |
| """ | |
| load_355m_model() | |
| # Only rank top 10 (balance accuracy vs speed) | |
| top_10 = candidates[:10] | |
| logger.info(f"[355M] Ranking {len(top_10)} trials with perplexity...") | |
| ranked_trials = [] | |
| for idx, (hybrid_score, trial_text) in enumerate(top_10): | |
| # Extract NCT ID | |
| nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text) | |
| nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}" | |
| # Format test text | |
| test_text = f"""Query: {query} | |
| Relevant Clinical Trial: | |
| {trial_text[:800]} | |
| This trial is highly relevant because""" | |
| # Calculate perplexity | |
| inputs = tokenizer_355m( | |
| test_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ).to(model_355m.device) | |
| with torch.no_grad(): | |
| outputs = model_355m(**inputs, labels=inputs.input_ids) | |
| perplexity = torch.exp(outputs.loss).item() | |
| # Convert perplexity to 0-1 score | |
| perplexity_score = 1.0 / (1.0 + perplexity / 100) | |
| # Combine: 70% hybrid search, 30% perplexity | |
| combined_score = 0.7 * hybrid_score + 0.3 * perplexity_score | |
| logger.info(f"[355M] {nct_id}: Perplexity={perplexity:.1f}, Combined={combined_score:.3f}") | |
| ranked_trials.append({ | |
| 'nct_id': nct_id, | |
| 'trial_text': trial_text, | |
| 'hybrid_score': float(hybrid_score), | |
| 'perplexity': float(perplexity), | |
| 'perplexity_score': float(perplexity_score), | |
| 'combined_score': float(combined_score), | |
| 'rank_before_355m': idx + 1 | |
| }) | |
| # Sort by combined score | |
| ranked_trials.sort(key=lambda x: x['combined_score'], reverse=True) | |
| # Add final ranks | |
| for idx, trial in enumerate(ranked_trials): | |
| trial['rank_after_355m'] = idx + 1 | |
| logger.info(f"[355M] β Ranking complete") | |
| # Add remaining trials (without 355M scoring) | |
| for idx, (hybrid_score, trial_text) in enumerate(candidates[10:], start=10): | |
| nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text) | |
| nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}" | |
| ranked_trials.append({ | |
| 'nct_id': nct_id, | |
| 'trial_text': trial_text, | |
| 'hybrid_score': float(hybrid_score), | |
| 'perplexity': None, | |
| 'perplexity_score': None, | |
| 'combined_score': float(hybrid_score), | |
| 'rank_before_355m': idx + 1, | |
| 'rank_after_355m': len(ranked_trials) + 1 | |
| }) | |
| return ranked_trials | |
| # ============================================================================ | |
| # STEP 4: STRUCTURED JSON OUTPUT | |
| # ============================================================================ | |
| def parse_trial_to_dict(trial_text: str, nct_id: str) -> Dict: | |
| """ | |
| Parse trial text into structured fields | |
| Extracts: | |
| - title, status, phase, conditions, interventions | |
| - sponsor, enrollment, dates | |
| - description, outcomes | |
| """ | |
| trial = {'nct_id': nct_id, 'url': f"https://clinicaltrials.gov/study/{nct_id}"} | |
| # Extract fields using regex | |
| fields = { | |
| 'title': r'TITLE:\s*([^\n]+)', | |
| 'status': r'STATUS:\s*([^\n]+)', | |
| 'phase': r'PHASE:\s*([^\n]+)', | |
| 'conditions': r'CONDITIONS:\s*([^\n]+)', | |
| 'interventions': r'INTERVENTION:\s*([^\n]+)', | |
| 'sponsor': r'SPONSOR:\s*([^\n]+)', | |
| 'enrollment': r'ENROLLMENT:\s*([^\n]+)', | |
| 'primary_outcome': r'PRIMARY OUTCOME:\s*([^\n]+)', | |
| 'description': r'DESCRIPTION:\s*([^\n]+)' | |
| } | |
| for field, pattern in fields.items(): | |
| match = re.search(pattern, trial_text, re.IGNORECASE) | |
| trial[field] = match.group(1).strip() if match else None | |
| return trial | |
| def process_query_option_b(query: str, top_k: int = 10) -> Dict: | |
| """ | |
| Complete Option B pipeline | |
| 1. Parse query with LLM | |
| 2. RAG search | |
| 3. 355M perplexity ranking | |
| 4. Return structured JSON | |
| Total time: ~7-10 seconds | |
| Total cost: $0.001 per query | |
| Returns: | |
| { | |
| 'query': str, | |
| 'processing_time': float, | |
| 'query_analysis': { | |
| 'extracted_entities': {...}, | |
| 'optimized_search': str, | |
| 'parsing_time': float | |
| }, | |
| 'results': { | |
| 'total_found': int, | |
| 'returned': int, | |
| 'top_relevance_score': float | |
| }, | |
| 'trials': [ | |
| { | |
| 'nct_id': str, | |
| 'title': str, | |
| 'status': str, | |
| ... | |
| 'scoring': { | |
| 'relevance_score': float, | |
| 'perplexity': float, | |
| 'rank_before_355m': int, | |
| 'rank_after_355m': int | |
| }, | |
| 'url': str | |
| } | |
| ], | |
| 'benchmarking': { | |
| 'query_parsing_time': float, | |
| 'rag_search_time': float, | |
| '355m_ranking_time': float, | |
| 'total_processing_time': float | |
| } | |
| } | |
| """ | |
| start_time = time.time() | |
| result = { | |
| 'query': query, | |
| 'processing_time': 0, | |
| 'query_analysis': {}, | |
| 'results': {}, | |
| 'trials': [], | |
| 'benchmarking': {} | |
| } | |
| try: | |
| # Step 1: Parse query with LLM | |
| step1_start = time.time() | |
| parsed_query = parse_query_with_llm(query, hf_token=hf_token) | |
| search_query = parsed_query['search_terms'] | |
| result['query_analysis'] = { | |
| 'extracted_entities': { | |
| 'drugs': parsed_query.get('drugs', []), | |
| 'diseases': parsed_query.get('diseases', []), | |
| 'companies': parsed_query.get('companies', []), | |
| 'endpoints': parsed_query.get('endpoints', []) | |
| }, | |
| 'optimized_search': search_query, | |
| 'parsing_time': time.time() - step1_start | |
| } | |
| # Step 2: RAG search | |
| step2_start = time.time() | |
| candidates = hybrid_rag_search(search_query, top_k=top_k * 3) | |
| rag_time = time.time() - step2_start | |
| # Step 3: 355M perplexity ranking | |
| step3_start = time.time() | |
| ranked_trials = rank_with_355m_perplexity(query, candidates) | |
| ranking_time = time.time() - step3_start | |
| # Step 4: Format structured output | |
| result['results'] = { | |
| 'total_found': len(candidates), | |
| 'returned': min(top_k, len(ranked_trials)), | |
| 'top_relevance_score': ranked_trials[0]['combined_score'] if ranked_trials else 0 | |
| } | |
| # Parse trials | |
| for trial_data in ranked_trials[:top_k]: | |
| trial_dict = parse_trial_to_dict(trial_data['trial_text'], trial_data['nct_id']) | |
| trial_dict['scoring'] = { | |
| 'relevance_score': trial_data['combined_score'], | |
| 'hybrid_score': trial_data['hybrid_score'], | |
| 'perplexity': trial_data['perplexity'], | |
| 'perplexity_score': trial_data['perplexity_score'], | |
| 'rank_before_355m': trial_data['rank_before_355m'], | |
| 'rank_after_355m': trial_data['rank_after_355m'], | |
| 'ranking_method': '355m_perplexity' if trial_data['perplexity'] is not None else 'hybrid_only' | |
| } | |
| result['trials'].append(trial_dict) | |
| # Benchmarking | |
| result['benchmarking'] = { | |
| 'query_parsing_time': result['query_analysis']['parsing_time'], | |
| 'rag_search_time': rag_time, | |
| '355m_ranking_time': ranking_time, | |
| 'total_processing_time': time.time() - start_time | |
| } | |
| result['processing_time'] = time.time() - start_time | |
| logger.info(f"[OPTION B] β Complete in {result['processing_time']:.1f}s") | |
| return result | |
| except Exception as e: | |
| logger.error(f"[OPTION B] Error: {e}") | |
| import traceback | |
| result['error'] = str(e) | |
| result['traceback'] = traceback.format_exc() | |
| result['processing_time'] = time.time() - start_time | |
| return result | |
| # ============================================================================ | |
| # INITIALIZATION | |
| # ============================================================================ | |
| def load_all_data(): | |
| """Load embeddings, chunks, and inverted index at startup""" | |
| global doc_chunks, doc_embeddings, inverted_index | |
| import pickle | |
| logger.info("=" * 60) | |
| logger.info("LOADING FOUNDATION RAG - OPTION B") | |
| logger.info("=" * 60) | |
| # Load chunks | |
| if CHUNKS_FILE.exists(): | |
| logger.info(f"Loading chunks from {CHUNKS_FILE}...") | |
| with open(CHUNKS_FILE, 'rb') as f: | |
| doc_chunks = pickle.load(f) | |
| logger.info(f"β Loaded {len(doc_chunks):,} trial chunks") | |
| # Load embeddings | |
| if EMBEDDINGS_FILE.exists(): | |
| logger.info(f"Loading embeddings from {EMBEDDINGS_FILE}...") | |
| doc_embeddings = np.load(EMBEDDINGS_FILE) | |
| logger.info(f"β Loaded embeddings: {doc_embeddings.shape}") | |
| # Load inverted index | |
| if INVERTED_INDEX_FILE.exists(): | |
| logger.info(f"Loading inverted index from {INVERTED_INDEX_FILE}...") | |
| with open(INVERTED_INDEX_FILE, 'rb') as f: | |
| inverted_index = pickle.load(f) | |
| logger.info(f"β Loaded inverted index: {len(inverted_index):,} terms") | |
| logger.info("=" * 60) | |
| logger.info("READY - Option B Pipeline Active") | |
| logger.info("=" * 60) | |
| # ============================================================================ | |
| # EXAMPLE USAGE | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| # Load data | |
| load_all_data() | |
| # Test query | |
| test_query = "What are the results for ianalumab in Sjogren's syndrome?" | |
| print(f"\nProcessing: {test_query}\n") | |
| result = process_query_option_b(test_query, top_k=5) | |
| print(f"\n{'='*60}") | |
| print("RESULTS") | |
| print(f"{'='*60}\n") | |
| print(f"Processing Time: {result['processing_time']:.1f}s") | |
| print(f"Query Parsing: {result['query_analysis']['parsing_time']:.1f}s") | |
| print(f"RAG Search: {result['benchmarking']['rag_search_time']:.1f}s") | |
| print(f"355M Ranking: {result['benchmarking']['355m_ranking_time']:.1f}s\n") | |
| print(f"Extracted Entities:") | |
| for entity_type, values in result['query_analysis']['extracted_entities'].items(): | |
| print(f" {entity_type}: {values}") | |
| print(f"\nTop {len(result['trials'])} Trials:\n") | |
| for i, trial in enumerate(result['trials'], 1): | |
| print(f"{i}. {trial['nct_id']}: {trial.get('title', 'No title')}") | |
| print(f" Relevance: {trial['scoring']['relevance_score']:.3f}") | |
| print(f" Perplexity: {trial['scoring']['perplexity']:.1f if trial['scoring']['perplexity'] else 'N/A'}") | |
| print(f" Rank change: {trial['scoring']['rank_before_355m']} β {trial['scoring']['rank_after_355m']}") | |
| print() | |