Spaces:
Sleeping
Sleeping
File size: 11,056 Bytes
e954cd5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 |
"""
Model Architecture for Genesis RNA
Implements a Transformer-based encoder with RNA-specific features:
- RNA-aware positional encodings
- Multi-head self-attention
- Support for secondary structure and base-pair prediction
"""
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple
from .config import GenesisRNAConfig
from .heads import MLMHead, StructureHead, PairHead
class RNAEmbedding(nn.Module):
"""
RNA sequence embeddings with positional encoding.
Combines:
- Token embeddings (nucleotides + special tokens)
- Positional embeddings (learned or sinusoidal)
- Layer normalization and dropout
"""
def __init__(self, cfg: GenesisRNAConfig):
super().__init__()
self.cfg = cfg
# Token embeddings
self.token_embeddings = nn.Embedding(cfg.vocab_size, cfg.d_model)
# Positional embeddings
if cfg.use_rotary_embeddings:
# Rotary embeddings will be applied in attention
self.pos_embeddings = None
else:
# Standard learned positional embeddings
self.pos_embeddings = nn.Embedding(cfg.max_position_embeddings, cfg.d_model)
self.layer_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.dropout = nn.Dropout(cfg.dropout)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize embedding weights"""
nn.init.normal_(self.token_embeddings.weight, std=self.cfg.initializer_range)
if self.pos_embeddings is not None:
nn.init.normal_(self.pos_embeddings.weight, std=self.cfg.initializer_range)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
input_ids: Token IDs [batch_size, seq_len]
Returns:
Embeddings [batch_size, seq_len, d_model]
"""
batch_size, seq_len = input_ids.size()
# Token embeddings
token_embeds = self.token_embeddings(input_ids)
# Add positional embeddings
if self.pos_embeddings is not None:
position_ids = torch.arange(
seq_len,
dtype=torch.long,
device=input_ids.device
).unsqueeze(0).expand(batch_size, -1)
pos_embeds = self.pos_embeddings(position_ids)
embeddings = token_embeds + pos_embeds
else:
embeddings = token_embeds
# Normalize and dropout
embeddings = self.layer_norm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class TransformerBlock(nn.Module):
"""
Single Transformer encoder block.
Architecture:
1. Multi-head self-attention
2. Add & Norm
3. Feedforward network
4. Add & Norm
"""
def __init__(self, cfg: GenesisRNAConfig):
super().__init__()
self.cfg = cfg
# Multi-head attention
self.attention = nn.MultiheadAttention(
embed_dim=cfg.d_model,
num_heads=cfg.n_heads,
dropout=cfg.attention_dropout,
batch_first=True,
)
# Layer norms
self.norm1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.norm2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
# Feedforward network
self.ffn = nn.Sequential(
nn.Linear(cfg.d_model, cfg.dim_ff),
nn.GELU(),
nn.Dropout(cfg.activation_dropout),
nn.Linear(cfg.dim_ff, cfg.d_model),
)
# Dropout
self.dropout1 = nn.Dropout(cfg.dropout)
self.dropout2 = nn.Dropout(cfg.dropout)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize layer weights"""
for module in self.ffn.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=self.cfg.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
hidden_states: Input tensor [batch_size, seq_len, d_model]
attention_mask: Attention mask [seq_len, seq_len] (optional)
key_padding_mask: Padding mask [batch_size, seq_len] (optional)
True for positions to ignore
Returns:
Output tensor [batch_size, seq_len, d_model]
"""
# Multi-head attention with residual
attn_output, _ = self.attention(
hidden_states,
hidden_states,
hidden_states,
attn_mask=attention_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)
hidden_states = hidden_states + self.dropout1(attn_output)
hidden_states = self.norm1(hidden_states)
# Feedforward with residual
ffn_output = self.ffn(hidden_states)
hidden_states = hidden_states + self.dropout2(ffn_output)
hidden_states = self.norm2(hidden_states)
return hidden_states
class GenesisRNAEncoder(nn.Module):
"""
RNA Transformer Encoder.
Multi-layer Transformer encoder for processing RNA sequences.
"""
def __init__(self, cfg: GenesisRNAConfig):
super().__init__()
self.cfg = cfg
# Embeddings
self.embeddings = RNAEmbedding(cfg)
# Transformer layers
self.layers = nn.ModuleList([
TransformerBlock(cfg) for _ in range(cfg.n_layers)
])
# Final layer norm
self.final_norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass through encoder.
Args:
input_ids: Token IDs [batch_size, seq_len]
attention_mask: Binary mask [batch_size, seq_len]
1 for real tokens, 0 for padding
Returns:
Hidden states [batch_size, seq_len, d_model]
"""
# Get embeddings
hidden_states = self.embeddings(input_ids)
# Prepare padding mask for attention
# PyTorch MultiheadAttention uses True for positions to IGNORE
key_padding_mask = None
if attention_mask is not None:
key_padding_mask = (attention_mask == 0)
# Pass through transformer layers
for layer in self.layers:
hidden_states = layer(
hidden_states,
key_padding_mask=key_padding_mask
)
# Final normalization
hidden_states = self.final_norm(hidden_states)
return hidden_states
class GenesisRNAModel(nn.Module):
"""
Complete Genesis RNA model with multi-task heads.
Combines:
1. GenesisRNAEncoder (Transformer backbone)
2. MLMHead (masked language modeling)
3. StructureHead (secondary structure prediction)
4. PairHead (base-pair prediction)
Example:
>>> config = GenesisRNAConfig()
>>> model = GenesisRNAModel(config)
>>> input_ids = torch.randint(0, 9, (4, 128))
>>> outputs = model(input_ids)
>>> mlm_logits = outputs["mlm_logits"] # [4, 128, vocab_size]
"""
def __init__(self, cfg: GenesisRNAConfig):
super().__init__()
self.cfg = cfg
# Encoder
self.encoder = GenesisRNAEncoder(cfg)
# Task-specific heads
self.mlm_head = MLMHead(cfg.d_model, cfg.vocab_size)
self.struct_head = StructureHead(cfg.d_model, cfg.structure_num_labels)
self.pair_head = PairHead(cfg.d_model)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_hidden_states: bool = False,
) -> dict:
"""
Forward pass with multi-task outputs.
Args:
input_ids: Token IDs [batch_size, seq_len]
attention_mask: Binary attention mask [batch_size, seq_len]
return_hidden_states: Whether to return encoder hidden states
Returns:
Dictionary with keys:
- mlm_logits: MLM predictions [batch_size, seq_len, vocab_size]
- struct_logits: Structure predictions [batch_size, seq_len, num_struct_labels]
- pair_logits: Pair predictions [batch_size, seq_len, seq_len]
- hidden_states: Encoder outputs [batch_size, seq_len, d_model] (if requested)
"""
# Encode
hidden_states = self.encoder(input_ids, attention_mask)
# Compute task-specific predictions
mlm_logits = self.mlm_head(hidden_states)
struct_logits = self.struct_head(hidden_states)
pair_logits = self.pair_head(hidden_states)
outputs = {
"mlm_logits": mlm_logits,
"struct_logits": struct_logits,
"pair_logits": pair_logits,
}
if return_hidden_states:
outputs["hidden_states"] = hidden_states
return outputs
def get_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Get sequence embeddings (for downstream tasks).
Args:
input_ids: Token IDs [batch_size, seq_len]
Returns:
Hidden states [batch_size, seq_len, d_model]
"""
return self.encoder(input_ids)
def save_pretrained(self, save_path: str):
"""Save model weights and config"""
# Save config as dict for better serialization
config_dict = self.cfg.to_dict() if hasattr(self.cfg, 'to_dict') else self.cfg
torch.save({
'model_state_dict': self.state_dict(),
'config': config_dict,
}, save_path)
@classmethod
def from_pretrained(cls, load_path: str, device: str = 'cpu'):
"""Load model from checkpoint"""
checkpoint = torch.load(load_path, map_location=device)
# Handle config - convert dict to GenesisRNAConfig if needed
config = checkpoint['config']
if isinstance(config, dict):
# If config has 'model' key, extract it (nested structure from training)
if 'model' in config:
config = config['model']
config = GenesisRNAConfig.from_dict(config)
model = cls(config)
model.load_state_dict(checkpoint['model_state_dict'])
return model.to(device)
|