""" Evo Model Web Interface A simple Gradio app for testing Evo's predictive and generative capabilities. """ import gradio as gr import torch import numpy as np from evo import Evo from evo.scoring import score_sequences from evo.generation import generate from typing import List, Tuple, Dict import io import sys import os from pathlib import Path # Add setup for HuggingFace cache sys.path.insert(0, str(Path(__file__).parent)) # Global model variables model = None tokenizer = None device = "cuda:0" if torch.cuda.is_available() else "cpu" def setup_hf_cache(): """Setup HuggingFace cache with tokenizer files BEFORE first download.""" import shutil from pathlib import Path import sys # First, ensure stripedhyena is in path app_dir = Path(__file__).parent if str(app_dir) not in sys.path: sys.path.insert(0, str(app_dir)) # Now we can import it try: import stripedhyena stripedhyena_path = Path(stripedhyena.__file__).parent except ImportError: # If import fails, use direct path stripedhyena_path = app_dir / "stripedhyena" local_tokenizer = stripedhyena_path / "tokenizer.py" local_utils = stripedhyena_path / "utils.py" if not local_tokenizer.exists(): print(f"Warning: tokenizer not found at {local_tokenizer}") return # Pre-create the HF cache directories and add tokenizer hf_cache = Path.home() / ".cache" / "huggingface" / "modules" / "transformers_modules" model_dirs = [ "togethercomputer/evo-1-8k-base", "togethercomputer/evo-1-131k-base" ] for model_dir in model_dirs: model_path = hf_cache / model_dir if model_path.exists(): # Model already downloaded, fix existing versions for version_dir in model_path.iterdir(): if version_dir.is_dir(): try: shutil.copy2(local_tokenizer, version_dir / "tokenizer.py") shutil.copy2(local_utils, version_dir / "utils.py") print(f"✓ Fixed tokenizer in {model_dir}/{version_dir.name}") except Exception as e: print(f"Warning: Could not copy to {version_dir}: {e}") def load_model(): """Load Evo model once at startup.""" global model, tokenizer if model is None: print("Loading Evo model...") # Setup HF cache BEFORE loading model try: setup_hf_cache() except Exception as e: print(f"Warning: Could not setup HF cache: {e}") evo_model = Evo('evo-1-8k-base') # Fix cache again AFTER download (in case it just downloaded) try: setup_hf_cache() except Exception as e: print(f"Warning: Could not fix HF cache after download: {e}") model, tokenizer = evo_model.model, evo_model.tokenizer model.to(device) model.eval() print("✓ Model loaded successfully") # ============================================================================ # TASK 1: Function Prediction # ============================================================================ def detect_sequence_type(seq: str) -> str: """Detect if sequence is DNA, RNA, or protein.""" seq_upper = seq.upper() if any(c in set('EFILPQZ') for c in seq_upper): return 'protein' if 'U' in seq_upper: return 'RNA' if all(c in set('ACGTN') for c in seq_upper): return 'DNA' return 'unknown' def parse_fasta_text(text: str) -> List[Tuple[str, str]]: """Parse FASTA format text into (id, sequence) tuples.""" sequences = [] current_id = None current_seq = [] for line in text.strip().split('\n'): line = line.strip() if line.startswith('>'): if current_id is not None: sequences.append((current_id, ''.join(current_seq))) current_id = line[1:].split('|')[0].strip() current_seq = [] else: current_seq.append(line) if current_id is not None: sequences.append((current_id, ''.join(current_seq))) return sequences def predict_function(sequences_text: str, threshold: float) -> str: """Predict sequence functionality.""" load_model() if not sequences_text.strip(): return "⚠️ Please enter sequences in FASTA format or paste sequences directly." # Parse input if sequences_text.startswith('>'): # FASTA format seq_data = parse_fasta_text(sequences_text) else: # Single sequence seq_data = [("sequence_1", sequences_text.strip().replace('\n', ''))] if not seq_data: return "⚠️ No valid sequences found." # Score sequences sequences = [seq for _, seq in seq_data] scores = score_sequences(sequences, model, tokenizer, reduce_method='mean', device=device) # Format results results = ["# Function Prediction Results\n"] results.append(f"{'Sequence ID':<20} {'Type':<10} {'Score':<12} {'Prediction':<15} {'Length':<10}") results.append("-" * 70) for (seq_id, seq), score in zip(seq_data, scores): seq_type = detect_sequence_type(seq) prediction = "✓ Functional" if score > threshold else "✗ Non-functional" results.append(f"{seq_id:<20} {seq_type:<10} {score:<12.4f} {prediction:<15} {len(seq):<10}") results.append("\n" + "=" * 70) results.append(f"Total sequences: {len(seq_data)}") results.append(f"Functional: {sum(1 for s in scores if s > threshold)}") results.append(f"Non-functional: {sum(1 for s in scores if s <= threshold)}") results.append(f"Average score: {np.mean(scores):.4f}") return "\n".join(results) # ============================================================================ # TASK 2: Gene Essentiality # ============================================================================ def predict_essentiality(genes_text: str) -> str: """Predict gene essentiality.""" load_model() if not genes_text.strip(): return "⚠️ Please enter gene sequences in FASTA format." # Parse FASTA if not genes_text.startswith('>'): return "⚠️ Please use FASTA format: >gene_id|organism|function\\nATGC..." gene_data = parse_fasta_text(genes_text) if not gene_data: return "⚠️ No valid genes found." # Score genes sequences = [seq for _, seq in gene_data] scores = score_sequences(sequences, model, tokenizer, reduce_method='mean', device=device) # Calculate statistics scores_mean = np.mean(scores) scores_std = np.std(scores) # Format results results = ["# Gene Essentiality Prediction\n"] results.append(f"{'Gene ID':<20} {'Z-Score':<10} {'Score':<12} {'Essentiality':<15} {'Confidence':<12}") results.append("-" * 70) essential_count = 0 for (gene_id, seq), score in zip(gene_data, scores): z_score = (score - scores_mean) / scores_std if scores_std > 0 else 0 if z_score > 0.5: essentiality = "✓ Essential" confidence = "High" if z_score > 1.0 else "Medium" essential_count += 1 elif z_score < -0.5: essentiality = "✗ Non-essential" confidence = "High" if z_score < -1.0 else "Medium" else: essentiality = "? Uncertain" confidence = "Low" results.append(f"{gene_id:<20} {z_score:<10.2f} {score:<12.4f} {essentiality:<15} {confidence:<12}") results.append("\n" + "=" * 70) results.append(f"Total genes: {len(gene_data)}") results.append(f"Essential: {essential_count}") results.append(f"Mean score: {scores_mean:.4f} (std: {scores_std:.4f})") return "\n".join(results) # ============================================================================ # TASK 3: CRISPR Generation # ============================================================================ def generate_crispr(n_systems: int, cas_type: str, target_seq: str, cas_length: int) -> str: """Generate CRISPR-Cas systems.""" load_model() # Templates cas9_start = 'ATGAACAAGAAC' cas12_start = 'ATGAGCAAGCTG' results = ["# CRISPR-Cas System Generation\n"] cas_types = ['cas9', 'cas12'] if cas_type == 'Both' else [cas_type.lower()] for i in range(n_systems): current_cas = cas_types[i % len(cas_types)] prompt = cas9_start if current_cas == 'cas9' else cas12_start results.append(f"\n{'='*70}") results.append(f"System {i+1}: {current_cas.upper()}") results.append('='*70) # Generate Cas protein output_seqs, _ = generate( [prompt], model, tokenizer, n_tokens=cas_length, temperature=0.8, top_k=4, device=device, verbose=0 ) cas_protein = output_seqs[0] # Generate gRNA spacer if target_seq: complement = {'A': 'U', 'T': 'A', 'G': 'C', 'C': 'G'} spacer = ''.join(complement.get(b, 'N') for b in reversed(target_seq[:20])) else: spacer_seqs, _ = generate(['G'], model, tokenizer, n_tokens=19, temperature=0.7, top_k=4, device=device, verbose=0) spacer = spacer_seqs[0][:20].replace('T', 'U') # PAM sequence pam = 'NGG' if current_cas == 'cas9' else 'TTTN' results.append(f"\n{current_cas.upper()} Protein ({len(cas_protein)} nt):") results.append(f"{cas_protein[:80]}..." if len(cas_protein) > 80 else cas_protein) results.append(f"\ngRNA Spacer: {spacer}") results.append(f"PAM Sequence: {pam}") if current_cas == 'cas9': results.append(f"tracrRNA: AGCAUAGCAAGUUAAAAUAAGGCUAGUCCGU") return "\n".join(results) # ============================================================================ # TASK 4: Regulatory Design # ============================================================================ def generate_spacer_simple(length: int) -> str: """Generate a simple random spacer.""" bases = ['A', 'T', 'G', 'C'] return ''.join(np.random.choice(bases) for _ in range(length)) def design_regulatory(n_designs: int, expression_level: str) -> str: """Design regulatory sequences.""" load_model() # Templates promoter_templates = { 'High': ('TTGACA', 'TATAAT'), 'Medium': ('TTGACT', 'TATACT'), 'Low': ('TTGCCA', 'TATGAT') } rbs_templates = { 'High': 'AGGAGGU', 'Medium': 'AGGAGG', 'Low': 'AGGA' } results = ["# Regulatory Sequences Design\n"] levels = ['High', 'Medium', 'Low'] for i in range(n_designs): if expression_level == 'Mixed': level = levels[i % 3] else: level = expression_level results.append(f"\n{'='*70}") results.append(f"Design {i+1}: {level} Expression") results.append('='*70) # Get promoter boxes box_35, box_10 = promoter_templates[level] # Generate spacers spacer_35_10 = generate_spacer_simple(17) spacer_10_rbs = generate_spacer_simple(7) # Get RBS rbs = rbs_templates[level] # Generate RBS-ATG spacer spacer_rbs_atg = generate_spacer_simple(7) # Assemble promoter = box_35 + spacer_35_10 + box_10 full_region = promoter + spacer_10_rbs + rbs + spacer_rbs_atg + 'ATG' gc_content = 100 * (full_region.count('G') + full_region.count('C')) / len(full_region) results.append(f"\nComponents:") results.append(f" -35 box: {box_35}") results.append(f" -10 box: {box_10}") results.append(f" RBS (Shine-Dalgarno): {rbs}") results.append(f" Start codon: ATG") results.append(f"\nFull Regulatory Region ({len(full_region)} bp, GC={gc_content:.1f}%):") results.append(full_region) results.append(f"\nPromoter only:") results.append(promoter) return "\n".join(results) # ============================================================================ # Gradio Interface # ============================================================================ def create_interface(): """Create the Gradio interface.""" with gr.Blocks(title="Evo Model Interface", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧬 Evo Model Interface") gr.Markdown("### Test Evo's predictive and generative capabilities") with gr.Tabs(): # Task 1: Function Prediction with gr.Tab("🔍 Function Prediction"): gr.Markdown("### Predict if sequences are functional") gr.Markdown("*Enter sequences in FASTA format or paste a single sequence*") with gr.Row(): with gr.Column(): func_input = gr.Textbox( label="Input Sequences", placeholder=">seq1|description\nATCGATCGATCG...\n\nOr paste a single sequence directly", lines=8 ) func_threshold = gr.Slider( minimum=-3.0, maximum=0.0, value=-1.5, step=0.1, label="Functionality Threshold" ) func_btn = gr.Button("Predict Function", variant="primary") with gr.Column(): func_output = gr.Textbox( label="Results", lines=15, show_copy_button=True ) func_btn.click( fn=predict_function, inputs=[func_input, func_threshold], outputs=func_output ) gr.Examples( examples=[ [">functional_gene\nATGGCACAACCCGCGCCGAACTGGTTGACCTGAAAACCACCGCCGCACTGCGTCAGGCCAGCCAGGCGGAACAA", -1.5], [">noncoding\nGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC", -1.5], ], inputs=[func_input, func_threshold] ) # Task 2: Gene Essentiality with gr.Tab("🧬 Gene Essentiality"): gr.Markdown("### Predict essential genes in bacteria/phages") gr.Markdown("*Input format: >gene_id|organism|function*") with gr.Row(): with gr.Column(): ess_input = gr.Textbox( label="Gene Sequences (FASTA)", placeholder=">dnaA|E.coli|Replication initiator\nATGTCGAAAGCCGCAT...", lines=8 ) ess_btn = gr.Button("Predict Essentiality", variant="primary") with gr.Column(): ess_output = gr.Textbox( label="Results", lines=15, show_copy_button=True ) ess_btn.click( fn=predict_essentiality, inputs=ess_input, outputs=ess_output ) # Task 3: CRISPR Generation with gr.Tab("✂️ CRISPR Generation"): gr.Markdown("### Generate synthetic CRISPR-Cas systems") with gr.Row(): with gr.Column(): crispr_n = gr.Slider( minimum=1, maximum=5, value=2, step=1, label="Number of Systems" ) crispr_type = gr.Radio( choices=["Cas9", "Cas12", "Both"], value="Both", label="Cas Type" ) crispr_target = gr.Textbox( label="Target Sequence (optional)", placeholder="ATCGATCGATCGATCG", lines=2 ) crispr_length = gr.Slider( minimum=500, maximum=2000, value=1000, step=100, label="Cas Protein Length" ) crispr_btn = gr.Button("Generate CRISPR Systems", variant="primary") with gr.Column(): crispr_output = gr.Textbox( label="Generated Systems", lines=15, show_copy_button=True ) crispr_btn.click( fn=generate_crispr, inputs=[crispr_n, crispr_type, crispr_target, crispr_length], outputs=crispr_output ) # Task 4: Regulatory Design with gr.Tab("🎛️ Regulatory Design"): gr.Markdown("### Design promoter-RBS pairs for gene expression") with gr.Row(): with gr.Column(): reg_n = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of Designs" ) reg_level = gr.Radio( choices=["High", "Medium", "Low", "Mixed"], value="Mixed", label="Expression Level" ) reg_btn = gr.Button("Design Regulatory Sequences", variant="primary") with gr.Column(): reg_output = gr.Textbox( label="Designed Sequences", lines=15, show_copy_button=True ) reg_btn.click( fn=design_regulatory, inputs=[reg_n, reg_level], outputs=reg_output ) gr.Markdown("---") gr.Markdown("💡 **Tips:** Higher scores = more functional/essential | All outputs can be copied | Model: evo-1-8k-base") return demo if __name__ == "__main__": demo = create_interface() demo.launch()