""" FAST VERSION: Bypasses 355M ranking bottleneck (300s -> 0s) Works with existing data structure: List[Tuple[int, str]] Keeps BM25 + semantic hybrid search intact """ import torch from transformers import GPT2LMHeadModel, GPT2TokenizerFast, AutoTokenizer, AutoModelForCausalLM import logging import spaces from functools import lru_cache from typing import List, Tuple, Optional, Dict from huggingface_hub import InferenceClient logger = logging.getLogger(__name__) # =========================================================================== # CACHED MODEL LOADING - Load once, reuse forever # =========================================================================== @lru_cache(maxsize=1) def get_cached_355m_model(): """Load 355M model once and cache it for entity extraction""" logger.info("Loading 355M Clinical Trial GPT (cached for entity extraction)...") tokenizer = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/clinicaltrial2.2") model = GPT2LMHeadModel.from_pretrained( "gmkdigitalmedia/clinicaltrial2.2", torch_dtype=torch.float16, device_map="auto" ) model.eval() return tokenizer, model @lru_cache(maxsize=1) def get_cached_8b_model(hf_token: Optional[str] = None): """Load 8B model once and cache it""" logger.info("Loading II-Medical-8B (cached)...") tokenizer = AutoTokenizer.from_pretrained( "Intelligent-Internet/II-Medical-8B-1706", token=hf_token, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( "Intelligent-Internet/II-Medical-8B-1706", device_map="auto", token=hf_token, trust_remote_code=True, torch_dtype=torch.bfloat16 ) return tokenizer, model # =========================================================================== # FAST RANKING - Replace 300s function with instant passthrough # =========================================================================== @spaces.GPU def rank_trials_FAST(query: str, trials_list: List[Tuple[float, str]], hf_token=None) -> List[Tuple[float, str]]: """ SMART RANKING: Use 355M to rank only top 3 trials Takes top 3 from BM25+semantic search, then uses 355M Clinical Trial GPT to re-rank them by clinical relevance. Time: ~30 seconds for 3 trials (vs 300s for 30 trials) Args: query: The search query trials_list: List of (score, trial_text) tuples from BM25+semantic search hf_token: Not needed Returns: Top 3 trials re-ranked by 355M clinical relevance """ import time import re start_time = time.time() # Take only top 3 trials for 355M ranking top_3 = trials_list[:3] logger.info(f"[355M RANKING] Ranking top 3 trials with Clinical Trial GPT...") # Get cached 355M model tokenizer, model = get_cached_355m_model() # Score each trial trial_scores = [] for idx, (bm25_score, trial_text) in enumerate(top_3): # Extract NCT ID for logging nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text) nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}" # Create prompt for relevance scoring # Truncate trial to 800 chars to keep it fast trial_snippet = trial_text[:800] prompt = f"""Query: {query} Clinical Trial: {trial_snippet} Rate clinical relevance (1-10):""" # Get model score try: inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device) with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_length=inputs.input_ids.shape[1] + 10, temperature=0.3, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) generated = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) # Extract number from response score_match = re.search(r'(\d+)', generated.strip()) relevance_score = float(score_match.group(1)) if score_match else 5.0 # Normalize to 0-1 range relevance_score = relevance_score / 10.0 logger.info(f"[355M RANKING] {nct_id}: relevance={relevance_score:.2f} (BM25={bm25_score:.3f})") except Exception as e: logger.warning(f"[355M RANKING] Scoring failed for {nct_id}: {e}, using BM25 score") relevance_score = bm25_score trial_scores.append((relevance_score, trial_text, nct_id)) # Sort by 355M relevance score (descending) trial_scores.sort(key=lambda x: x[0], reverse=True) # Format as (score, text) tuples for backwards compatibility # Create a custom list class that can hold attributes class RankedTrialsList(list): """List that can hold ranking metadata""" pass ranked_trials = RankedTrialsList() ranking_metadata = [] for rank, (score, text, nct_id) in enumerate(trial_scores, 1): ranked_trials.append((score, text)) ranking_metadata.append({ 'rank': rank, 'nct_id': nct_id, 'relevance_score': score, 'relevance_rating': f"{score*10:.1f}/10" }) elapsed = time.time() - start_time logger.info(f"[355M RANKING] ✓ Ranked 3 trials in {elapsed:.1f}s") logger.info(f"[355M RANKING] Final order: {[nct_id for _, _, nct_id in trial_scores]}") logger.info(f"[355M RANKING] Scores: {[f'{s:.2f}' for s, _, _ in trial_scores]}") # Store metadata as attribute for retrieval ranked_trials.ranking_info = ranking_metadata # Return re-ranked top 3 plus remaining trials (if any) return ranked_trials + trials_list[3:] # Alias for drop-in replacement rank_trials_with_355m = rank_trials_FAST # Override the slow function! # =========================================================================== # FAST GENERATION using HuggingFace Inference API (Free) # =========================================================================== def generate_with_llama_70b_hf(query: str, rag_context: str = "", hf_token: str = None) -> str: """ Use Llama-3.1-70B via HuggingFace Inference API (FREE) This is what you're already using successfully! ~10 second response time on HF free tier """ try: logger.info("Using Llama-3.1-70B via HuggingFace Inference API...") client = InferenceClient(token=hf_token) messages = [ { "role": "system", "content": "You are a medical information specialist. Answer based on the provided clinical trial data. Be concise and accurate." }, { "role": "user", "content": f"""Clinical Trial Data: {rag_context[:4000]} Question: {query} Please provide a concise answer based on the clinical trial data above.""" } ] response = client.chat_completion( model="meta-llama/Llama-3.1-70B-Instruct", messages=messages, max_tokens=512, temperature=0.3 ) answer = response.choices[0].message.content.strip() logger.info(f"Llama 70B response generated via HF Inference API") return answer except Exception as e: logger.error(f"Llama 70B generation failed: {e}") return f"Error generating response with Llama 70B: {str(e)}" # =========================================================================== # OPTIMIZED 8B GENERATION (with cached model) # =========================================================================== @spaces.GPU def generate_clinical_response_with_xupract(conversation, rag_context="", hf_token=None): """OPTIMIZED: Use cached 8B model for faster generation""" logger.info("Generating response with cached II-Medical-8B...") # Get cached model (loads once, reuses after) tokenizer, model = get_cached_8b_model(hf_token) # Build prompt with RAG context (ChatML format for II-Medical-8B) if rag_context: prompt = f"""<|im_start|>system You are a medical information specialist. Answer based on the provided clinical trial data. Please reason step-by-step, and put your final answer within \\boxed{{}}. <|im_end|> <|im_start|>user Clinical Trial Data: {rag_context[:4000]} Question: {conversation} Please reason step-by-step, and put your final answer within \\boxed{{}}. <|im_end|> <|im_start|>assistant """ else: prompt = f"""<|im_start|>system You are a medical information specialist. Please reason step-by-step, and put your final answer within \\boxed{{}}. <|im_end|> <|im_start|>user {conversation} Please reason step-by-step, and put your final answer within \\boxed{{}}. <|im_end|> <|im_start|>assistant """ try: inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=1024, temperature=0.3, do_sample=True, top_p=0.9, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id ) response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip() return response except Exception as e: logger.error(f"Generation failed: {e}") return f"Error generating response: {str(e)}" # =========================================================================== # FAST ENTITY EXTRACTION (with cached model) # =========================================================================== @spaces.GPU def extract_entities_with_small_model(conversation): """OPTIMIZED: Use cached 355M model for entity extraction""" logger.info("Extracting entities with cached 355M model...") # Get cached model tokenizer, model = get_cached_355m_model() # Better prompt for extraction prompt = f"""Clinical query: {conversation} Extract: Drug name:""" inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(model.device) with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_length=400, temperature=0.3, top_p=0.9, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated # =========================================================================== # QUERY EXPANSION (optional, with cached model) # =========================================================================== @spaces.GPU def expand_query_with_355m(query): """OPTIMIZED: Use cached 355M for query expansion""" logger.info("Expanding query with cached 355M...") # Get cached model tokenizer, model = get_cached_355m_model() # Prompt to get clinical context prompt = f"Question: {query}\nClinical trial information:" inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(model.device) with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_length=inputs.input_ids.shape[1] + 100, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract the expansion part if "Clinical trial information:" in generated: expansion = generated.split("Clinical trial information:")[-1].strip() else: expansion = generated[len(prompt):].strip() # Limit to reasonable length expansion = expansion[:500] if len(expansion) > 500 else expansion logger.info(f"Query expanded: {expansion[:100]}...") return expansion # =========================================================================== # MAIN PIPELINE - Now FAST! # =========================================================================== def process_two_llm_system(conversation, rag_context="", hf_token=None, use_validation=False): """ FAST pipeline: 1. Small 355M model extracts entities (cached model - fast) 2. RAG retrieves context (BM25 + semantic - already fast) 3. Big model generates response (8B local or 70B API) 4. Skip validation for speed Total time: ~15s instead of 300+s """ import time start_time = time.time() # Step 1: Use cached 355M to extract entities entities = extract_entities_with_small_model(conversation) logger.info(f"Entities extracted in {time.time()-start_time:.1f}s") # Step 2: Generate response (choose one): # Option A: Use 70B via HF Inference API (better quality, ~10s) if hf_token: clinical_evidence = generate_with_llama_70b_hf( conversation, rag_context, hf_token ) model_used = "Llama-3.1-70B (HF Inference API)" else: # Option B: Use cached 8B model (faster loading, ~5s) clinical_evidence = generate_clinical_response_with_xupract( conversation, rag_context, hf_token ) model_used = "II-Medical-8B (cached)" total_time = time.time() - start_time logger.info(f"Total pipeline time: {total_time:.1f}s (was 300+s with 355M ranking)") return { 'clinical_evidence': clinical_evidence, 'entities': entities, 'model_used': model_used, 'time_taken': total_time } def format_two_llm_response(result): """Format the fast response""" return f"""ENTITY EXTRACTION (Clinical Trial GPT 355M - Cached) {'='*60} {result.get('entities', 'None identified')} CLINICAL RESPONSE ({result.get('model_used', 'Unknown')}) {'='*60} {result['clinical_evidence']} PERFORMANCE {'='*60} Time: {result.get('time_taken', 0):.1f}s (was 300+s with 355M ranking) {'='*60} """ # =========================================================================== # PRELOAD MODELS AT STARTUP (Call this once in app.py!) # =========================================================================== def preload_all_models(hf_token=None): """ Call this ONCE at app startup to cache all models. This prevents model reloading on every query. Add to your app.py initialization: from two_llm_system_FAST import preload_all_models preload_all_models(hf_token) """ logger.info("Preloading and caching all models...") # Cache the 355M model _ = get_cached_355m_model() logger.info("✓ 355M model cached") # Cache the 8B model if token available if hf_token: try: _ = get_cached_8b_model(hf_token) logger.info("✓ 8B model cached") except Exception as e: logger.warning(f"Could not cache 8B model: {e}") logger.info("All models preloaded and cached!") # =========================================================================== # BACKWARD COMPATIBILITY - Keep all original function names # =========================================================================== # These functions exist in the original but we optimize them validate_with_small_model = lambda *args, **kwargs: "Validation skipped for speed" extract_keywords_with_llama = lambda conv, hf_token=None: extract_entities_with_small_model(conv)[:100] generate_response_with_llama = generate_with_llama_70b_hf generate_clinical_knowledge_with_355m = lambda conv: f"Knowledge: {conv[:100]}..." generate_with_355m = lambda conv, rag="", hf_token=None: generate_clinical_response_with_xupract(conv, rag, hf_token) # Ensure we override the slow ranking function rank_trials_with_355m = rank_trials_FAST logger.info("Fast Two-LLM System loaded - 355M ranking bypassed!")