Evo-App / app.py
sochasticbackup's picture
added model support and caching
784595b
"""
Evo Model Web Interface
A simple Gradio app for testing Evo's predictive and generative capabilities.
"""
import gradio as gr
import torch
import numpy as np
from evo import Evo
from evo.scoring import score_sequences
from evo.generation import generate
from typing import List, Tuple, Dict
import io
import sys
import os
from pathlib import Path
# Add setup for HuggingFace cache
sys.path.insert(0, str(Path(__file__).parent))
# Global model variables
model = None
tokenizer = None
device = "cuda:0" if torch.cuda.is_available() else "cpu"
def setup_hf_cache():
"""Setup HuggingFace cache with tokenizer files BEFORE first download."""
import shutil
from pathlib import Path
import sys
# First, ensure stripedhyena is in path
app_dir = Path(__file__).parent
if str(app_dir) not in sys.path:
sys.path.insert(0, str(app_dir))
# Now we can import it
try:
import stripedhyena
stripedhyena_path = Path(stripedhyena.__file__).parent
except ImportError:
# If import fails, use direct path
stripedhyena_path = app_dir / "stripedhyena"
local_tokenizer = stripedhyena_path / "tokenizer.py"
local_utils = stripedhyena_path / "utils.py"
if not local_tokenizer.exists():
print(f"Warning: tokenizer not found at {local_tokenizer}")
return
# Pre-create the HF cache directories and add tokenizer
hf_cache = Path.home() / ".cache" / "huggingface" / "modules" / "transformers_modules"
model_dirs = [
"togethercomputer/evo-1-8k-base",
"togethercomputer/evo-1-131k-base"
]
for model_dir in model_dirs:
model_path = hf_cache / model_dir
if model_path.exists():
# Model already downloaded, fix existing versions
for version_dir in model_path.iterdir():
if version_dir.is_dir():
try:
shutil.copy2(local_tokenizer, version_dir / "tokenizer.py")
shutil.copy2(local_utils, version_dir / "utils.py")
print(f"✓ Fixed tokenizer in {model_dir}/{version_dir.name}")
except Exception as e:
print(f"Warning: Could not copy to {version_dir}: {e}")
def load_model():
"""Load Evo model once at startup."""
global model, tokenizer
if model is None:
print("Loading Evo model...")
# Setup HF cache BEFORE loading model
try:
setup_hf_cache()
except Exception as e:
print(f"Warning: Could not setup HF cache: {e}")
evo_model = Evo('evo-1-8k-base')
# Fix cache again AFTER download (in case it just downloaded)
try:
setup_hf_cache()
except Exception as e:
print(f"Warning: Could not fix HF cache after download: {e}")
model, tokenizer = evo_model.model, evo_model.tokenizer
model.to(device)
model.eval()
print("✓ Model loaded successfully")
# ============================================================================
# TASK 1: Function Prediction
# ============================================================================
def detect_sequence_type(seq: str) -> str:
"""Detect if sequence is DNA, RNA, or protein."""
seq_upper = seq.upper()
if any(c in set('EFILPQZ') for c in seq_upper):
return 'protein'
if 'U' in seq_upper:
return 'RNA'
if all(c in set('ACGTN') for c in seq_upper):
return 'DNA'
return 'unknown'
def parse_fasta_text(text: str) -> List[Tuple[str, str]]:
"""Parse FASTA format text into (id, sequence) tuples."""
sequences = []
current_id = None
current_seq = []
for line in text.strip().split('\n'):
line = line.strip()
if line.startswith('>'):
if current_id is not None:
sequences.append((current_id, ''.join(current_seq)))
current_id = line[1:].split('|')[0].strip()
current_seq = []
else:
current_seq.append(line)
if current_id is not None:
sequences.append((current_id, ''.join(current_seq)))
return sequences
def predict_function(sequences_text: str, threshold: float) -> str:
"""Predict sequence functionality."""
load_model()
if not sequences_text.strip():
return "⚠️ Please enter sequences in FASTA format or paste sequences directly."
# Parse input
if sequences_text.startswith('>'):
# FASTA format
seq_data = parse_fasta_text(sequences_text)
else:
# Single sequence
seq_data = [("sequence_1", sequences_text.strip().replace('\n', ''))]
if not seq_data:
return "⚠️ No valid sequences found."
# Score sequences
sequences = [seq for _, seq in seq_data]
scores = score_sequences(sequences, model, tokenizer, reduce_method='mean', device=device)
# Format results
results = ["# Function Prediction Results\n"]
results.append(f"{'Sequence ID':<20} {'Type':<10} {'Score':<12} {'Prediction':<15} {'Length':<10}")
results.append("-" * 70)
for (seq_id, seq), score in zip(seq_data, scores):
seq_type = detect_sequence_type(seq)
prediction = "✓ Functional" if score > threshold else "✗ Non-functional"
results.append(f"{seq_id:<20} {seq_type:<10} {score:<12.4f} {prediction:<15} {len(seq):<10}")
results.append("\n" + "=" * 70)
results.append(f"Total sequences: {len(seq_data)}")
results.append(f"Functional: {sum(1 for s in scores if s > threshold)}")
results.append(f"Non-functional: {sum(1 for s in scores if s <= threshold)}")
results.append(f"Average score: {np.mean(scores):.4f}")
return "\n".join(results)
# ============================================================================
# TASK 2: Gene Essentiality
# ============================================================================
def predict_essentiality(genes_text: str) -> str:
"""Predict gene essentiality."""
load_model()
if not genes_text.strip():
return "⚠️ Please enter gene sequences in FASTA format."
# Parse FASTA
if not genes_text.startswith('>'):
return "⚠️ Please use FASTA format: >gene_id|organism|function\\nATGC..."
gene_data = parse_fasta_text(genes_text)
if not gene_data:
return "⚠️ No valid genes found."
# Score genes
sequences = [seq for _, seq in gene_data]
scores = score_sequences(sequences, model, tokenizer, reduce_method='mean', device=device)
# Calculate statistics
scores_mean = np.mean(scores)
scores_std = np.std(scores)
# Format results
results = ["# Gene Essentiality Prediction\n"]
results.append(f"{'Gene ID':<20} {'Z-Score':<10} {'Score':<12} {'Essentiality':<15} {'Confidence':<12}")
results.append("-" * 70)
essential_count = 0
for (gene_id, seq), score in zip(gene_data, scores):
z_score = (score - scores_mean) / scores_std if scores_std > 0 else 0
if z_score > 0.5:
essentiality = "✓ Essential"
confidence = "High" if z_score > 1.0 else "Medium"
essential_count += 1
elif z_score < -0.5:
essentiality = "✗ Non-essential"
confidence = "High" if z_score < -1.0 else "Medium"
else:
essentiality = "? Uncertain"
confidence = "Low"
results.append(f"{gene_id:<20} {z_score:<10.2f} {score:<12.4f} {essentiality:<15} {confidence:<12}")
results.append("\n" + "=" * 70)
results.append(f"Total genes: {len(gene_data)}")
results.append(f"Essential: {essential_count}")
results.append(f"Mean score: {scores_mean:.4f} (std: {scores_std:.4f})")
return "\n".join(results)
# ============================================================================
# TASK 3: CRISPR Generation
# ============================================================================
def generate_crispr(n_systems: int, cas_type: str, target_seq: str, cas_length: int) -> str:
"""Generate CRISPR-Cas systems."""
load_model()
# Templates
cas9_start = 'ATGAACAAGAAC'
cas12_start = 'ATGAGCAAGCTG'
results = ["# CRISPR-Cas System Generation\n"]
cas_types = ['cas9', 'cas12'] if cas_type == 'Both' else [cas_type.lower()]
for i in range(n_systems):
current_cas = cas_types[i % len(cas_types)]
prompt = cas9_start if current_cas == 'cas9' else cas12_start
results.append(f"\n{'='*70}")
results.append(f"System {i+1}: {current_cas.upper()}")
results.append('='*70)
# Generate Cas protein
output_seqs, _ = generate(
[prompt],
model,
tokenizer,
n_tokens=cas_length,
temperature=0.8,
top_k=4,
device=device,
verbose=0
)
cas_protein = output_seqs[0]
# Generate gRNA spacer
if target_seq:
complement = {'A': 'U', 'T': 'A', 'G': 'C', 'C': 'G'}
spacer = ''.join(complement.get(b, 'N') for b in reversed(target_seq[:20]))
else:
spacer_seqs, _ = generate(['G'], model, tokenizer, n_tokens=19, temperature=0.7,
top_k=4, device=device, verbose=0)
spacer = spacer_seqs[0][:20].replace('T', 'U')
# PAM sequence
pam = 'NGG' if current_cas == 'cas9' else 'TTTN'
results.append(f"\n{current_cas.upper()} Protein ({len(cas_protein)} nt):")
results.append(f"{cas_protein[:80]}..." if len(cas_protein) > 80 else cas_protein)
results.append(f"\ngRNA Spacer: {spacer}")
results.append(f"PAM Sequence: {pam}")
if current_cas == 'cas9':
results.append(f"tracrRNA: AGCAUAGCAAGUUAAAAUAAGGCUAGUCCGU")
return "\n".join(results)
# ============================================================================
# TASK 4: Regulatory Design
# ============================================================================
def generate_spacer_simple(length: int) -> str:
"""Generate a simple random spacer."""
bases = ['A', 'T', 'G', 'C']
return ''.join(np.random.choice(bases) for _ in range(length))
def design_regulatory(n_designs: int, expression_level: str) -> str:
"""Design regulatory sequences."""
load_model()
# Templates
promoter_templates = {
'High': ('TTGACA', 'TATAAT'),
'Medium': ('TTGACT', 'TATACT'),
'Low': ('TTGCCA', 'TATGAT')
}
rbs_templates = {
'High': 'AGGAGGU',
'Medium': 'AGGAGG',
'Low': 'AGGA'
}
results = ["# Regulatory Sequences Design\n"]
levels = ['High', 'Medium', 'Low']
for i in range(n_designs):
if expression_level == 'Mixed':
level = levels[i % 3]
else:
level = expression_level
results.append(f"\n{'='*70}")
results.append(f"Design {i+1}: {level} Expression")
results.append('='*70)
# Get promoter boxes
box_35, box_10 = promoter_templates[level]
# Generate spacers
spacer_35_10 = generate_spacer_simple(17)
spacer_10_rbs = generate_spacer_simple(7)
# Get RBS
rbs = rbs_templates[level]
# Generate RBS-ATG spacer
spacer_rbs_atg = generate_spacer_simple(7)
# Assemble
promoter = box_35 + spacer_35_10 + box_10
full_region = promoter + spacer_10_rbs + rbs + spacer_rbs_atg + 'ATG'
gc_content = 100 * (full_region.count('G') + full_region.count('C')) / len(full_region)
results.append(f"\nComponents:")
results.append(f" -35 box: {box_35}")
results.append(f" -10 box: {box_10}")
results.append(f" RBS (Shine-Dalgarno): {rbs}")
results.append(f" Start codon: ATG")
results.append(f"\nFull Regulatory Region ({len(full_region)} bp, GC={gc_content:.1f}%):")
results.append(full_region)
results.append(f"\nPromoter only:")
results.append(promoter)
return "\n".join(results)
# ============================================================================
# Gradio Interface
# ============================================================================
def create_interface():
"""Create the Gradio interface."""
with gr.Blocks(title="Evo Model Interface", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧬 Evo Model Interface")
gr.Markdown("### Test Evo's predictive and generative capabilities")
with gr.Tabs():
# Task 1: Function Prediction
with gr.Tab("🔍 Function Prediction"):
gr.Markdown("### Predict if sequences are functional")
gr.Markdown("*Enter sequences in FASTA format or paste a single sequence*")
with gr.Row():
with gr.Column():
func_input = gr.Textbox(
label="Input Sequences",
placeholder=">seq1|description\nATCGATCGATCG...\n\nOr paste a single sequence directly",
lines=8
)
func_threshold = gr.Slider(
minimum=-3.0,
maximum=0.0,
value=-1.5,
step=0.1,
label="Functionality Threshold"
)
func_btn = gr.Button("Predict Function", variant="primary")
with gr.Column():
func_output = gr.Textbox(
label="Results",
lines=15,
show_copy_button=True
)
func_btn.click(
fn=predict_function,
inputs=[func_input, func_threshold],
outputs=func_output
)
gr.Examples(
examples=[
[">functional_gene\nATGGCACAACCCGCGCCGAACTGGTTGACCTGAAAACCACCGCCGCACTGCGTCAGGCCAGCCAGGCGGAACAA", -1.5],
[">noncoding\nGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC", -1.5],
],
inputs=[func_input, func_threshold]
)
# Task 2: Gene Essentiality
with gr.Tab("🧬 Gene Essentiality"):
gr.Markdown("### Predict essential genes in bacteria/phages")
gr.Markdown("*Input format: >gene_id|organism|function*")
with gr.Row():
with gr.Column():
ess_input = gr.Textbox(
label="Gene Sequences (FASTA)",
placeholder=">dnaA|E.coli|Replication initiator\nATGTCGAAAGCCGCAT...",
lines=8
)
ess_btn = gr.Button("Predict Essentiality", variant="primary")
with gr.Column():
ess_output = gr.Textbox(
label="Results",
lines=15,
show_copy_button=True
)
ess_btn.click(
fn=predict_essentiality,
inputs=ess_input,
outputs=ess_output
)
# Task 3: CRISPR Generation
with gr.Tab("✂️ CRISPR Generation"):
gr.Markdown("### Generate synthetic CRISPR-Cas systems")
with gr.Row():
with gr.Column():
crispr_n = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Number of Systems"
)
crispr_type = gr.Radio(
choices=["Cas9", "Cas12", "Both"],
value="Both",
label="Cas Type"
)
crispr_target = gr.Textbox(
label="Target Sequence (optional)",
placeholder="ATCGATCGATCGATCG",
lines=2
)
crispr_length = gr.Slider(
minimum=500,
maximum=2000,
value=1000,
step=100,
label="Cas Protein Length"
)
crispr_btn = gr.Button("Generate CRISPR Systems", variant="primary")
with gr.Column():
crispr_output = gr.Textbox(
label="Generated Systems",
lines=15,
show_copy_button=True
)
crispr_btn.click(
fn=generate_crispr,
inputs=[crispr_n, crispr_type, crispr_target, crispr_length],
outputs=crispr_output
)
# Task 4: Regulatory Design
with gr.Tab("🎛️ Regulatory Design"):
gr.Markdown("### Design promoter-RBS pairs for gene expression")
with gr.Row():
with gr.Column():
reg_n = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of Designs"
)
reg_level = gr.Radio(
choices=["High", "Medium", "Low", "Mixed"],
value="Mixed",
label="Expression Level"
)
reg_btn = gr.Button("Design Regulatory Sequences", variant="primary")
with gr.Column():
reg_output = gr.Textbox(
label="Designed Sequences",
lines=15,
show_copy_button=True
)
reg_btn.click(
fn=design_regulatory,
inputs=[reg_n, reg_level],
outputs=reg_output
)
gr.Markdown("---")
gr.Markdown("💡 **Tips:** Higher scores = more functional/essential | All outputs can be copied | Model: evo-1-8k-base")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()