""" Foundation 1.2 Clinical trial query system with 355M foundation model """ import gradio as gr import os from pathlib import Path import pickle import numpy as np from sentence_transformers import SentenceTransformer import logging from rank_bm25 import BM25Okapi import re from two_llm_system_FIXED import expand_query_with_355m, generate_clinical_response_with_xupract, rank_trials_with_355m logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize hf_token = os.getenv("HF_TOKEN") # Paths for data storage # Files will be downloaded from HF Dataset on first run DATASET_FILE = Path(__file__).parent / "complete_dataset_WITH_RESULTS_FULL.txt" CHUNKS_FILE = Path(__file__).parent / "dataset_chunks_TRIAL_AWARE.pkl" EMBEDDINGS_FILE = Path(__file__).parent / "dataset_embeddings_TRIAL_AWARE_FIXED.npy" # FIXED version to avoid cache INVERTED_INDEX_FILE = Path(__file__).parent / "inverted_index_COMPREHENSIVE.pkl" # Pre-built inverted index (307MB) # HF Dataset containing the large files DATASET_REPO = "gmkdigitalmedia/foundation1.2-data" # Global storage embedder = None doc_chunks = [] doc_embeddings = None bm25_index = None # BM25 index for fast keyword search inverted_index = None # Inverted index for instant drug lookup # ============================================================================ # ANALYTICS TRACKING # ============================================================================ from collections import defaultdict, Counter import time as time_module class QueryAnalytics: """Track query patterns and performance for monitoring""" def __init__(self): self.query_types = Counter() self.response_times = defaultdict(list) self.error_count = 0 self.total_queries = 0 self.start_time = time_module.time() def record_query(self, query_type: str, response_time: float, success: bool = True): """Record a query execution""" self.total_queries += 1 self.query_types[query_type] += 1 self.response_times[query_type].append(response_time) if not success: self.error_count += 1 logger.info(f"[ANALYTICS] Recorded: {query_type}, {response_time:.2f}s, success={success}") def get_stats(self): """Get analytics summary""" uptime = time_module.time() - self.start_time stats = { 'total_queries': self.total_queries, 'uptime_seconds': uptime, 'error_rate': self.error_count / self.total_queries if self.total_queries > 0 else 0, 'query_type_distribution': dict(self.query_types), 'avg_response_times': {} } for query_type, times in self.response_times.items(): if times: stats['avg_response_times'][query_type] = sum(times) / len(times) return stats # Initialize global analytics query_analytics = QueryAnalytics() # ============================================================================ # RAG FUNCTIONS # ============================================================================ def load_embedder(): """Load L6 embedding model (matches generated embeddings)""" global embedder if embedder is None: logger.info("Loading MiniLM-L6 embedding model...") # Force CPU to avoid CUDA init in main process embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') logger.info("L6 model loaded on CPU") def build_inverted_index(chunks): """ Build targeted inverted index for clinical search Maps drugs, diseases, companies, and endpoints to trial indices for O(1) lookup Indexes ONLY what matters: 1. INTERVENTION - drug/device names 2. CONDITIONS - diseases being treated 3. SPONSOR/COLLABORATOR/MANUFACTURER - company names 4. OUTCOME - trial endpoints (what's being measured) Does NOT index trial names (unnecessary noise) """ import time t_start = time.time() inv_index = {} logger.info("Building targeted index: drugs, diseases, companies, endpoints...") # Generic words to skip skip_words = { 'with', 'versus', 'combination', 'treatment', 'therapy', 'study', 'trial', 'phase', 'double', 'blind', 'placebo', 'group', 'control', 'active', 'randomized', 'multicenter', 'open', 'label', 'crossover' } for idx, chunk_data in enumerate(chunks): if idx % 100000 == 0 and idx > 0: logger.info(f" Indexed {idx:,}/{len(chunks):,} trials...") text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data text_lower = text.lower() # 1. DRUGS from INTERVENTION field intervention_match = re.search(r'intervention[:\s]+([^\n]+)', text_lower) if intervention_match: intervention_text = intervention_match.group(1) drugs = re.split(r'[,;\-\s]+', intervention_text) for drug in drugs: drug = drug.strip('.,;:() ') if len(drug) > 3 and drug not in skip_words: if drug not in inv_index: inv_index[drug] = [] if idx not in inv_index[drug]: inv_index[drug].append(idx) # 2. DISEASES from CONDITIONS field conditions_match = re.search(r'conditions?[:\s]+([^\n]+)', text_lower) if conditions_match: conditions_text = conditions_match.group(1) diseases = re.split(r'[,;\|]+', conditions_text) for disease in diseases: disease = disease.strip('.,;:() ') # Split multi-word conditions and index each significant word disease_words = re.findall(r'\b\w{4,}\b', disease) for word in disease_words: if word not in skip_words: if word not in inv_index: inv_index[word] = [] if idx not in inv_index[word]: inv_index[word].append(idx) # 3. COMPANIES from SPONSOR field sponsor_match = re.search(r'sponsor[:\s]+([^\n]+)', text_lower) if sponsor_match: sponsor_text = sponsor_match.group(1) sponsors = re.split(r'[,;\|]+', sponsor_text) for sponsor in sponsors: sponsor = sponsor.strip('.,;:() ') if len(sponsor) > 3: if sponsor not in inv_index: inv_index[sponsor] = [] if idx not in inv_index[sponsor]: inv_index[sponsor].append(idx) # 4. COMPANIES from COLLABORATOR field collab_match = re.search(r'collaborator[:\s]+([^\n]+)', text_lower) if collab_match: collab_text = collab_match.group(1) collaborators = re.split(r'[,;\|]+', collab_text) for collab in collaborators: collab = collab.strip('.,;:() ') if len(collab) > 3: if collab not in inv_index: inv_index[collab] = [] if idx not in inv_index[collab]: inv_index[collab].append(idx) # 5. COMPANIES from MANUFACTURER field manuf_match = re.search(r'manufacturer[:\s]+([^\n]+)', text_lower) if manuf_match: manuf_text = manuf_match.group(1) manufacturers = re.split(r'[,;\|]+', manuf_text) for manuf in manufacturers: manuf = manuf.strip('.,;:() ') if len(manuf) > 3: if manuf not in inv_index: inv_index[manuf] = [] if idx not in inv_index[manuf]: inv_index[manuf].append(idx) # 6. ENDPOINTS from OUTCOME fields # Look for outcome measures (what's being measured) outcome_matches = re.findall(r'outcome[:\s]+([^\n]+)', text_lower) for outcome_match in outcome_matches[:5]: # First 5 outcomes only # Extract meaningful endpoint terms endpoint_words = re.findall(r'\b\w{5,}\b', outcome_match) # 5+ char words for word in endpoint_words[:3]: # First 3 words per outcome if word not in skip_words and word not in {'outcome', 'measure', 'primary', 'secondary'}: if word not in inv_index: inv_index[word] = [] if idx not in inv_index[word]: inv_index[word].append(idx) t_elapsed = time.time() - t_start logger.info(f"✓ Targeted index built in {t_elapsed:.1f}s with {len(inv_index):,} terms") # Log sample entries for debugging (drugs, diseases, companies, endpoints) sample_terms = { 'drugs': ['keytruda', 'opdivo', 'humira'], 'diseases': ['cancer', 'diabetes', 'melanoma'], 'companies': ['novartis', 'pfizer', 'merck'], 'endpoints': ['survival', 'response', 'remission'] } for category, terms in sample_terms.items(): logger.info(f" {category.upper()} samples:") for term in terms: if term in inv_index: logger.info(f" '{term}' -> {len(inv_index[term])} trials") return inv_index def download_from_dataset(filename): """Download file from HF Dataset if not present locally""" from huggingface_hub import hf_hub_download import tempfile import os # Use /tmp for downloads (has write permissions in Docker) download_dir = Path("/tmp/foundation_data") download_dir.mkdir(exist_ok=True) local_file = download_dir / filename if local_file.exists(): logger.info(f"Found cached {filename}") return local_file try: logger.info(f"Downloading {filename} from {DATASET_REPO}...") # Get HF_TOKEN from environment or global hf_token variable token = os.environ.get('HF_TOKEN') or hf_token downloaded_file = hf_hub_download( repo_id=DATASET_REPO, filename=filename, repo_type="dataset", local_dir=download_dir, local_dir_use_symlinks=False, token=token # Pass authentication token ) logger.info(f"Downloaded {filename}") return Path(downloaded_file) except Exception as e: logger.error(f"Failed to download {filename}: {e}") return None def load_embeddings(): """Load pre-generated embeddings (download from dataset if needed)""" global doc_chunks, doc_embeddings, bm25_index # Try to download if not present - store paths returned by download chunks_path = CHUNKS_FILE embeddings_path = EMBEDDINGS_FILE dataset_path = DATASET_FILE index_path = INVERTED_INDEX_FILE if not CHUNKS_FILE.exists(): downloaded = download_from_dataset("dataset_chunks_TRIAL_AWARE.pkl") if downloaded: chunks_path = downloaded if not EMBEDDINGS_FILE.exists(): downloaded = download_from_dataset("dataset_embeddings_TRIAL_AWARE_FIXED.npy") # FIXED version if downloaded: embeddings_path = downloaded if not DATASET_FILE.exists(): downloaded = download_from_dataset("complete_dataset_WITH_RESULTS_FULL.txt") if downloaded: dataset_path = downloaded # Download inverted index from dataset (307 MB, truly comprehensive) if not INVERTED_INDEX_FILE.exists(): downloaded = download_from_dataset("inverted_index_COMPREHENSIVE.pkl") if downloaded: index_path = downloaded logger.info(f"✓ Downloaded comprehensive inverted index from dataset") if chunks_path.exists() and embeddings_path.exists(): try: logger.info("Loading embeddings from disk...") with open(chunks_path, 'rb') as f: doc_chunks = pickle.load(f) # Load embeddings loaded_embeddings = np.load(embeddings_path, allow_pickle=True) logger.info(f"Loaded embeddings type: {type(loaded_embeddings)}") # Check if it's already a proper numpy array if isinstance(loaded_embeddings, np.ndarray) and loaded_embeddings.ndim == 2: doc_embeddings = loaded_embeddings logger.info(f"✓ Embeddings are proper numpy array with shape: {doc_embeddings.shape}") elif isinstance(loaded_embeddings, list): logger.info(f"Converting embeddings from list to numpy array (memory efficient)...") # Convert in chunks to avoid memory spike chunk_size = 10000 total = len(loaded_embeddings) # DEBUG: Print first 3 items to see format logger.info(f"DEBUG: Total embeddings: {total}") logger.info(f"DEBUG: Type of first item: {type(loaded_embeddings[0])}") # Check if this is actually the chunks file (wrong file uploaded) if isinstance(loaded_embeddings[0], tuple) and len(loaded_embeddings[0]) == 2: if isinstance(loaded_embeddings[0][0], int) and isinstance(loaded_embeddings[0][1], str): raise ValueError( f"ERROR: The embeddings file contains (int, string) tuples!\n" f"This looks like the CHUNKS file was uploaded as the embeddings file.\n\n" f"First item: {loaded_embeddings[0][:2]}\n\n" f"Please re-upload the correct file:\n" f" CORRECT: dataset_embeddings_TRIAL_AWARE.npy (numpy array, 855 MB)\n" f" WRONG: dataset_chunks_TRIAL_AWARE.pkl (tuples, 2.8 GB)\n\n" f"The local file at /mnt/c/Users/ibm/Documents/HF/kg_to_model/dataset_embeddings_TRIAL_AWARE.npy is correct." ) if isinstance(loaded_embeddings[0], tuple): logger.info(f"DEBUG: Tuple length: {len(loaded_embeddings[0])}") for i, item in enumerate(loaded_embeddings[0][:5] if len(loaded_embeddings[0]) > 5 else loaded_embeddings[0]): logger.info(f"DEBUG: Tuple element {i}: type={type(item)}, preview={str(item)[:100]}") # Get embedding dimension from first item first_emb = loaded_embeddings[0] emb_idx = None # Initialize # Handle different formats if isinstance(first_emb, tuple): # Try both positions - could be (id, emb) or (emb, id) logger.info(f"DEBUG: Trying to find embedding vector in tuple...") emb_vector = None for idx, elem in enumerate(first_emb): if isinstance(elem, (list, np.ndarray)): emb_vector = elem emb_idx = idx logger.info(f"DEBUG: Found embedding at position {idx}") break if emb_vector is None: raise ValueError(f"No embedding vector found in tuple. Tuple contains: {[type(x) for x in first_emb]}") emb_dim = len(emb_vector) logger.info(f"DEBUG: Embedding dimension: {emb_dim}") elif isinstance(first_emb, list): emb_dim = len(first_emb) emb_idx = None elif isinstance(first_emb, np.ndarray): emb_dim = first_emb.shape[0] emb_idx = None else: raise ValueError(f"Unknown embedding format: {type(first_emb)}") logger.info(f"Creating array for {total} embeddings of dimension {emb_dim}") # Pre-allocate array doc_embeddings = np.zeros((total, emb_dim), dtype=np.float32) # Fill in chunks for i in range(0, total, chunk_size): end = min(i + chunk_size, total) # Extract embeddings from tuples if needed if isinstance(first_emb, tuple) and emb_idx is not None: # Extract just the embedding vector from each tuple at the correct position batch = [item[emb_idx] for item in loaded_embeddings[i:end]] doc_embeddings[i:end] = batch else: doc_embeddings[i:end] = loaded_embeddings[i:end] if i % 50000 == 0: logger.info(f"Converted {i}/{total} embeddings...") logger.info(f"✓ Converted to array with shape: {doc_embeddings.shape}") else: doc_embeddings = loaded_embeddings logger.info(f"Embeddings already numpy array with shape: {doc_embeddings.shape}") logger.info(f"Loaded {len(doc_chunks)} chunks with embeddings") # Skip BM25 (too memory-heavy for Docker), use inverted index only global inverted_index # Try to load pre-built comprehensive inverted index (77MB) from dataset if index_path.exists(): logger.info(f"Loading comprehensive inverted index from {index_path.name}...") try: with open(index_path, 'rb') as f: inverted_index = pickle.load(f) logger.info(f"✓ Loaded comprehensive index with {len(inverted_index):,} terms") logger.info(f" Includes: TITLE (all words), INTERVENTION, CONDITIONS, SPONSOR, SUMMARY/DESCRIPTION (companies)") except Exception as e: logger.warning(f"Failed to load comprehensive index: {e}, building basic index...") inverted_index = build_inverted_index(doc_chunks) else: logger.info("Comprehensive inverted index not found, building basic index (15 minutes)...") inverted_index = build_inverted_index(doc_chunks) logger.info("Will use inverted index + semantic search (no BM25)") return True except Exception as e: logger.error(f"Failed to load embeddings: {e}") raise RuntimeError("Embeddings are required but failed to load") from e raise RuntimeError("Embeddings files not found - system cannot function without embeddings") def filter_trial_for_clinical_summary(trial_text): """ Filter trial data to keep essential clinical information including SOME results. COMPREHENSIVE FILTERING: - Keeps all core trial info (title, summary, conditions, interventions) - Keeps sponsor/collaborator/manufacturer (WHO is running the trial) - Keeps first 5 outcomes (to show key endpoints) - Keeps first 5 result values per trial (to show actual data) - Filters out overwhelming statistical noise (hundreds of baseline/adverse event lines) This ensures the LLM sees comprehensive context including company information. """ if not trial_text: return trial_text lines = trial_text.split('\n') filtered_lines = [] # Counters to limit repetitive data outcome_count = 0 outcome_desc_count = 0 result_value_count = 0 # Limits MAX_OUTCOMES = 5 MAX_OUTCOME_DESC = 5 MAX_RESULT_VALUES = 5 for line in lines: line_stripped = line.strip() # Skip empty lines if not line_stripped: continue # ALWAYS SKIP: Overwhelming noise always_skip = [ 'BASELINE:', 'SERIOUS_ADVERSE_EVENT:', 'OTHER_ADVERSE_EVENT:', 'OUTCOME_TYPE:', 'OUTCOME_TIME_FRAME:', 'OUTCOME_SAFETY:', 'OUTCOME_OTHER:', 'OUTCOME_NUMBER:' ] should_skip = False for marker in always_skip: if line_stripped.startswith(marker): should_skip = True break if should_skip: continue # LIMITED KEEP: Outcomes (first N only) if line_stripped.startswith('OUTCOME:'): outcome_count += 1 if outcome_count <= MAX_OUTCOMES: filtered_lines.append(line) continue # LIMITED KEEP: Outcome descriptions (first N only) if line_stripped.startswith('OUTCOME_DESCRIPTION:'): outcome_desc_count += 1 if outcome_desc_count <= MAX_OUTCOME_DESC: filtered_lines.append(line) continue # LIMITED KEEP: Result values (first N only) if line_stripped.startswith('RESULT_VALUE:'): result_value_count += 1 if result_value_count <= MAX_RESULT_VALUES: filtered_lines.append(line) continue # ALWAYS KEEP: Core trial information + context always_keep = [ 'NCT_ID:', 'TITLE:', 'OFFICIAL_TITLE:', 'SUMMARY:', 'DESCRIPTION:', 'CONDITIONS:', 'INTERVENTION:', # WHAT disease, WHAT drug 'SPONSOR:', 'COLLABORATOR:', 'MANUFACTURER:', # WHO is running/funding 'ELIGIBILITY:' # Note: OUTCOME/OUTCOME_DESCRIPTION handled in LIMITED KEEP section above ] for marker in always_keep: if line_stripped.startswith(marker): filtered_lines.append(line) break return '\n'.join(filtered_lines) def retrieve_context_with_embeddings(query, top_k=10, entities=None): """ ENTERPRISE HYBRID SEARCH with STRICT ENTITY FILTERING - Enforces HARD FILTERS for companies (sponsor/collaborator) - Extracts meaningful terms from query (case-insensitive) - Scores each trial by keyword frequency (TF-IDF style) - Also gets semantic similarity scores - Merges both scores with weighted combination Args: query: Search query string top_k: Number of results to return entities: Dict with 'drugs', 'diseases', 'companies' - if provided, STRICTLY filters """ import time import re from collections import Counter global doc_chunks, doc_embeddings, embedder if doc_embeddings is None or len(doc_chunks) == 0: logger.error("Embeddings not loaded!") return "" t0 = time.time() # Extract ALL meaningful words from query (stop words removed) stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'what', 'how', 'do', 'you', 'know', 'about', 'that', 'this', 'there', 'it'} query_lower = query.lower() # Remove punctuation and split words = re.findall(r'\b\w+\b', query_lower) # Filter out stop words and short words query_terms = [w for w in words if len(w) > 2 and w not in stop_words] logger.info(f"[HYBRID] Query terms extracted: {query_terms}") # PARALLEL SEARCH: Run both keyword and semantic simultaneously # 1. KEYWORD SCORING WITH BM25 (Fast!) t_kw = time.time() # Use inverted index for drug lookup (lightweight, no BM25) global bm25_index, inverted_index keyword_scores = {} if inverted_index is not None: # Check if any query terms are in our drug/intervention inverted index inv_index_candidates = set() for term in query_terms: if term in inverted_index: inv_index_candidates.update(inverted_index[term]) logger.info(f"[INVERTED INDEX] Found {len(inverted_index[term])} trials for '{term}'") # FAST PATH: If we have inverted index hits (drug names), score those trials if inv_index_candidates: logger.info(f"[FAST PATH] Checking {len(inv_index_candidates)} inverted index candidates") # CRITICAL: Identify which terms are specific drugs (low frequency) drug_specific_terms = set() for term in query_terms: if term in inverted_index and len(inverted_index[term]) < 100: # This term appears in <100 trials - likely a specific drug name! drug_specific_terms.add(term) logger.info(f"[DRUG SPECIFIC] '{term}' found in {len(inverted_index[term])} trials - treating as drug name") for idx in inv_index_candidates: # No BM25, use simple match count as base score base_score = 1.0 # Check if this trial contains a drug-specific term chunk_data = doc_chunks[idx] chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data chunk_lower = chunk_text.lower() has_drug_match = False for drug_term in drug_specific_terms: if drug_term in chunk_lower: has_drug_match = True break # MASSIVE PRIORITY for drug-specific trials if has_drug_match: # Drug-specific trials get GUARANTEED top ranking score = 1000.0 + base_score logger.info(f"[DRUG PRIORITY] Trial {idx} contains specific drug - score={score:.1f}") else: # Regular inverted index hits (generic terms) if base_score <= 0: base_score = 0.1 score = base_score # Apply field-specific boosting for non-drug terms max_field_boost = 1.0 for term in query_terms: if term not in chunk_lower or term in drug_specific_terms: continue # INTERVENTION field - medium priority for non-drug terms if f'intervention: {term}' in chunk_lower or f'intervention:{term}' in chunk_lower: max_field_boost = max(max_field_boost, 3.0) # TITLE field - low priority elif 'title:' in chunk_lower: title_pos = chunk_lower.find('title:') term_pos = chunk_lower.find(term) if title_pos < term_pos < title_pos + 200: max_field_boost = max(max_field_boost, 2.0) score *= max_field_boost keyword_scores[idx] = score else: logger.info(f"[FALLBACK] No inverted index hits, using pure semantic search") logger.info(f"[HYBRID] Inverted index scoring: {len(keyword_scores)} trials matched ({time.time()-t_kw:.2f}s)") # 1.5. STRICT COMPANY FILTERING (if companies specified) company_filter_failed = False if entities and entities.get('companies'): companies = [c.lower() for c in entities['companies']] logger.info(f"[STRICT FILTER] Enforcing company filter: {companies}") # Save original scores in case we need to fall back original_keyword_scores = keyword_scores.copy() # Filter keyword_scores to ONLY trials with these companies filtered_keyword_scores = {} sponsor_field_patterns = ['sponsor:', 'collaborator:', 'manufacturer:'] for idx, score in keyword_scores.items(): chunk_data = doc_chunks[idx] chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data chunk_lower = chunk_text.lower() # Check if ANY company appears in sponsor/collaborator/manufacturer fields has_company = False for company in companies: # Look for company name in sponsor-related fields for field in sponsor_field_patterns: if field in chunk_lower: field_start = chunk_lower.find(field) field_text = chunk_lower[field_start:field_start+500] # Next 500 chars if company in field_text: has_company = True logger.info(f"[COMPANY MATCH] Trial {idx} has '{company}' in {field}") break if has_company: break if has_company: filtered_keyword_scores[idx] = score * 10.0 # 10x boost for company match # If no company match, EXCLUDE this trial before_count = len(keyword_scores) after_count = len(filtered_keyword_scores) logger.info(f"[STRICT FILTER] Filtered {before_count} → {after_count} trials (only those from {companies})") # If no company matches, fall back to original search but flag it if len(filtered_keyword_scores) == 0: logger.warning(f"[STRICT FILTER] No trials found from companies {companies}, falling back to general search") company_filter_failed = True keyword_scores = original_keyword_scores # Restore original else: keyword_scores = filtered_keyword_scores # 2. SEMANTIC SCORING load_embedder() t_sem = time.time() query_embedding = embedder.encode([query])[0] semantic_similarities = np.dot(doc_embeddings, query_embedding) logger.info(f"[HYBRID] Semantic scoring complete ({time.time()-t_sem:.2f}s)") # 3. MERGE SCORES # Normalize both scores to 0-1 range if keyword_scores: max_kw = max(keyword_scores.values()) keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()} else: keyword_scores_norm = {} max_sem = semantic_similarities.max() min_sem = semantic_similarities.min() semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10) # Combined score: 50% keyword (with IDF/field boost), 50% semantic (context) # Balanced approach: IDF-weighted keywords + semantic understanding combined_scores = np.zeros(len(doc_chunks)) for idx in range(len(doc_chunks)): kw_score = keyword_scores_norm.get(idx, 0.0) sem_score = semantic_scores_norm[idx] # If keyword match exists, balance keyword + semantic if kw_score > 0: combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score else: # Pure semantic if no keyword match combined_scores[idx] = sem_score # Get top K by combined score (get more candidates to sort by recency) # We'll get 10 candidates, then sort by NCT ID to find the 3 most recent candidate_k = max(top_k * 3, 10) # Get 3x requested, minimum 10 top_indices = np.argsort(combined_scores)[-candidate_k:][::-1] logger.info(f"[HYBRID] Top 3 combined scores: {combined_scores[top_indices[:3]]}") logger.info(f"[HYBRID] Top 3 keyword scores: {[keyword_scores_norm.get(i, 0.0) for i in top_indices[:3]]}") logger.info(f"[HYBRID] Top 3 semantic scores: {[semantic_scores_norm[i] for i in top_indices[:3]]}") # Extract text and scores for 355M ranking # Format as (score, text) tuples for rank_trials_with_355m candidate_trials_for_ranking = [(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i]) for i in top_indices] # SORT BY NCT ID (higher = newer) before 355M ranking def extract_nct_number(trial_tuple): """Extract NCT number from trial text for sorting (higher = newer)""" _, text = trial_tuple match = re.search(r'NCT_ID:\s*NCT(\d+)', text) return int(match.group(1)) if match else 0 # Sort candidates by NCT ID (descending = newest first) candidate_trials_for_ranking.sort(key=extract_nct_number, reverse=True) # Log top 5 NCT IDs to show recency sorting top_ncts = [] for score, text in candidate_trials_for_ranking[:5]: match = re.search(r'NCT_ID:\s*(NCT\d+)', text) if match: top_ncts.append(match.group(1)) logger.info(f"[NCT SORT] Top 5 candidates by recency: {top_ncts}") # SKIP 355M RANKING - It's broken (gives 0.50 to everything) and wastes 10 seconds # Just use the hybrid-scored + recency-sorted candidates logger.info(f"[FAST MODE] Using hybrid search + recency sort (skipping broken 355M ranking)") ranked_trials = candidate_trials_for_ranking # Take top K from ranked results top_ranked = ranked_trials[:top_k] logger.info(f"[FAST MODE] Selected top {len(top_ranked)} trials (hybrid score + recency)") # Extract just the text raw_chunks = [trial_text for _, trial_text in top_ranked] # Apply clinical filter to each trial context_chunks = [filter_trial_for_clinical_summary(chunk) for chunk in raw_chunks] if context_chunks: first_trial_preview = context_chunks[0][:200] logger.info(f"[HYBRID] First result (filtered): {first_trial_preview}") # Add ranking information if available from 355M if hasattr(ranked_trials, 'ranking_info'): ranking_header = "[TRIAL RANKING BY CLINICAL RELEVANCE GPT]\n" for info in ranked_trials.ranking_info: ranking_header += f"Rank {info['rank']}: {info['nct_id']} - Relevance {info['relevance_rating']}\n" ranking_header += "---\n\n" # Prepend ranking info to first trial if context_chunks: context_chunks[0] = ranking_header + context_chunks[0] logger.info(f"[355M RANKING] Added ranking metadata to context for final LLM") context = "\n\n---\n\n".join(context_chunks) # Use --- as separator between trials logger.info(f"[HYBRID] TOTAL TIME: {time.time()-t0:.2f}s") logger.info(f"[HYBRID] Filtered context length: {len(context)} chars (was ~{sum(len(c) for c in raw_chunks)} chars)") return context def keyword_search_query_text(query, max_results=10, hf_token=None): """Search dataset using ALL meaningful words from the full query""" if not DATASET_FILE.exists(): logger.error("Dataset file not found") return "" # Extract all meaningful words from the full query # Remove common stopwords but keep medical/clinical terms stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should', 'could', 'may', 'might', 'must', 'can', 'of', 'at', 'by', 'for', 'with', 'about', 'as', 'into', 'through', 'during', 'to', 'from', 'in', 'on', 'what', 'you', 'know', 'that', 'relevant'} # Extract words, filter stopwords and short words words = query.lower().split() search_terms = [w.strip('?.,!;:()[]{}') for w in words if w.lower() not in stopwords and len(w) >= 3] if not search_terms: logger.warning("No search terms extracted from query") return "" logger.info(f"Search terms from full query: {search_terms}") # Store trials with match scores trials_with_scores = [] current_trial = "" try: with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f: for line in f: # Check if new trial starts if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"): # Score previous trial if current_trial: trial_lower = current_trial.lower() # Count matches for all search terms score = sum(1 for term in search_terms if term in trial_lower) if score > 0: trials_with_scores.append((score, current_trial)) current_trial = line else: current_trial += line # Check last trial if current_trial: trial_lower = current_trial.lower() score = sum(1 for term in search_terms if term in trial_lower) if score > 0: trials_with_scores.append((score, current_trial)) # Sort by score (highest first) and take top results trials_with_scores.sort(reverse=True, key=lambda x: x[0]) matching_trials = [(score, trial) for score, trial in trials_with_scores[:max_results]] if matching_trials: logger.info(f"Keyword search found {len(matching_trials)} trials") return matching_trials # Return list of (score, trial) tuples else: logger.warning("Keyword search found no matching trials") return [] except Exception as e: logger.error(f"Keyword search failed: {e}") return [] def keyword_search_in_dataset(entities, max_results=10): """Legacy: Search dataset file for keyword matches using extracted entities""" if not DATASET_FILE.exists(): logger.error("Dataset file not found") return "" drugs = [d.lower() for d in entities.get('drugs', [])] conditions = [c.lower() for c in entities.get('conditions', [])] if not drugs and not conditions: logger.warning("No search terms for keyword search") return "" logger.info(f"Keyword search - Drugs: {drugs}, Conditions: {conditions}") # Store trials with match scores trials_with_scores = [] current_trial = "" try: with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f: for line in f: # Check if new trial starts if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"): # Score previous trial if current_trial: trial_lower = current_trial.lower() # Count matches drug_matches = sum(1 for d in drugs if d in trial_lower) condition_matches = sum(1 for c in conditions if c in trial_lower) # Only include trials that match at least the drug (if drug was specified) if drugs: if drug_matches > 0: score = drug_matches * 10 + condition_matches trials_with_scores.append((score, current_trial)) elif condition_matches > 0: # No drug specified, just match conditions trials_with_scores.append((condition_matches, current_trial)) current_trial = line else: current_trial += line # Check last trial if current_trial: trial_lower = current_trial.lower() drug_matches = sum(1 for d in drugs if d in trial_lower) condition_matches = sum(1 for c in conditions if c in trial_lower) if drugs: if drug_matches > 0: score = drug_matches * 10 + condition_matches trials_with_scores.append((score, current_trial)) elif condition_matches > 0: trials_with_scores.append((condition_matches, current_trial)) # Sort by score (highest first) and take top results trials_with_scores.sort(reverse=True, key=lambda x: x[0]) matching_trials = [trial for score, trial in trials_with_scores[:max_results]] if matching_trials: context = "\n\n---\n\n".join(matching_trials) if len(context) > 6000: context = context[:6000] + "..." logger.info(f"Keyword search found {len(matching_trials)} trials (from {len(trials_with_scores)} candidates)") return context else: logger.warning("Keyword search found no trials matching drug") return "" except Exception as e: logger.error(f"Keyword search failed: {e}") return "" # ============================================================================ # ENTITY EXTRACTION # ============================================================================ def parse_entities_from_query(conversation, hf_token=None): """Parse entities from query using both 355M and 8B models + regex fallback""" entities = {'drugs': [], 'conditions': []} # Use 355M model for entity extraction extracted_355m = extract_entities_with_small_model(conversation) # Also use 8B model for more reliable extraction extracted_8b = extract_entities_with_8b(conversation, hf_token=hf_token) # Combine both extractions extracted = (extracted_355m or "") + "\n" + (extracted_8b or "") # Parse model output if extracted: lines = extracted.split('\n') for line in lines: lower_line = line.lower() if 'drug:' in lower_line or 'medication:' in lower_line: drug = re.sub(r'(drug:|medication:)', '', line, flags=re.IGNORECASE).strip() if drug: entities['drugs'].append(drug) elif 'condition:' in lower_line or 'disease:' in lower_line: condition = re.sub(r'(condition:|disease:)', '', line, flags=re.IGNORECASE).strip() if condition: entities['conditions'].append(condition) # Regex fallback for standard drug naming patterns drug_patterns = [ r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: -mab suffix r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: -nib suffix r'\b([A-Z]\d+[A-Z]+\d+)\b' # Alphanumeric codes like F8IL10 ] for pattern in drug_patterns: matches = re.findall(pattern, conversation) for match in matches: if match.lower() not in [d.lower() for d in entities['drugs']]: entities['drugs'].append(match) condition_patterns = [ r'\b(sjogren\'?s?|lupus|myelofibrosis|rheumatoid arthritis)\b' ] for pattern in condition_patterns: matches = re.findall(pattern, conversation, re.IGNORECASE) for match in matches: if match not in [c.lower() for c in entities['conditions']]: entities['conditions'].append(match) logger.info(f"Extracted entities: {entities}") return entities # ============================================================================ # MAIN QUERY PROCESSING # ============================================================================ def extract_entities_simple(query): """Simple entity extraction using regex patterns - no model needed""" entities = {'drugs': [], 'conditions': []} # Drug patterns drug_patterns = [ r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: ianalumab, rituximab, etc. r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: imatinib, etc. r'\b([A-Z]\d+[A-Z]+\d+)\b', # Alphanumeric codes r'\b(ianalumab|rituximab|tocilizumab|adalimumab|infliximab)\b', # Common drugs ] # Condition patterns condition_patterns = [ r'\b(sjogren\'?s?\s+syndrome)\b', r'\b(rheumatoid arthritis)\b', r'\b(lupus)\b', r'\b(myelofibrosis)\b', r'\b(diabetes)\b', r'\b(cancer|carcinoma|melanoma)\b', ] query_lower = query.lower() # Extract drugs for pattern in drug_patterns: matches = re.findall(pattern, query, re.IGNORECASE) for match in matches: if match.lower() not in [d.lower() for d in entities['drugs']]: entities['drugs'].append(match) # Extract conditions for pattern in condition_patterns: matches = re.findall(pattern, query, re.IGNORECASE) for match in matches: if match.lower() not in [c.lower() for c in entities['conditions']]: entities['conditions'].append(match) logger.info(f"Extracted entities: {entities}") return entities def parse_query_with_llm(query, hf_token=None): """ Use fast LLM to parse query and extract structured information Extracts: - Drug names - Diseases/conditions - Companies (sponsors/manufacturers) - Endpoints (what's being measured) - Search terms (optimized for RAG) Returns: Dict with extracted entities and optimized search query """ try: from huggingface_hub import InferenceClient logger.info("[QUERY PARSER] Analyzing user query with LLM...") client = InferenceClient(token=hf_token, timeout=30) parse_prompt = f"""You are an expert in clinical trial terminology. Extract and expand entities from this query. Query: "{query}" Your task is to think creatively about ALL possible ways these entities might appear in clinical trial databases. For each entity type, brainstorm extensively: DRUGS: - Start with drugs explicitly mentioned - Add ALL possible names: brand names, generic names, research codes (like BNT162b2), manufacturer+drug combos (Pfizer-BioNTech), chemical names, common abbreviations - Think: "What would a pharmaceutical company call this in a trial?" - Example: "Pfizer COVID vaccine" → ["Comirnaty", "BNT162b2", "tozinameran", "Pfizer-BioNTech COVID-19 vaccine", "mRNA-1273"] DISEASES: - Include the disease/condition mentioned - Add medical synonyms, ICD-10 terms, related conditions - Both technical and colloquial terms - Example: "COVID" → ["COVID-19", "SARS-CoV-2", "coronavirus disease 2019", "severe acute respiratory syndrome coronavirus 2"] COMPANIES: - Company mentioned plus parent companies, subsidiaries - Include previous names, merged entities, partnership names - Example: "Pfizer" → ["Pfizer", "Pfizer Inc.", "Pfizer-BioNTech", "BioNTech SE"] ENDPOINTS: - Any specific outcomes, measures, or endpoints mentioned - Include related clinical measures SEARCH_TERMS: - Comprehensive keywords combining above entities - Include partial matches that might be relevant Format EXACTLY as: DRUGS: [list or "none"] DISEASES: [list or "none"] COMPANIES: [list or "none"] ENDPOINTS: [list or "none"] SEARCH_TERMS: [comprehensive keyword list] Be expansive - more synonyms mean better trial matching.""" response = client.chat_completion( model="meta-llama/Llama-3.1-70B-Instruct", messages=[{"role": "user", "content": parse_prompt}], max_tokens=500, # Increased for comprehensive synonyms temperature=0.3 # Slightly higher for creative synonym generation ) parsed = response.choices[0].message.content.strip() logger.info(f"[QUERY PARSER] Extracted entities:\n{parsed}") # Parse the response into dict result = { 'raw_parsed': parsed, 'drugs': [], 'diseases': [], 'companies': [], 'endpoints': [], 'search_terms': query # fallback } lines = parsed.split('\n') for line in lines: line = line.strip() if line.startswith('DRUGS:'): drugs = line.replace('DRUGS:', '').strip() # Remove brackets if present: [item1, item2] → item1, item2 drugs = drugs.strip('[]') if drugs and drugs.lower() != 'none': result['drugs'] = [d.strip().strip('"\'') for d in drugs.split(',') if d.strip()] elif line.startswith('DISEASES:'): diseases = line.replace('DISEASES:', '').strip() diseases = diseases.strip('[]') if diseases and diseases.lower() != 'none': result['diseases'] = [d.strip().strip('"\'') for d in diseases.split(',') if d.strip()] elif line.startswith('COMPANIES:'): companies = line.replace('COMPANIES:', '').strip() companies = companies.strip('[]') if companies and companies.lower() != 'none': result['companies'] = [c.strip().strip('"\'') for c in companies.split(',') if c.strip()] elif line.startswith('ENDPOINTS:'): endpoints = line.replace('ENDPOINTS:', '').strip() endpoints = endpoints.strip('[]') if endpoints and endpoints.lower() != 'none': result['endpoints'] = [e.strip().strip('"\'') for e in endpoints.split(',') if e.strip()] elif line.startswith('SEARCH_TERMS:'): terms = line.replace('SEARCH_TERMS:', '').strip() terms = terms.strip('[]') result['search_terms'] = terms if terms else query # FALLBACK: If LLM returned empty, try regex extraction from query if not result['drugs'] and not result['diseases'] and not result['companies']: logger.warning("[QUERY PARSER] LLM returned empty entities, using regex fallback") # Extract drug-like terms (capitalized words, could be drug names) import re query_lower = query.lower() # Common drug patterns drug_patterns = [ r'\b(ianalumab|pembrolizumab|nivolumab|rituximab|tocilizumab)\b', r'\b(keytruda|opdivo|humira|enbrel|remicade)\b', r'\b([A-Z][a-z]+mab)\b', # -mab suffix (monoclonal antibodies) r'\b([A-Z][a-z]+nib)\b', # -nib suffix (kinase inhibitors) ] for pattern in drug_patterns: matches = re.findall(pattern, query, re.IGNORECASE) for match in matches: if match.lower() not in [d.lower() for d in result['drugs']]: result['drugs'].append(match) # Extract disease terms disease_patterns = [ r"\b(sjogren'?s?|sjogrens)\s*(syndrome|disease)?\b", r'\b(lupus|arthritis|melanoma|diabetes|cancer)\b', r'\b(rheumatoid\s+arthritis|multiple\s+sclerosis)\b', ] for pattern in disease_patterns: matches = re.findall(pattern, query, re.IGNORECASE) for match in matches: disease = match if isinstance(match, str) else ' '.join(match).strip() if disease and disease.lower() not in [d.lower() for d in result['diseases']]: result['diseases'].append(disease) logger.info(f"[QUERY PARSER] Regex fallback found - Drugs: {result['drugs']}, Diseases: {result['diseases']}") logger.info(f"[QUERY PARSER] ✓ Drugs: {result['drugs']}, Diseases: {result['diseases']}, Companies: {result['companies']}") return result except Exception as e: logger.warning(f"[QUERY PARSER] Failed: {e}, using regex fallback on query") # Emergency fallback - extract from query directly import re query_lower = query.lower() drugs = [] diseases = [] # Extract Ianalumab specifically if 'ianalumab' in query_lower: drugs.append('Ianalumab') # Extract Sjogren's if 'sjogren' in query_lower: diseases.append("Sjogren's syndrome") return { 'drugs': drugs, 'diseases': diseases, 'companies': [], 'endpoints': [], 'search_terms': query, 'raw_parsed': '' } def plan_query_action(query, parsed_entities, hf_token=None): """ Use HuggingFace Llama-70B to decide the best action for this query. Actions: - SEARCH_TRIALS: Specific drug/disease questions (use RAG with top 30 trials) - COUNT_AGGREGATE: "How many" or "list all" questions (use index counts) - COMPARE: Compare two or more treatments - GENERAL_KNOWLEDGE: Definitions or general info (skip RAG, use LLM knowledge) Returns: Dict with action, reasoning, and parameters """ try: from huggingface_hub import InferenceClient logger.info("[PLANNING AGENT] Deciding action with HuggingFace Llama-70B...") client = InferenceClient(token=hf_token, timeout=30) planning_prompt = f"""You are a clinical trial search strategist. Route this query to the best action. Query: "{query}" Extracted entities: - Drugs: {parsed_entities.get('drugs', [])} - Diseases: {parsed_entities.get('diseases', [])} - Companies: {parsed_entities.get('companies', [])} - Endpoints: {parsed_entities.get('endpoints', [])} ROUTING RULES: 1. SEARCH_TRIALS (default): Any question about specific drugs, treatments, efficacy, safety, trial results, side effects, or when entities are extracted 2. COUNT_AGGREGATE: Only when explicitly asking "how many", "list all", "total number" 3. COMPARE: Only when explicitly comparing with "vs", "versus", "compare", "better than", "difference between" 4. GENERAL_KNOWLEDGE: Only for pure definitions with no trial data needed When in doubt, choose SEARCH_TRIALS - real trial data is almost always helpful. Analyze the user's intent: - Are they asking about specific trial outcomes? → SEARCH_TRIALS - Do they want data about a drug/disease? → SEARCH_TRIALS - Are they asking for counts or lists? → COUNT_AGGREGATE - Are they comparing treatments? → COMPARE - Is this purely definitional? → GENERAL_KNOWLEDGE Respond with: ACTION: [choose one action] REASONING: [one clear sentence explaining why] SEARCH_TERMS: [refined search terms to find the most relevant trials] FOCUS: [what aspect to emphasize in the final answer - efficacy, safety, trial status, etc.]""" response = client.chat_completion( model="meta-llama/Llama-3.1-70B-Instruct", messages=[{"role": "user", "content": planning_prompt}], max_tokens=150, temperature=0.1 # Low temp for consistent routing ) result_text = response.choices[0].message.content.strip() logger.info(f"[PLANNING AGENT] Decision:\n{result_text}") # Parse the response result = { 'action': 'SEARCH_TRIALS', # Default fallback 'reasoning': 'Could not parse response', 'params': query, 'focus': 'comprehensive trial data', # New field 'raw': result_text } lines = result_text.split('\n') for line in lines: line = line.strip() if line.startswith('ACTION:'): action = line.replace('ACTION:', '').strip() if action in ['SEARCH_TRIALS', 'COUNT_AGGREGATE', 'COMPARE', 'GENERAL_KNOWLEDGE']: result['action'] = action elif line.startswith('REASONING:'): result['reasoning'] = line.replace('REASONING:', '').strip() elif line.startswith('SEARCH_TERMS:'): params = line.replace('SEARCH_TERMS:', '').strip() if params.lower() != 'none': result['params'] = params elif line.startswith('FOCUS:'): result['focus'] = line.replace('FOCUS:', '').strip() logger.info(f"[PLANNING AGENT] ✓ Action: {result['action']}, Focus: {result['focus']}, Reasoning: {result['reasoning']}") return result except Exception as e: logger.warning(f"[PLANNING AGENT] Failed: {e}, defaulting to SEARCH_TRIALS") return { 'action': 'SEARCH_TRIALS', 'reasoning': f'Planning failed: {e}', 'params': query, 'focus': 'available trial data' } def generate_llama_response(query, rag_context, hf_token=None, parsed_entities=None, planning_context=None): """ Intelligent synthesis that ALWAYS provides substantive answers from available data Args: query: User's question rag_context: Retrieved trial data hf_token: HuggingFace API token parsed_entities: Dict with extracted entities (drugs, diseases, companies) planning_context: Dict with planning agent output (action, focus, reasoning) """ # Build entity context string for better guidance entity_context = "" if parsed_entities: drugs_list = parsed_entities.get('drugs', [])[:10] diseases_list = parsed_entities.get('diseases', [])[:10] companies_list = parsed_entities.get('companies', [])[:10] if drugs_list or diseases_list or companies_list: entity_context = f""" Key entities to look for (including synonyms): - Drugs/Treatments: {', '.join(drugs_list) if drugs_list else 'none'} - Diseases: {', '.join(diseases_list) if diseases_list else 'none'} - Companies: {', '.join(companies_list) if companies_list else 'none'}""" # Focus area from planning focus_area = planning_context.get('focus', 'comprehensive analysis') if planning_context else 'comprehensive analysis' try: # Try Groq first (much faster), fallback to HuggingFace groq_api_key = os.getenv("GROQ_API_KEY") system_prompt = """You are a leading clinical trials analyst. Your role is to provide the most helpful, informative answer possible using available trial data. You excel at finding connections and insights even from imperfect data matches. CORE PRINCIPLES: 1. ALWAYS provide substantive, useful answers 2. Find relevant information even in partially-matching trials 3. Extract specific numbers, dates, phases, outcomes wherever available 4. Connect information across trials to build comprehensive insights 5. Never say "no relevant trials found" - work with what you have""" user_prompt = f"""Question: {query} Focus for this analysis: {focus_area} {entity_context} Clinical Trials Retrieved: {rag_context[:12000]} YOUR MISSION: Provide the most comprehensive, helpful answer possible by intelligently analyzing ALL available trials. ANALYSIS APPROACH: 1. SCAN all trials for ANY relevance to the query: - Direct matches (same drug + disease) → Primary focus - Same drug, different disease → Still valuable (shows drug profile) - Same disease, different drug → Provides treatment landscape context - Same company → Shows research pipeline - Similar mechanisms/drug classes → Offers comparative insights 2. EXTRACT concrete information: - Trial phases, enrollment numbers, completion dates - Efficacy percentages, response rates, survival data - Safety profiles, adverse events, tolerability - Dosing regimens, administration routes - Patient populations, inclusion/exclusion criteria 3. SYNTHESIZE intelligently: - If asking about Drug X for Disease Y but only find Drug X for Disease Z, discuss what this reveals about Drug X's mechanism and potential - Find patterns across trials (e.g., consistent safety profile) - Note trial progression (Phase 1 → 2 → 3) showing development status ## YOUR RESPONSE STRUCTURE: ### DIRECT ANSWER [Immediately address the query with the best available information. Be confident and helpful. If asking about "Sinopharm COVID vaccine" and trials mention "BBIBP-CorV" - recognize these as the same. Lead with what you KNOW, not what you don't know.] ### KEY CLINICAL TRIALS EVIDENCE [For each relevant trial, extract meaningful information:] - **NCT#####**: [Specific findings relevant to query - be detailed with numbers/outcomes] - **NCT#####**: [What this tells us - phases, enrollment, results if available] [Include even partially relevant trials with appropriate context] ### CLINICAL INSIGHTS [Synthesize patterns and meaningful conclusions:] - What do these trials collectively reveal? - Treatment landscape and development status - Efficacy signals or safety patterns - How different trials complement each other - Comparison with similar drugs/approaches if relevant ### ADDITIONAL CONTEXT [Brief, if needed - but keep positive and informative:] - If data is from different indications, explain transferable insights - If only early phase data, discuss what this means for development - Focus on what the data DOES tell us REMEMBER: - Users want actionable information, not disclaimers - Even Phase 1 safety data is valuable information - Cross-indication data provides mechanism insights - Company trial portfolios reveal strategic priorities - Similar drug classes offer comparative context - ALWAYS find something valuable to report""" if groq_api_key: logger.info("Generating response with Llama-3.1-70B via GROQ (fast)...") from groq import Groq client = Groq(api_key=groq_api_key) response = client.chat.completions.create( model="llama-3.1-70b-versatile", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], max_tokens=2000, # Increased for comprehensive answers temperature=0.3, timeout=30 ) return response.choices[0].message.content.strip() else: # Fallback to HuggingFace (slower) logger.info("Generating response with Llama-3.1-70B via HuggingFace (slow)...") from huggingface_hub import InferenceClient client = InferenceClient(token=hf_token, timeout=120) response = client.chat_completion( model="meta-llama/Meta-Llama-3.1-70B-Instruct", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], max_tokens=2000, # Increased for comprehensive answers temperature=0.3 ) return response.choices[0].message.content.strip() except Exception as e: logger.error(f"Llama error: {e}") return f"Llama API error: {str(e)}" def process_query_simple_test(conversation): """TEST JUST THE RAG - no models""" try: import time output = [] output.append(f"QUERY: {conversation}\n") # Check if embeddings loaded if doc_embeddings is None or len(doc_chunks) == 0: return "FAIL: Embeddings not loaded" output.append(f"✓ Embeddings loaded: {len(doc_chunks)} chunks\n") output.append(f"✓ Embeddings shape: {doc_embeddings.shape}\n") # Try to search start = time.time() context = retrieve_context_with_embeddings(conversation, top_k=10) search_time = time.time() - start if not context: return "".join(output) + "\nFAIL: RAG returned empty" output.append(f"✓ RAG search took: {search_time:.2f}s\n") output.append(f"✓ Retrieved {context.count('NCT')} trials\n\n") output.append("FIRST 1000 CHARS:\n") output.append(context[:1000]) return "".join(output) except Exception as e: import traceback return f"ERROR IN RAG TEST:\n{str(e)}\n\nTRACEBACK:\n{traceback.format_exc()}" def process_query(conversation): """ Complete pipeline with LLM query parsing, planning agent, and natural language generation Flow: 0. LLM Parser - Extract drugs, diseases, companies, endpoints (~2-3s) 0.5. Planning Agent - Decide action: SEARCH_TRIALS / COUNT_AGGREGATE / GENERAL_KNOWLEDGE (~1s) 1. Execute Action - Based on plan: RAG search, index count, or skip to LLM (~2s) 2. Skipped - 355M ranking removed (was broken) 3. LLM Response - Llama 70B generates natural language (~15s) Total: ~21 seconds """ import time import traceback import sys # MASTER try/except - catches EVERYTHING try: start_time = time.time() output_parts = [f"QUERY: {conversation}\n\n"] # Step 0: Parse query with LLM to extract structured info try: step0_start = time.time() logger.info("Step 0: Parsing query with LLM...") output_parts.append("✓ Step 0: LLM query parser started...\n") parsed_query = parse_query_with_llm(conversation, hf_token=hf_token) # Use optimized search terms from parser search_query = parsed_query['search_terms'] step0_time = time.time() - step0_start output_parts.append(f"✓ Step 0 Complete: Extracted entities ({step0_time:.1f}s)\n") output_parts.append(f" Drugs: {parsed_query['drugs']}\n") output_parts.append(f" Diseases: {parsed_query['diseases']}\n") output_parts.append(f" Companies: {parsed_query['companies']}\n") output_parts.append(f" Optimized search: {search_query}\n") logger.info(f"Query parsing successful in {step0_time:.1f}s") except Exception as e: error_msg = f"✗ Step 0 WARNING (LLM Parser): {str(e)}, using original query" logger.warning(error_msg) output_parts.append(f"{error_msg}\n") search_query = conversation # Fallback to original parsed_query = {'drugs': [], 'diseases': [], 'companies': []} # Step 0.5: Planning agent decides action try: planning_start = time.time() logger.info("Step 0.5: Planning agent deciding action...") output_parts.append("✓ Step 0.5: Planning agent started...\n") plan = plan_query_action(conversation, parsed_query, hf_token=hf_token) planning_time = time.time() - planning_start output_parts.append(f"✓ Step 0.5 Complete: Action decided ({planning_time:.1f}s)\n") output_parts.append(f" Action: {plan['action']}\n") output_parts.append(f" Reasoning: {plan['reasoning']}\n") logger.info(f"Planning complete: {plan['action']} - {plan['reasoning']}") except Exception as e: error_msg = f"✗ Step 0.5 WARNING (Planning): {str(e)}, defaulting to SEARCH_TRIALS" logger.warning(error_msg) output_parts.append(f"{error_msg}\n") plan = {'action': 'SEARCH_TRIALS', 'reasoning': 'Planning failed', 'params': search_query} # Step 1: Execute action based on plan if plan['action'] == 'GENERAL_KNOWLEDGE': # Skip RAG entirely, go straight to LLM try: step1_start = time.time() logger.info("Step 1: GENERAL_KNOWLEDGE - Skipping RAG...") output_parts.append("✓ Step 1: Skipped RAG (general knowledge query)\n") context = "" # Empty context step1_time = time.time() - step1_start output_parts.append(f"✓ Step 1 Complete: Using LLM knowledge only ({step1_time:.1f}s)\n") except Exception as e: error_msg = f"✗ Step 1 FAILED: {str(e)}" logger.error(error_msg) return error_msg elif plan['action'] == 'COUNT_AGGREGATE': # Use index to count, pass summary to LLM try: step1_start = time.time() logger.info("Step 1: COUNT_AGGREGATE - Using inverted index...") output_parts.append("✓ Step 1: Count/aggregation started...\n") # Get search terms from plan search_terms = plan['params'].lower().split() # Find matching trials from inverted index global inverted_index matching_trial_ids = set() if inverted_index: for term in search_terms: if term in inverted_index: matching_trial_ids.update(inverted_index[term]) logger.info(f" Found {len(inverted_index[term])} trials for '{term}'") # Create summary context if matching_trial_ids: context = f"Found {len(matching_trial_ids)} trials matching the query.\n\n" context += f"Note: This is an aggregate count. For detailed information about specific trials, " context += f"please ask a more specific question about individual drugs or treatments." else: context = "No trials found matching the query." step1_time = time.time() - step1_start output_parts.append(f"✓ Step 1 Complete: Found {len(matching_trial_ids)} matching trials ({step1_time:.1f}s)\n") logger.info(f"Count aggregation complete - {len(matching_trial_ids)} trials in {step1_time:.1f}s") except Exception as e: error_msg = f"✗ Step 1 FAILED (Count): {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return error_msg elif plan['action'] == 'COMPARE': # Compare treatments - retrieve trials for each and let LLM analyze try: step1_start = time.time() logger.info("Step 1: COMPARE - Retrieving trials for comparison...") output_parts.append("✓ Step 1: Comparison search started...\n") # Extract treatments to compare from parsed drugs treatments = parsed_query.get('drugs', []) if len(treatments) < 2: # Try to extract from query text if not in parsed drugs import re compare_patterns = [ r'(\w+)\s+(?:vs|versus|vs\.)\s+(\w+)', r'compare\s+(\w+)\s+(?:and|with|to)\s+(\w+)' ] for pattern in compare_patterns: match = re.search(pattern, conversation.lower()) if match: treatments = [match.group(1), match.group(2)] break if len(treatments) < 2: context = "Could not identify two treatments to compare. Please specify which treatments you'd like to compare." else: logger.info(f"[COMPARE] Comparing: {treatments[0]} vs {treatments[1]}") # Search for trials for each treatment context_parts = [] for i, treatment in enumerate(treatments[:2], 1): # Compare first 2 logger.info(f"[COMPARE] Searching trials for {treatment}...") treatment_trials = retrieve_context_with_embeddings(treatment, top_k=10, entities=parsed_query) if treatment_trials: context_parts.append(f"=== TRIALS FOR {treatment.upper()} ===\n{treatment_trials}\n") else: context_parts.append(f"=== TRIALS FOR {treatment.upper()} ===\nNo trials found.\n") # Combine all trials for LLM comparison context = "\n".join(context_parts) context += f"\n\nPLEASE COMPARE: {treatments[0]} vs {treatments[1]}\n" context += "Analyze the trials above and provide a side-by-side comparison including:\n" context += "- Number of trials for each\n" context += "- Key indications/diseases studied\n" context += "- Trial phases\n" context += "- Notable efficacy or safety findings\n" context += "- Head-to-head comparison trials (if any)" step1_time = time.time() - step1_start output_parts.append(f"✓ Step 1 Complete: Retrieved comparison data ({step1_time:.1f}s)\n") logger.info(f"Comparison search complete in {step1_time:.1f}s") except Exception as e: error_msg = f"✗ Step 1 FAILED (Compare): {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return error_msg else: # SEARCH_TRIALS - normal RAG search (using optimized search query) try: step1_start = time.time() logger.info("Step 1: RAG search...") output_parts.append("✓ Step 1: RAG search started...\n") # Pass entities for STRICT company filtering context = retrieve_context_with_embeddings(search_query, top_k=10, entities=parsed_query) if not context: return "No matching trials found in RAG search." # No limit - use complete trials step1_time = time.time() - step1_start output_parts.append(f"✓ Step 1 Complete: Found {context.count('NCT')} trials ({step1_time:.1f}s)\n") logger.info(f"RAG search successful - found trials in {step1_time:.1f}s") except Exception as e: error_msg = f"✗ Step 1 FAILED (RAG search): {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return error_msg # Step 2: Skipped (355M ranking removed - was broken) output_parts.append("✓ Step 2: Skipped (using hybrid search + recency)\n") # Step 3: Llama 70B try: step3_start = time.time() logger.info("Step 3: Generating response with Llama-3.1-70B...") output_parts.append("✓ Step 3: Llama 70B generation started...\n") llama_response = generate_llama_response( conversation, context, hf_token=hf_token, parsed_entities=parsed_query, planning_context=plan ) step3_time = time.time() - step3_start output_parts.append(f"✓ Step 3 Complete: Llama 70B response generated ({step3_time:.1f}s)\n") logger.info(f"Llama 70B generation successful in {step3_time:.1f}s") except Exception as e: error_msg = f"✗ Step 3 FAILED (Llama 70B): {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) llama_response = f"[Llama 70B error: {str(e)}]" output_parts.append(f"✗ Step 3 Failed: {str(e)}\n") total_time = time.time() - start_time # Format output - handle missing variables try: context_display = context if 'context' in locals() else "[No context retrieved]" clinical_display = clinical_context_355m if 'clinical_context_355m' in locals() else "[355M not run]" llama_display = llama_response if 'llama_response' in locals() else "[Llama 70B not run]" output = f"""{''.join(output_parts)} CLINICAL SUMMARY (Llama-3.1-70B-Instruct): {llama_display} --- RAG RETRIEVED TRIALS (Top 3 Most Relevant): {context_display} --- Total Time: {total_time:.1f}s """ # Record analytics query_type = plan.get('action', 'UNKNOWN') if 'plan' in locals() else 'UNKNOWN' query_analytics.record_query(query_type, total_time, success=True) return output except Exception as e: # Absolute fallback error_info = f""" CRITICAL ERROR IN OUTPUT FORMATTING: {str(e)} TRACEBACK: {traceback.format_exc()} OUTPUT PARTS: {''.join(output_parts)} Variables defined: {locals().keys()} """ logger.error(error_info) return error_info # MASTER EXCEPTION HANDLER - catches ANY unhandled error except Exception as master_error: master_error_msg = f""" ======================================== MASTER ERROR HANDLER CAUGHT EXCEPTION ======================================== Error Type: {type(master_error).__name__} Error Message: {str(master_error)} FULL TRACEBACK: {traceback.format_exc()} System Info: - Python version: {sys.version} - Error at line: {sys.exc_info()[2].tb_lineno if sys.exc_info()[2] else 'unknown'} ======================================== """ logger.error(master_error_msg) # Record analytics for error elapsed_time = time.time() - start_time if 'start_time' in locals() else 0 query_analytics.record_query('ERROR', elapsed_time, success=False) return master_error_msg def get_analytics_report(): """ Get analytics report for monitoring Returns formatted string with query statistics """ stats = query_analytics.get_stats() uptime_hours = stats['uptime_seconds'] / 3600 report = f""" === ANALYTICS REPORT === Uptime: {uptime_hours:.1f} hours Total Queries: {stats['total_queries']} Error Rate: {stats['error_rate']*100:.1f}% Query Type Distribution: """ for query_type, count in stats['query_type_distribution'].items(): percentage = (count / stats['total_queries'] * 100) if stats['total_queries'] > 0 else 0 avg_time = stats['avg_response_times'].get(query_type, 0) report += f" {query_type}: {count} queries ({percentage:.1f}%) - avg {avg_time:.2f}s\n" report += "\n=== END REPORT ===\n" return report # ============================================================================ # 355M PERPLEXITY-BASED RANKING (FOR STRUCTURED JSON API) # ============================================================================ def rank_trials_with_355m_perplexity(query, trials_list, hf_token=None): """ Rank trials using 355M Clinical Trial GPT perplexity scoring This uses the model for SCORING not GENERATION to avoid hallucinations Lower perplexity = more relevant trial Args: query: User query trials_list: List of (score, trial_text) tuples from hybrid search hf_token: Not needed (model runs locally) Returns: List of dicts with trial data and perplexity scores """ import time import re import torch from transformers import GPT2LMHeadModel, GPT2TokenizerFast start_time = time.time() # Only rank top 10 trials (balance between accuracy and speed) top_10 = trials_list[:10] logger.info(f"[355M PERPLEXITY] Ranking {len(top_10)} trials with CT2 model...") try: # Load 355M model tokenizer = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/CT2") model = GPT2LMHeadModel.from_pretrained( "gmkdigitalmedia/CT2", torch_dtype=torch.float16, device_map="auto" ) model.eval() tokenizer.pad_token = tokenizer.eos_token ranked_trials = [] for idx, (hybrid_score, trial_text) in enumerate(top_10): # Extract NCT ID nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text) nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}" # Format test text for perplexity calculation # The model calculates: "How natural is this query-trial pairing?" test_text = f"""Query: {query} Relevant Clinical Trial: {trial_text[:800]} This trial is highly relevant because""" # Calculate perplexity (lower = more relevant) inputs = tokenizer( test_text, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(model.device) with torch.no_grad(): outputs = model(**inputs, labels=inputs.input_ids) perplexity = torch.exp(outputs.loss).item() # Convert perplexity to relevance score (0-1) # Typical range: 10-1000, lower is better perplexity_score = 1.0 / (1.0 + perplexity / 100) # Combine with hybrid score (70% hybrid, 30% perplexity) combined_score = 0.7 * hybrid_score + 0.3 * perplexity_score logger.info(f"[355M] {nct_id}: Hybrid={hybrid_score:.3f}, " f"Perplexity={perplexity:.1f}, " f"Perplexity_Score={perplexity_score:.3f}, " f"Combined={combined_score:.3f}") ranked_trials.append({ 'nct_id': nct_id, 'trial_text': trial_text, 'hybrid_score': float(hybrid_score), 'perplexity': float(perplexity), 'perplexity_score': float(perplexity_score), 'combined_score': float(combined_score), 'rank_before_355m': idx + 1 }) # Sort by combined score (descending) ranked_trials.sort(key=lambda x: x['combined_score'], reverse=True) # Add final rank for idx, trial in enumerate(ranked_trials): trial['rank_after_355m'] = idx + 1 elapsed = time.time() - start_time logger.info(f"[355M PERPLEXITY] ✓ Ranking complete in {elapsed:.1f}s") # Add remaining trials (beyond top 10) without 355M scoring for idx, (hybrid_score, trial_text) in enumerate(trials_list[10:], start=10): nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text) nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}" ranked_trials.append({ 'nct_id': nct_id, 'trial_text': trial_text, 'hybrid_score': float(hybrid_score), 'perplexity': None, 'perplexity_score': None, 'combined_score': float(hybrid_score), 'rank_before_355m': idx + 1, 'rank_after_355m': len(ranked_trials) + 1 }) return ranked_trials except Exception as e: logger.error(f"[355M PERPLEXITY] Error: {e}") logger.warning("[355M PERPLEXITY] Falling back to hybrid scores only") # Fallback: return trials with hybrid scores only fallback_trials = [] for idx, (hybrid_score, trial_text) in enumerate(trials_list): nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text) nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}" fallback_trials.append({ 'nct_id': nct_id, 'trial_text': trial_text, 'hybrid_score': float(hybrid_score), 'perplexity': None, 'perplexity_score': None, 'combined_score': float(hybrid_score), 'rank_before_355m': idx + 1, 'rank_after_355m': idx + 1 }) return fallback_trials def parse_trial_text_to_dict(trial_text, nct_id): """ Parse trial text into structured dictionary Args: trial_text: Raw trial text nct_id: NCT ID Returns: Dict with parsed trial fields """ import re trial_dict = { 'nct_id': nct_id, 'title': '', 'sponsor': '', 'collaborators': [], 'phase': '', 'status': '', 'enrollment': None, 'conditions': [], 'interventions': [], 'primary_outcome': '', 'results_summary': '', 'start_date': '', 'completion_date': '', 'last_update': '', 'locations': [] } # Extract fields using regex patterns try: # Title title_match = re.search(r'TITLE:\s*([^\n]+)', trial_text, re.IGNORECASE) if title_match: trial_dict['title'] = title_match.group(1).strip() # Sponsor sponsor_match = re.search(r'SPONSOR:\s*([^\n]+)', trial_text, re.IGNORECASE) if sponsor_match: trial_dict['sponsor'] = sponsor_match.group(1).strip() # Collaborators collab_match = re.search(r'COLLABORATOR[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE) if collab_match: collabs = collab_match.group(1).strip().split(',') trial_dict['collaborators'] = [c.strip() for c in collabs if c.strip()] # Phase phase_match = re.search(r'PHASE:\s*([^\n]+)', trial_text, re.IGNORECASE) if phase_match: trial_dict['phase'] = phase_match.group(1).strip() # Status status_match = re.search(r'STATUS:\s*([^\n]+)', trial_text, re.IGNORECASE) if status_match: trial_dict['status'] = status_match.group(1).strip() # Enrollment enrollment_match = re.search(r'ENROLLMENT:\s*(\d+)', trial_text, re.IGNORECASE) if enrollment_match: trial_dict['enrollment'] = int(enrollment_match.group(1)) # Conditions condition_match = re.search(r'CONDITION[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE) if condition_match: conditions = condition_match.group(1).strip().split(',') trial_dict['conditions'] = [c.strip() for c in conditions if c.strip()] # Interventions intervention_match = re.search(r'INTERVENTION[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE) if intervention_match: interventions = intervention_match.group(1).strip().split(',') trial_dict['interventions'] = [i.strip() for i in interventions if i.strip()] # Primary outcome outcome_match = re.search(r'PRIMARY[_ ]OUTCOME:\s*([^\n]+)', trial_text, re.IGNORECASE) if outcome_match: trial_dict['primary_outcome'] = outcome_match.group(1).strip() # Results summary results_match = re.search(r'RESULTS:\s*([^\n]+)', trial_text, re.IGNORECASE) if results_match: trial_dict['results_summary'] = results_match.group(1).strip() # Dates start_match = re.search(r'START[_ ]DATE:\s*([^\n]+)', trial_text, re.IGNORECASE) if start_match: trial_dict['start_date'] = start_match.group(1).strip() completion_match = re.search(r'COMPLETION[_ ]DATE:\s*([^\n]+)', trial_text, re.IGNORECASE) if completion_match: trial_dict['completion_date'] = completion_match.group(1).strip() # Locations location_match = re.search(r'LOCATION[S]?:\s*([^\n]+)', trial_text, re.IGNORECASE) if location_match: locations = location_match.group(1).strip().split(',') trial_dict['locations'] = [l.strip() for l in locations if l.strip()] except Exception as e: logger.warning(f"Error parsing trial {nct_id}: {e}") return trial_dict def process_query_structured(query, top_k=10): """ Process query and return structured JSON (no LLM response generation) This is the new API endpoint that: 1. Uses LLM for query parsing/entity extraction 2. Performs hybrid RAG search 3. Ranks with 355M perplexity scoring 4. Returns structured JSON Args: query: User query top_k: Number of trials to return Returns: Dict with structured response """ import time start_time = time.time() result = { 'query': query, 'processing_time': 0, 'query_analysis': {}, 'results': {}, 'trials': [], 'benchmarking': {}, 'metadata': {} } try: # Step 1: Parse query with LLM step1_start = time.time() logger.info("[STRUCTURED API] Step 1: Parsing query with LLM...") try: parsed_query = parse_query_with_llm(query, hf_token=hf_token) search_query = parsed_query['search_terms'] result['query_analysis'] = { 'extracted_entities': { 'drugs': parsed_query.get('drugs', []), 'diseases': parsed_query.get('diseases', []), 'companies': parsed_query.get('companies', []), 'endpoints': parsed_query.get('endpoints', []) }, 'optimized_search': search_query, 'parsing_time': time.time() - step1_start } logger.info(f"[STRUCTURED API] Query parsed in {time.time() - step1_start:.1f}s") except Exception as e: logger.warning(f"[STRUCTURED API] Query parsing failed: {e}, using original query") search_query = query parsed_query = {'drugs': [], 'diseases': [], 'companies': [], 'endpoints': []} result['query_analysis'] = { 'extracted_entities': parsed_query, 'optimized_search': search_query, 'parsing_time': time.time() - step1_start, 'error': str(e) } # Step 2: Hybrid RAG search step2_start = time.time() logger.info("[STRUCTURED API] Step 2: Hybrid RAG search...") # Get more candidates for 355M ranking candidate_k = top_k * 3 # We need to get the candidate trials with scores # Re-implement the key parts of retrieve_context_with_embeddings to get structured data from collections import Counter global doc_chunks, doc_embeddings, embedder, inverted_index if doc_embeddings is None or len(doc_chunks) == 0: raise Exception("Embeddings not loaded!") # Extract keywords stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'what', 'how', 'do', 'you', 'know', 'about', 'that', 'this', 'there', 'it'} query_lower = search_query.lower() import re words = re.findall(r'\b\w+\b', query_lower) query_terms = [w for w in words if len(w) > 2 and w not in stop_words] # Keyword scoring with inverted index keyword_scores = {} if inverted_index is not None: inv_index_candidates = set() for term in query_terms: if term in inverted_index: inv_index_candidates.update(inverted_index[term]) if inv_index_candidates: drug_specific_terms = set() for term in query_terms: if term in inverted_index and len(inverted_index[term]) < 100: drug_specific_terms.add(term) for idx in inv_index_candidates: chunk_data = doc_chunks[idx] chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data chunk_lower = chunk_text.lower() has_drug_match = any(drug_term in chunk_lower for drug_term in drug_specific_terms) if has_drug_match: keyword_scores[idx] = 1000.0 else: keyword_scores[idx] = 1.0 # Semantic scoring load_embedder() query_embedding = embedder.encode([search_query])[0] semantic_similarities = np.dot(doc_embeddings, query_embedding) # Normalize and combine scores if keyword_scores: max_kw = max(keyword_scores.values()) keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()} else: keyword_scores_norm = {} max_sem = semantic_similarities.max() min_sem = semantic_similarities.min() semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10) # Combined scores combined_scores = np.zeros(len(doc_chunks)) for idx in range(len(doc_chunks)): kw_score = keyword_scores_norm.get(idx, 0.0) sem_score = semantic_scores_norm[idx] combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score if kw_score > 0 else sem_score # Get top candidates top_indices = np.argsort(combined_scores)[-candidate_k:][::-1] # Format as (score, text) tuples candidate_trials = [(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i]) for i in top_indices] rag_time = time.time() - step2_start logger.info(f"[STRUCTURED API] RAG search complete in {rag_time:.1f}s, found {len(candidate_trials)} candidates") # Step 3: Rank with 355M perplexity step3_start = time.time() logger.info("[STRUCTURED API] Step 3: Ranking with 355M perplexity...") ranked_trials = rank_trials_with_355m_perplexity(query, candidate_trials, hf_token=hf_token) ranking_time = time.time() - step3_start logger.info(f"[STRUCTURED API] 355M ranking complete in {ranking_time:.1f}s") # Format results result['results'] = { 'total_found': len(candidate_trials), 'returned': min(top_k, len(ranked_trials)), 'top_relevance_score': ranked_trials[0]['combined_score'] if ranked_trials else 0 } # Parse trials and add to results for trial_data in ranked_trials[:top_k]: trial_dict = parse_trial_text_to_dict(trial_data['trial_text'], trial_data['nct_id']) trial_dict['scoring'] = { 'relevance_score': trial_data['combined_score'], 'hybrid_score': trial_data['hybrid_score'], 'perplexity': trial_data['perplexity'], 'perplexity_score': trial_data['perplexity_score'], 'rank_before_355m': trial_data['rank_before_355m'], 'rank_after_355m': trial_data['rank_after_355m'], 'ranking_method': '355m_perplexity' if trial_data['perplexity'] is not None else 'hybrid_only' } trial_dict['url'] = f"https://clinicaltrials.gov/study/{trial_data['nct_id']}" result['trials'].append(trial_dict) # Benchmarking data if ranked_trials: # Calculate how much 355M changed the ranking rank_changes = [] for trial in ranked_trials[:top_k]: if trial['perplexity'] is not None: rank_change = trial['rank_before_355m'] - trial['rank_after_355m'] rank_changes.append(rank_change) result['benchmarking'] = { 'rag_search_time': rag_time, '355m_ranking_time': ranking_time, 'total_processing_time': time.time() - start_time, 'trials_ranked_by_355m': len([t for t in ranked_trials if t['perplexity'] is not None]), 'average_rank_change': sum(rank_changes) / len(rank_changes) if rank_changes else 0, 'max_rank_improvement': max(rank_changes) if rank_changes else 0, 'top_3_perplexity_scores': [t['perplexity'] for t in ranked_trials[:3] if t['perplexity'] is not None] } # Metadata result['metadata'] = { 'database_version': '2025-01-06', 'total_trials_searched': len(doc_chunks), 'api_version': '2.0.0', 'model_info': { 'query_parser': 'Llama-3.1-70B-Instruct', 'ranking_model': 'gmkdigitalmedia/CT2-355M', 'embedding_model': 'all-MiniLM-L6-v2' } } result['processing_time'] = time.time() - start_time logger.info(f"[STRUCTURED API] ✓ Complete in {result['processing_time']:.1f}s") return result except Exception as e: logger.error(f"[STRUCTURED API] Error: {e}") import traceback result['error'] = str(e) result['traceback'] = traceback.format_exc() result['processing_time'] = time.time() - start_time return result # ============================================================================ # GRADIO INTERFACE # ============================================================================ with gr.Blocks(title="Foundation 1.2") as demo: gr.Markdown("# Foundation 1.2 - Clinical Trial AI") query_input = gr.Textbox( label="Ask about clinical trials", placeholder="Example: What are the results for ianalumab in Sjogren's syndrome?", lines=3 ) submit_btn = gr.Button("Generate Response", variant="primary") output = gr.Textbox( label="AI Response", lines=30 ) submit_btn.click( fn=process_query, # Full pipeline: RAG + 355M + Llama inputs=query_input, outputs=output ) gr.Markdown(""" **Production Pipeline - Optimized for Clinical Accuracy** """) # ============================================================================ # STARTUP # ============================================================================ # Embeddings will be loaded by FastAPI startup event in app.py # Do NOT load here - causes Docker permission errors logger.info("=== Foundation 1.2 Module Loaded ===") logger.info("Call load_embeddings() to initialize the system") if DATASET_FILE.exists(): file_size_mb = DATASET_FILE.stat().st_size / (1024 * 1024) logger.info(f"✓ Dataset file found: {file_size_mb:.0f}MB") else: logger.error("✗ Dataset file not found!") logger.info("=== Startup Complete ===") if __name__ == "__main__": demo.launch()