|
|
""" |
|
|
Evo Model Web Interface |
|
|
A simple Gradio app for testing Evo's predictive and generative capabilities. |
|
|
""" |
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from evo import Evo |
|
|
from evo.scoring import score_sequences |
|
|
from evo.generation import generate |
|
|
from typing import List, Tuple, Dict |
|
|
import io |
|
|
import sys |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
def setup_hf_cache(): |
|
|
"""Setup HuggingFace cache with tokenizer files BEFORE first download.""" |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
import sys |
|
|
|
|
|
|
|
|
app_dir = Path(__file__).parent |
|
|
if str(app_dir) not in sys.path: |
|
|
sys.path.insert(0, str(app_dir)) |
|
|
|
|
|
|
|
|
try: |
|
|
import stripedhyena |
|
|
stripedhyena_path = Path(stripedhyena.__file__).parent |
|
|
except ImportError: |
|
|
|
|
|
stripedhyena_path = app_dir / "stripedhyena" |
|
|
|
|
|
local_tokenizer = stripedhyena_path / "tokenizer.py" |
|
|
local_utils = stripedhyena_path / "utils.py" |
|
|
|
|
|
if not local_tokenizer.exists(): |
|
|
print(f"Warning: tokenizer not found at {local_tokenizer}") |
|
|
return |
|
|
|
|
|
|
|
|
hf_cache = Path.home() / ".cache" / "huggingface" / "modules" / "transformers_modules" |
|
|
|
|
|
model_dirs = [ |
|
|
"togethercomputer/evo-1-8k-base", |
|
|
"togethercomputer/evo-1-131k-base" |
|
|
] |
|
|
|
|
|
for model_dir in model_dirs: |
|
|
model_path = hf_cache / model_dir |
|
|
if model_path.exists(): |
|
|
|
|
|
for version_dir in model_path.iterdir(): |
|
|
if version_dir.is_dir(): |
|
|
try: |
|
|
shutil.copy2(local_tokenizer, version_dir / "tokenizer.py") |
|
|
shutil.copy2(local_utils, version_dir / "utils.py") |
|
|
print(f"✓ Fixed tokenizer in {model_dir}/{version_dir.name}") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not copy to {version_dir}: {e}") |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""Load Evo model once at startup.""" |
|
|
global model, tokenizer |
|
|
if model is None: |
|
|
print("Loading Evo model...") |
|
|
|
|
|
|
|
|
try: |
|
|
setup_hf_cache() |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not setup HF cache: {e}") |
|
|
|
|
|
evo_model = Evo('evo-1-8k-base') |
|
|
|
|
|
|
|
|
try: |
|
|
setup_hf_cache() |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not fix HF cache after download: {e}") |
|
|
|
|
|
model, tokenizer = evo_model.model, evo_model.tokenizer |
|
|
model.to(device) |
|
|
model.eval() |
|
|
print("✓ Model loaded successfully") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_sequence_type(seq: str) -> str: |
|
|
"""Detect if sequence is DNA, RNA, or protein.""" |
|
|
seq_upper = seq.upper() |
|
|
if any(c in set('EFILPQZ') for c in seq_upper): |
|
|
return 'protein' |
|
|
if 'U' in seq_upper: |
|
|
return 'RNA' |
|
|
if all(c in set('ACGTN') for c in seq_upper): |
|
|
return 'DNA' |
|
|
return 'unknown' |
|
|
|
|
|
|
|
|
def parse_fasta_text(text: str) -> List[Tuple[str, str]]: |
|
|
"""Parse FASTA format text into (id, sequence) tuples.""" |
|
|
sequences = [] |
|
|
current_id = None |
|
|
current_seq = [] |
|
|
|
|
|
for line in text.strip().split('\n'): |
|
|
line = line.strip() |
|
|
if line.startswith('>'): |
|
|
if current_id is not None: |
|
|
sequences.append((current_id, ''.join(current_seq))) |
|
|
current_id = line[1:].split('|')[0].strip() |
|
|
current_seq = [] |
|
|
else: |
|
|
current_seq.append(line) |
|
|
|
|
|
if current_id is not None: |
|
|
sequences.append((current_id, ''.join(current_seq))) |
|
|
|
|
|
return sequences |
|
|
|
|
|
|
|
|
def predict_function(sequences_text: str, threshold: float) -> str: |
|
|
"""Predict sequence functionality.""" |
|
|
load_model() |
|
|
|
|
|
if not sequences_text.strip(): |
|
|
return "⚠️ Please enter sequences in FASTA format or paste sequences directly." |
|
|
|
|
|
|
|
|
if sequences_text.startswith('>'): |
|
|
|
|
|
seq_data = parse_fasta_text(sequences_text) |
|
|
else: |
|
|
|
|
|
seq_data = [("sequence_1", sequences_text.strip().replace('\n', ''))] |
|
|
|
|
|
if not seq_data: |
|
|
return "⚠️ No valid sequences found." |
|
|
|
|
|
|
|
|
sequences = [seq for _, seq in seq_data] |
|
|
scores = score_sequences(sequences, model, tokenizer, reduce_method='mean', device=device) |
|
|
|
|
|
|
|
|
results = ["# Function Prediction Results\n"] |
|
|
results.append(f"{'Sequence ID':<20} {'Type':<10} {'Score':<12} {'Prediction':<15} {'Length':<10}") |
|
|
results.append("-" * 70) |
|
|
|
|
|
for (seq_id, seq), score in zip(seq_data, scores): |
|
|
seq_type = detect_sequence_type(seq) |
|
|
prediction = "✓ Functional" if score > threshold else "✗ Non-functional" |
|
|
results.append(f"{seq_id:<20} {seq_type:<10} {score:<12.4f} {prediction:<15} {len(seq):<10}") |
|
|
|
|
|
results.append("\n" + "=" * 70) |
|
|
results.append(f"Total sequences: {len(seq_data)}") |
|
|
results.append(f"Functional: {sum(1 for s in scores if s > threshold)}") |
|
|
results.append(f"Non-functional: {sum(1 for s in scores if s <= threshold)}") |
|
|
results.append(f"Average score: {np.mean(scores):.4f}") |
|
|
|
|
|
return "\n".join(results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_essentiality(genes_text: str) -> str: |
|
|
"""Predict gene essentiality.""" |
|
|
load_model() |
|
|
|
|
|
if not genes_text.strip(): |
|
|
return "⚠️ Please enter gene sequences in FASTA format." |
|
|
|
|
|
|
|
|
if not genes_text.startswith('>'): |
|
|
return "⚠️ Please use FASTA format: >gene_id|organism|function\\nATGC..." |
|
|
|
|
|
gene_data = parse_fasta_text(genes_text) |
|
|
if not gene_data: |
|
|
return "⚠️ No valid genes found." |
|
|
|
|
|
|
|
|
sequences = [seq for _, seq in gene_data] |
|
|
scores = score_sequences(sequences, model, tokenizer, reduce_method='mean', device=device) |
|
|
|
|
|
|
|
|
scores_mean = np.mean(scores) |
|
|
scores_std = np.std(scores) |
|
|
|
|
|
|
|
|
results = ["# Gene Essentiality Prediction\n"] |
|
|
results.append(f"{'Gene ID':<20} {'Z-Score':<10} {'Score':<12} {'Essentiality':<15} {'Confidence':<12}") |
|
|
results.append("-" * 70) |
|
|
|
|
|
essential_count = 0 |
|
|
for (gene_id, seq), score in zip(gene_data, scores): |
|
|
z_score = (score - scores_mean) / scores_std if scores_std > 0 else 0 |
|
|
|
|
|
if z_score > 0.5: |
|
|
essentiality = "✓ Essential" |
|
|
confidence = "High" if z_score > 1.0 else "Medium" |
|
|
essential_count += 1 |
|
|
elif z_score < -0.5: |
|
|
essentiality = "✗ Non-essential" |
|
|
confidence = "High" if z_score < -1.0 else "Medium" |
|
|
else: |
|
|
essentiality = "? Uncertain" |
|
|
confidence = "Low" |
|
|
|
|
|
results.append(f"{gene_id:<20} {z_score:<10.2f} {score:<12.4f} {essentiality:<15} {confidence:<12}") |
|
|
|
|
|
results.append("\n" + "=" * 70) |
|
|
results.append(f"Total genes: {len(gene_data)}") |
|
|
results.append(f"Essential: {essential_count}") |
|
|
results.append(f"Mean score: {scores_mean:.4f} (std: {scores_std:.4f})") |
|
|
|
|
|
return "\n".join(results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_crispr(n_systems: int, cas_type: str, target_seq: str, cas_length: int) -> str: |
|
|
"""Generate CRISPR-Cas systems.""" |
|
|
load_model() |
|
|
|
|
|
|
|
|
cas9_start = 'ATGAACAAGAAC' |
|
|
cas12_start = 'ATGAGCAAGCTG' |
|
|
|
|
|
results = ["# CRISPR-Cas System Generation\n"] |
|
|
|
|
|
cas_types = ['cas9', 'cas12'] if cas_type == 'Both' else [cas_type.lower()] |
|
|
|
|
|
for i in range(n_systems): |
|
|
current_cas = cas_types[i % len(cas_types)] |
|
|
prompt = cas9_start if current_cas == 'cas9' else cas12_start |
|
|
|
|
|
results.append(f"\n{'='*70}") |
|
|
results.append(f"System {i+1}: {current_cas.upper()}") |
|
|
results.append('='*70) |
|
|
|
|
|
|
|
|
output_seqs, _ = generate( |
|
|
[prompt], |
|
|
model, |
|
|
tokenizer, |
|
|
n_tokens=cas_length, |
|
|
temperature=0.8, |
|
|
top_k=4, |
|
|
device=device, |
|
|
verbose=0 |
|
|
) |
|
|
cas_protein = output_seqs[0] |
|
|
|
|
|
|
|
|
if target_seq: |
|
|
complement = {'A': 'U', 'T': 'A', 'G': 'C', 'C': 'G'} |
|
|
spacer = ''.join(complement.get(b, 'N') for b in reversed(target_seq[:20])) |
|
|
else: |
|
|
spacer_seqs, _ = generate(['G'], model, tokenizer, n_tokens=19, temperature=0.7, |
|
|
top_k=4, device=device, verbose=0) |
|
|
spacer = spacer_seqs[0][:20].replace('T', 'U') |
|
|
|
|
|
|
|
|
pam = 'NGG' if current_cas == 'cas9' else 'TTTN' |
|
|
|
|
|
results.append(f"\n{current_cas.upper()} Protein ({len(cas_protein)} nt):") |
|
|
results.append(f"{cas_protein[:80]}..." if len(cas_protein) > 80 else cas_protein) |
|
|
results.append(f"\ngRNA Spacer: {spacer}") |
|
|
results.append(f"PAM Sequence: {pam}") |
|
|
if current_cas == 'cas9': |
|
|
results.append(f"tracrRNA: AGCAUAGCAAGUUAAAAUAAGGCUAGUCCGU") |
|
|
|
|
|
return "\n".join(results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_spacer_simple(length: int) -> str: |
|
|
"""Generate a simple random spacer.""" |
|
|
bases = ['A', 'T', 'G', 'C'] |
|
|
return ''.join(np.random.choice(bases) for _ in range(length)) |
|
|
|
|
|
|
|
|
def design_regulatory(n_designs: int, expression_level: str) -> str: |
|
|
"""Design regulatory sequences.""" |
|
|
load_model() |
|
|
|
|
|
|
|
|
promoter_templates = { |
|
|
'High': ('TTGACA', 'TATAAT'), |
|
|
'Medium': ('TTGACT', 'TATACT'), |
|
|
'Low': ('TTGCCA', 'TATGAT') |
|
|
} |
|
|
|
|
|
rbs_templates = { |
|
|
'High': 'AGGAGGU', |
|
|
'Medium': 'AGGAGG', |
|
|
'Low': 'AGGA' |
|
|
} |
|
|
|
|
|
results = ["# Regulatory Sequences Design\n"] |
|
|
|
|
|
levels = ['High', 'Medium', 'Low'] |
|
|
|
|
|
for i in range(n_designs): |
|
|
if expression_level == 'Mixed': |
|
|
level = levels[i % 3] |
|
|
else: |
|
|
level = expression_level |
|
|
|
|
|
results.append(f"\n{'='*70}") |
|
|
results.append(f"Design {i+1}: {level} Expression") |
|
|
results.append('='*70) |
|
|
|
|
|
|
|
|
box_35, box_10 = promoter_templates[level] |
|
|
|
|
|
|
|
|
spacer_35_10 = generate_spacer_simple(17) |
|
|
spacer_10_rbs = generate_spacer_simple(7) |
|
|
|
|
|
|
|
|
rbs = rbs_templates[level] |
|
|
|
|
|
|
|
|
spacer_rbs_atg = generate_spacer_simple(7) |
|
|
|
|
|
|
|
|
promoter = box_35 + spacer_35_10 + box_10 |
|
|
full_region = promoter + spacer_10_rbs + rbs + spacer_rbs_atg + 'ATG' |
|
|
|
|
|
gc_content = 100 * (full_region.count('G') + full_region.count('C')) / len(full_region) |
|
|
|
|
|
results.append(f"\nComponents:") |
|
|
results.append(f" -35 box: {box_35}") |
|
|
results.append(f" -10 box: {box_10}") |
|
|
results.append(f" RBS (Shine-Dalgarno): {rbs}") |
|
|
results.append(f" Start codon: ATG") |
|
|
results.append(f"\nFull Regulatory Region ({len(full_region)} bp, GC={gc_content:.1f}%):") |
|
|
results.append(full_region) |
|
|
results.append(f"\nPromoter only:") |
|
|
results.append(promoter) |
|
|
|
|
|
return "\n".join(results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create the Gradio interface.""" |
|
|
|
|
|
with gr.Blocks(title="Evo Model Interface", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🧬 Evo Model Interface") |
|
|
gr.Markdown("### Test Evo's predictive and generative capabilities") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("🔍 Function Prediction"): |
|
|
gr.Markdown("### Predict if sequences are functional") |
|
|
gr.Markdown("*Enter sequences in FASTA format or paste a single sequence*") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
func_input = gr.Textbox( |
|
|
label="Input Sequences", |
|
|
placeholder=">seq1|description\nATCGATCGATCG...\n\nOr paste a single sequence directly", |
|
|
lines=8 |
|
|
) |
|
|
func_threshold = gr.Slider( |
|
|
minimum=-3.0, |
|
|
maximum=0.0, |
|
|
value=-1.5, |
|
|
step=0.1, |
|
|
label="Functionality Threshold" |
|
|
) |
|
|
func_btn = gr.Button("Predict Function", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
func_output = gr.Textbox( |
|
|
label="Results", |
|
|
lines=15, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
func_btn.click( |
|
|
fn=predict_function, |
|
|
inputs=[func_input, func_threshold], |
|
|
outputs=func_output |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[">functional_gene\nATGGCACAACCCGCGCCGAACTGGTTGACCTGAAAACCACCGCCGCACTGCGTCAGGCCAGCCAGGCGGAACAA", -1.5], |
|
|
[">noncoding\nGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC", -1.5], |
|
|
], |
|
|
inputs=[func_input, func_threshold] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("🧬 Gene Essentiality"): |
|
|
gr.Markdown("### Predict essential genes in bacteria/phages") |
|
|
gr.Markdown("*Input format: >gene_id|organism|function*") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
ess_input = gr.Textbox( |
|
|
label="Gene Sequences (FASTA)", |
|
|
placeholder=">dnaA|E.coli|Replication initiator\nATGTCGAAAGCCGCAT...", |
|
|
lines=8 |
|
|
) |
|
|
ess_btn = gr.Button("Predict Essentiality", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
ess_output = gr.Textbox( |
|
|
label="Results", |
|
|
lines=15, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
ess_btn.click( |
|
|
fn=predict_essentiality, |
|
|
inputs=ess_input, |
|
|
outputs=ess_output |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("✂️ CRISPR Generation"): |
|
|
gr.Markdown("### Generate synthetic CRISPR-Cas systems") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
crispr_n = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=5, |
|
|
value=2, |
|
|
step=1, |
|
|
label="Number of Systems" |
|
|
) |
|
|
crispr_type = gr.Radio( |
|
|
choices=["Cas9", "Cas12", "Both"], |
|
|
value="Both", |
|
|
label="Cas Type" |
|
|
) |
|
|
crispr_target = gr.Textbox( |
|
|
label="Target Sequence (optional)", |
|
|
placeholder="ATCGATCGATCGATCG", |
|
|
lines=2 |
|
|
) |
|
|
crispr_length = gr.Slider( |
|
|
minimum=500, |
|
|
maximum=2000, |
|
|
value=1000, |
|
|
step=100, |
|
|
label="Cas Protein Length" |
|
|
) |
|
|
crispr_btn = gr.Button("Generate CRISPR Systems", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
crispr_output = gr.Textbox( |
|
|
label="Generated Systems", |
|
|
lines=15, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
crispr_btn.click( |
|
|
fn=generate_crispr, |
|
|
inputs=[crispr_n, crispr_type, crispr_target, crispr_length], |
|
|
outputs=crispr_output |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("🎛️ Regulatory Design"): |
|
|
gr.Markdown("### Design promoter-RBS pairs for gene expression") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
reg_n = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=3, |
|
|
step=1, |
|
|
label="Number of Designs" |
|
|
) |
|
|
reg_level = gr.Radio( |
|
|
choices=["High", "Medium", "Low", "Mixed"], |
|
|
value="Mixed", |
|
|
label="Expression Level" |
|
|
) |
|
|
reg_btn = gr.Button("Design Regulatory Sequences", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
reg_output = gr.Textbox( |
|
|
label="Designed Sequences", |
|
|
lines=15, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
reg_btn.click( |
|
|
fn=design_regulatory, |
|
|
inputs=[reg_n, reg_level], |
|
|
outputs=reg_output |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("💡 **Tips:** Higher scores = more functional/essential | All outputs can be copied | Model: evo-1-8k-base") |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_interface() |
|
|
demo.launch() |
|
|
|