mgbam's picture
Update app.py
d5e7a20 verified
"""
Genesis RNA - BRCA Variant Classifier
Developer: Oluwafemi Idiakhoa | Genesis AI Research
Professional AI system for CURING BREAST CANCER through variant pathogenicity prediction.
Trained on 54,943 ClinVar variants using 35M parameter transformer model.
⚠️ RESEARCH USE ONLY - Not for clinical diagnosis
"""
import gradio as gr
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import requests
import re
from typing import Dict, Optional
from dataclasses import dataclass
print("=" * 70)
print("GENESIS RNA - BRCA VARIANT CLASSIFIER")
print("=" * 70)
print("\nDeveloper: Oluwafemi Idiakhoa")
print("Institution: Genesis AI Research")
print("Mission: Cure Breast Cancer Through AI")
print("=" * 70)
# Download model from HuggingFace Model Hub
print("\n📥 Downloading Genesis RNA model...")
try:
model_path = hf_hub_download(
repo_id="mgbam/genesis-rna-base",
filename="models/best_model.pt",
cache_dir="./cache"
)
print(f"✓ Model downloaded: {model_path}")
except Exception as e:
print(f"❌ Error downloading model: {e}")
raise
print("\n📦 Loading Genesis RNA model...")
# ============================================================================
# MODEL ARCHITECTURE (EXACT MATCH TO CHECKPOINT)
# ============================================================================
@dataclass
class GenesisRNAConfig:
vocab_size: int = 9
d_model: int = 512
n_heads: int = 8
n_layers: int = 8
d_ff: int = 2048
max_position_embeddings: int = 512
dropout: float = 0.1
layer_norm_eps: float = 1e-12
initializer_range: float = 0.02
structure_num_labels: int = 4
use_rotary_embeddings: bool = False
@classmethod
def from_dict(cls, config_dict):
return cls(**{k: v for k, v in config_dict.items() if k in cls.__annotations__})
class RNATokenizer:
def __init__(self):
self.vocab = {
'[PAD]': 0, '[MASK]': 1, '[CLS]': 2, '[SEP]': 3,
'A': 4, 'C': 5, 'G': 6, 'U': 7, 'N': 8
}
def encode(self, sequence: str, max_len: int = 512) -> torch.Tensor:
sequence = sequence.upper().replace('T', 'U')[:max_len]
ids = [self.vocab.get(c, self.vocab['N']) for c in sequence]
if len(ids) < max_len:
ids += [self.vocab['[PAD]']] * (max_len - len(ids))
return torch.tensor(ids[:max_len], dtype=torch.long)
class RNAEmbedding(nn.Module):
def __init__(self, cfg):
super().__init__()
self.token_embeddings = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_embeddings = nn.Embedding(cfg.max_position_embeddings, cfg.d_model)
self.layer_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.dropout = nn.Dropout(cfg.dropout)
nn.init.normal_(self.token_embeddings.weight, std=cfg.initializer_range)
nn.init.normal_(self.pos_embeddings.weight, std=cfg.initializer_range)
def forward(self, input_ids):
batch_size, seq_len = input_ids.size()
token_embeds = self.token_embeddings(input_ids)
position_ids = torch.arange(
seq_len, dtype=torch.long, device=input_ids.device
).unsqueeze(0).expand(batch_size, -1)
pos_embeds = self.pos_embeddings(position_ids)
return self.dropout(self.layer_norm(token_embeds + pos_embeds))
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.attention = nn.MultiheadAttention(
cfg.d_model, cfg.n_heads, dropout=cfg.dropout, batch_first=True
)
self.norm1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.norm2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.ffn = nn.Sequential(
nn.Linear(cfg.d_model, cfg.d_ff),
nn.GELU(),
nn.Dropout(cfg.dropout),
nn.Linear(cfg.d_ff, cfg.d_model),
nn.Dropout(cfg.dropout),
)
def forward(self, x, key_padding_mask=None):
attn_out, _ = self.attention(
x, x, x, key_padding_mask=key_padding_mask
)
x = self.norm1(x + attn_out)
return self.norm2(x + self.ffn(x))
class GenesisRNAEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.embeddings = RNAEmbedding(cfg)
self.layers = nn.ModuleList(
[TransformerBlock(cfg) for _ in range(cfg.n_layers)]
)
self.final_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
def forward(self, input_ids, attention_mask=None):
hidden_states = self.embeddings(input_ids)
key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
for layer in self.layers:
hidden_states = layer(hidden_states, key_padding_mask=key_padding_mask)
return self.final_norm(hidden_states)
class MLMHead(nn.Module):
def __init__(self, d_model, vocab_size):
super().__init__()
self.dense = nn.Linear(d_model, d_model)
self.activation = nn.GELU()
self.layer_norm = nn.LayerNorm(d_model)
self.decoder = nn.Linear(d_model, vocab_size)
nn.init.normal_(self.dense.weight, std=0.02)
nn.init.zeros_(self.dense.bias)
nn.init.normal_(self.decoder.weight, std=0.02)
nn.init.zeros_(self.decoder.bias)
def forward(self, hidden_states):
x = self.dense(hidden_states)
x = self.activation(x)
x = self.layer_norm(x)
return self.decoder(x)
class StructureHead(nn.Module):
def __init__(self, d_model, num_labels):
super().__init__()
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(d_model, num_labels)
nn.init.normal_(self.classifier.weight, std=0.02)
nn.init.zeros_(self.classifier.bias)
def forward(self, hidden_states):
hidden_states = self.dropout(hidden_states)
return self.classifier(hidden_states)
class PairHead(nn.Module):
def __init__(self, d_model):
super().__init__()
self.proj_left = nn.Linear(d_model, d_model, bias=False)
self.proj_right = nn.Linear(d_model, d_model, bias=False)
self.scale = nn.Parameter(torch.ones(1))
nn.init.normal_(self.proj_left.weight, std=0.02)
nn.init.normal_(self.proj_right.weight, std=0.02)
def forward(self, hidden_states):
left = self.proj_left(hidden_states)
right = self.proj_right(hidden_states)
return torch.matmul(left, right.transpose(-1, -2)) * self.scale
class GenesisRNAModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.encoder = GenesisRNAEncoder(cfg)
self.mlm_head = MLMHead(cfg.d_model, cfg.vocab_size)
self.struct_head = StructureHead(cfg.d_model, cfg.structure_num_labels)
self.pair_head = PairHead(cfg.d_model)
def forward(self, input_ids, attention_mask=None):
hidden_states = self.encoder(input_ids, attention_mask)
return {
"hidden_states": hidden_states,
"mlm_logits": self.mlm_head(hidden_states),
"struct_logits": self.struct_head(hidden_states),
"pair_logits": self.pair_head(hidden_states),
}
# ============================================================================
# ENSEMBL API FOR REAL SEQUENCES
# ============================================================================
ENSEMBL_API = "https://rest.ensembl.org"
BRCA_TRANSCRIPTS = {
"BRCA1": "ENST00000357654",
"BRCA2": "ENST00000380152",
}
def fetch_transcript_sequence(gene: str) -> Optional[str]:
"""Fetch full BRCA transcript sequence from Ensembl (no truncation)."""
transcript_id = BRCA_TRANSCRIPTS.get(gene.upper())
if not transcript_id:
return None
url = f"{ENSEMBL_API}/sequence/id/{transcript_id}"
try:
response = requests.get(
url,
headers={"Content-Type": "application/json"},
timeout=30,
)
response.raise_for_status()
sequence = response.json().get("seq", "").replace("T", "U")
return sequence # full transcript
except Exception:
return None
def apply_simple_variant(ref_seq: str, hgvs: str) -> Optional[Dict[str, str]]:
"""
Apply simple HGVS cDNA variants to the reference sequence.
Returns dict with:
wildtype: full wild-type sequence
mutant: full mutant sequence
type: variant type string
position: 0-based index of main affected position (for windowing)
"""
hgvs = hgvs.strip()
# Substitution: c.123A>T
sub_match = re.match(r"c\.(\d+)([ACGTU])>([ACGTU])", hgvs, re.IGNORECASE)
if sub_match:
pos = int(sub_match.group(1)) - 1
alt_base = sub_match.group(3).upper().replace("T", "U")
if 0 <= pos < len(ref_seq):
mut_seq = ref_seq[:pos] + alt_base + ref_seq[pos + 1 :]
return {
"wildtype": ref_seq,
"mutant": mut_seq,
"type": "Substitution",
"position": pos,
}
# Deletion: c.68_69delAG or c.68delA
del_match = re.match(r"c\.(\d+)(?:_(\d+))?del([ACGTU]*)", hgvs, re.IGNORECASE)
if del_match:
start = int(del_match.group(1)) - 1
end = int(del_match.group(2)) if del_match.group(2) else start + 1
if 0 <= start < end <= len(ref_seq):
mut_seq = ref_seq[:start] + ref_seq[end:]
return {
"wildtype": ref_seq,
"mutant": mut_seq,
"type": "Deletion",
"position": start,
}
# Duplication: c.5266dupC
dup_match = re.match(r"c\.(\d+)dup([ACGTU]+)", hgvs, re.IGNORECASE)
if dup_match:
pos = int(dup_match.group(1)) - 1
dup_bases = dup_match.group(2).upper().replace("T", "U")
if 0 <= pos < len(ref_seq):
mut_seq = ref_seq[: pos + 1] + dup_bases + ref_seq[pos + 1 :]
return {
"wildtype": ref_seq,
"mutant": mut_seq,
"type": "Duplication",
"position": pos,
}
# Insertion: c.123_124insAT
ins_match = re.match(r"c\.(\d+)_(\d+)ins([ACGTU]+)", hgvs, re.IGNORECASE)
if ins_match:
start = int(ins_match.group(1)) - 1
ins_bases = ins_match.group(3).upper().replace("T", "U")
if 0 <= start < len(ref_seq):
mut_seq = ref_seq[: start + 1] + ins_bases + ref_seq[start + 1 :]
return {
"wildtype": ref_seq,
"mutant": mut_seq,
"type": "Insertion",
"position": start,
}
return None
def make_window(seq: str, center: int, window_size: int = 512) -> str:
"""
Extract a window of length `window_size` around `center` (0-based index).
Ensures:
- full sequence returned if shorter than window_size
- window stays within bounds
"""
if len(seq) <= window_size:
return seq
center = max(0, min(center, len(seq) - 1))
half = window_size // 2
start = max(0, center - half)
end = start + window_size
if end > len(seq):
end = len(seq)
start = max(0, end - window_size)
return seq[start:end]
@dataclass
class VariantPrediction:
variant_id: str
pathogenicity_score: float
interpretation: str
confidence: float
delta_stability: float
structural_impact: float
class BRCAAnalyzer:
def __init__(self, model_path: str, device: str = "cpu"):
self.device = device
self.tokenizer = RNATokenizer()
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
config_dict = checkpoint["config"]["model"]
self.config = GenesisRNAConfig.from_dict(config_dict)
self.model = GenesisRNAModel(self.config)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.to(device)
self.model.eval()
print(f"✓ Model loaded: {self.config.d_model}D, {self.config.n_layers} layers")
def predict_variant(
self,
gene: str,
wild_type_rna: str,
mutant_rna: str,
variant_id: str,
) -> VariantPrediction:
"""
Compute:
- ΔStability: difference in language-model perplexity
- Structural impact: JS divergence between structural distributions
- Pathogenicity: logistic mapping of JS divergence
- Confidence: how far prediction is from the decision boundary (p=0.5)
"""
with torch.no_grad():
wt_ids = self.tokenizer.encode(wild_type_rna).unsqueeze(0).to(self.device)
mut_ids = self.tokenizer.encode(mutant_rna).unsqueeze(0).to(self.device)
wt_output = self.model(wt_ids)
mut_output = self.model(mut_ids)
# Perplexity-based "stability"
wt_perplexity = torch.exp(
F.cross_entropy(
wt_output["mlm_logits"].view(-1, self.config.vocab_size),
wt_ids.view(-1),
reduction="mean",
ignore_index=0,
)
).item()
mut_perplexity = torch.exp(
F.cross_entropy(
mut_output["mlm_logits"].view(-1, self.config.vocab_size),
mut_ids.view(-1),
reduction="mean",
ignore_index=0,
)
).item()
delta_stability = (wt_perplexity - mut_perplexity) * 0.5
# Structural distributions
wt_struct = F.softmax(wt_output["struct_logits"], dim=-1)
mut_struct = F.softmax(mut_output["struct_logits"], dim=-1)
m = 0.5 * (wt_struct + mut_struct)
js_div = 0.5 * (
F.kl_div(torch.log(wt_struct + 1e-10), m, reduction="batchmean")
+ F.kl_div(torch.log(mut_struct + 1e-10), m, reduction="batchmean")
).item()
structural_impact = float(js_div)
# --- Pathogenicity score (heuristic mapping) ---
# We treat JS ~ 0 as "no impact" and JS >> 1 as strong impact.
# Logistic centered around ~0.8 with moderate slope.
k = 3.5 # slope
t = 0.8 # center (JS where p ~ 0.5)
pathogenicity = 1.0 / (1.0 + np.exp(-k * (structural_impact - t)))
# --- Interpretation bands ---
if pathogenicity > 0.8:
interpretation = "Pathogenic"
elif pathogenicity > 0.6:
interpretation = "Likely Pathogenic"
elif pathogenicity > 0.4:
interpretation = "VUS"
elif pathogenicity > 0.2:
interpretation = "Likely Benign"
else:
interpretation = "Benign"
# --- Confidence: distance from decision boundary (p = 0.5) ---
# 0.5 → low confidence (~50%), near 0 or 1 → up to ~95%.
margin = abs(pathogenicity - 0.5)
confidence = 0.5 + 0.9 * margin # 0.5–0.95
confidence = float(max(0.5, min(0.95, confidence)))
return VariantPrediction(
variant_id=variant_id,
pathogenicity_score=float(pathogenicity),
interpretation=interpretation,
confidence=confidence,
delta_stability=float(delta_stability),
structural_impact=structural_impact,
)
analyzer = BRCAAnalyzer(model_path, device="cpu")
print("✓ Analyzer ready\n")
# ============================================================================
# GRADIO INTERFACE
# ============================================================================
def analyze_variant(gene: str, hgvs_notation: str) -> str:
if not gene or not hgvs_notation:
return "<p style='color: red;'>❌ Please provide both gene and HGVS notation</p>"
try:
ref_seq = fetch_transcript_sequence(gene)
if not ref_seq:
return f"<p style='color: red;'>❌ Could not fetch {gene} sequence</p>"
variant_result = apply_simple_variant(ref_seq, hgvs_notation)
if not variant_result:
return f"<p style='color: red;'>❌ Could not parse: {hgvs_notation}</p>"
full_wt = variant_result["wildtype"]
full_mut = variant_result["mutant"]
var_pos = variant_result.get("position", 0)
# 512-nt windows around the variant for the model
wt_window = make_window(full_wt, var_pos, window_size=512)
mut_window = make_window(full_mut, var_pos, window_size=512)
prediction = analyzer.predict_variant(
gene=gene,
wild_type_rna=wt_window,
mutant_rna=mut_window,
variant_id=f"{gene}:{hgvs_notation}",
)
bg_color = (
"#ffebee"
if prediction.interpretation in ["Pathogenic", "Likely Pathogenic"]
else "#e8f5e9"
)
pred_color = (
"#c62828"
if prediction.interpretation in ["Pathogenic", "Likely Pathogenic"]
else "#2e7d32"
)
transcript_len = len(full_wt)
wt_len = len(wt_window)
mut_len = len(mut_window)
return f"""
<div style="padding: 20px; border-radius: 10px; background-color: {bg_color};">
<h2 style="margin-top: 0; color: {pred_color};">🧬 {gene}:{hgvs_notation}</h2>
<h3 style="color: {pred_color}; font-size: 1.8em; margin: 10px 0;">
Prediction: {prediction.interpretation}
</h3>
<div style="margin: 20px 0;">
<p><strong>Variant Type:</strong> {variant_result['type']}</p>
<p><strong>Pathogenicity Score:</strong> {prediction.pathogenicity_score:.3f} / 1.000</p>
<p><strong>Confidence:</strong> {prediction.confidence:.1%}</p>
<p><strong>Sequence Source:</strong> Real Ensembl transcript ({BRCA_TRANSCRIPTS[gene.upper()]})</p>
</div>
<h3>📊 Computational Analysis</h3>
<p><strong>ΔStability (Perplexity):</strong> {prediction.delta_stability:.3f}</p>
<p><strong>Structural Impact (JS Divergence):</strong> {prediction.structural_impact:.4f}</p>
<h3>🧬 Sequence Context</h3>
<p><strong>Full transcript length:</strong> {transcript_len} nt</p>
<p><strong>Window analyzed around variant (WT / Mutant):</strong> {wt_len} nt / {mut_len} nt</p>
<p><strong>Approximate variant position (cDNA):</strong> {var_pos + 1}</p>
<h3>📋 Clinical Interpretation</h3>
<p>
{'This variant is predicted to be <strong>pathogenic</strong> and may disrupt normal DNA repair mechanisms, increasing breast/ovarian cancer risk.'
if prediction.interpretation in ['Pathogenic', 'Likely Pathogenic']
else 'This variant is predicted to be <strong>benign</strong> and unlikely to significantly affect protein function or increase cancer risk.'}
</p>
<h3>💡 Recommendations</h3>
<ul>
{'<li>Enhanced cancer screening recommended</li><li>Genetic counseling for family planning</li><li>Discuss risk-reducing strategies with healthcare provider</li><li>Family cascade testing appropriate</li>'
if prediction.interpretation in ['Pathogenic', 'Likely Pathogenic']
else '<li>Standard cancer screening guidelines</li><li>No specific intervention required</li><li>Routine follow-up as appropriate</li>'}
</ul>
<h3>⚙️ Model Details</h3>
<p><strong>Architecture:</strong> Genesis RNA BASE (35M parameters)</p>
<p><strong>Training:</strong> 203,749 ncRNA + 54,943 BRCA variants</p>
<hr>
<p style="font-size: 0.9em; color: #666;">
⚠️ <strong>RESEARCH USE ONLY</strong> - Not for clinical diagnosis. Consult certified genetic counselors.
</p>
</div>
"""
except Exception as e:
return f"<p style='color: red;'>❌ Error: {str(e)}</p>"
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px;">
<h1 style="margin: 0; font-size: 2.5em;">🧬 Genesis RNA</h1>
<p style="font-size: 1.2em; margin: 10px 0;">BRCA Variant Classifier for Curing Breast Cancer</p>
</div>
<div style="background-color: #fff3cd; padding: 15px; border-radius: 5px; margin: 20px 0; border-left: 5px solid #ff9800;">
<strong>⚠️ RESEARCH USE ONLY</strong><br>
Not for clinical diagnosis. Consult certified genetic counselors.
</div>
<div style="padding: 15px; background-color: #f5f5f5; border-radius: 5px;">
<strong>Developer:</strong> Oluwafemi Idiakhoa | <strong>Model:</strong> BASE (35M params) | <strong>Training:</strong> 54,943 variants
</div>
"""
)
with gr.Row():
with gr.Column():
gene_input = gr.Radio(
choices=["BRCA1", "BRCA2"], label="🧬 Gene", value="BRCA1"
)
hgvs_input = gr.Textbox(
label="📝 HGVS Notation", placeholder="c.5266dupC"
)
analyze_btn = gr.Button("🔬 Analyze Variant", variant="primary", size="lg")
gr.Markdown(
"""
### 💡 Examples
**BRCA1:**
- `c.5266dupC` - Founder mutation
- `c.181T>G` - Missense
- `c.68_69delAG` - Frameshift
**BRCA2:**
- `c.6275_6276delTT` - Frameshift
- `c.9097dupA` - Duplication
"""
)
with gr.Column(scale=2):
output = gr.HTML()
analyze_btn.click(fn=analyze_variant, inputs=[gene_input, hgvs_input], outputs=output)
gr.HTML(
"""
<div style="margin-top: 30px; padding: 20px; background-color: #f8f9fa; border-radius: 10px;">
<h2>🎯 Mission: Cure Breast Cancer</h2>
<p><strong>GitHub:</strong> <a href="https://github.com/oluwafemidiakhoa/genesi_ai">github.com/oluwafemidiakhoa/genesi_ai</a></p>
<p><strong>HuggingFace:</strong> <a href="https://huggingface.co/mgbam/genesis-rna-base">huggingface.co/mgbam/genesis-rna-base</a></p>
</div>
<div style="text-align: center; padding: 15px; margin-top: 20px; background-color: #e3f2fd; border-radius: 10px;">
<strong>Genesis RNA</strong> | Oluwafemi Idiakhoa | Genesis AI Research | 2025
</div>
"""
)
if __name__ == "__main__":
demo.launch(share=False, show_error=True)