mgbam's picture
Upload 4 files
e954cd5 verified
"""
Configuration Module for Genesis RNA
Defines model architecture and training hyperparameters.
"""
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class GenesisRNAConfig:
"""
Configuration for Genesis RNA model architecture.
This config defines the Transformer-based encoder architecture for
RNA sequence modeling with support for multi-task learning.
Attributes:
vocab_size: Size of the token vocabulary
d_model: Hidden dimension size
n_heads: Number of attention heads
n_layers: Number of transformer layers
dim_ff: Feedforward dimension
max_len: Maximum sequence length
dropout: Dropout probability
structure_num_labels: Number of structure labels (e.g., STEM, LOOP, BULGE)
use_rotary_embeddings: Whether to use rotary positional embeddings
attention_type: Type of attention mechanism ('standard', 'linear', 'flash')
layer_norm_eps: Epsilon for layer normalization
initializer_range: Standard deviation for weight initialization
"""
# Vocabulary
vocab_size: int = 9 # 4 nucleotides + N + 4 special tokens
# Architecture dimensions
d_model: int = 512
n_heads: int = 8
n_layers: int = 8
dim_ff: int = 2048
max_len: int = 512
# Regularization
dropout: float = 0.1
attention_dropout: float = 0.1
activation_dropout: float = 0.0
# Task-specific
structure_num_labels: int = 5 # NONE, STEM, LOOP, BULGE, HAIRPIN
# Positional encoding
use_rotary_embeddings: bool = False
max_position_embeddings: int = 512
# Attention settings
attention_type: str = "standard" # 'standard', 'linear', 'flash'
# Normalization
layer_norm_eps: float = 1e-12
# Initialization
initializer_range: float = 0.02
# Model type identifier
model_type: str = "genesis_rna"
def __post_init__(self):
"""Validate configuration"""
assert self.d_model % self.n_heads == 0, \
f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
assert self.vocab_size > 0, "vocab_size must be positive"
assert self.n_layers > 0, "n_layers must be positive"
assert 0 <= self.dropout <= 1, "dropout must be in [0, 1]"
@property
def head_dim(self) -> int:
"""Dimension per attention head"""
return self.d_model // self.n_heads
def to_dict(self):
"""Convert config to dictionary"""
return {
field.name: getattr(self, field.name)
for field in self.__dataclass_fields__.values()
}
@classmethod
def from_dict(cls, config_dict):
"""Create config from dictionary"""
return cls(**config_dict)
@dataclass
class TrainingConfig:
"""
Training configuration for Genesis RNA pretraining.
Attributes:
batch_size: Training batch size
learning_rate: Peak learning rate
num_epochs: Number of training epochs
warmup_steps: Number of warmup steps
max_steps: Maximum training steps (overrides num_epochs if set)
weight_decay: Weight decay coefficient
gradient_clip_norm: Maximum gradient norm for clipping
mlm_probability: Probability of masking tokens for MLM
save_steps: Save checkpoint every N steps
eval_steps: Evaluate every N steps
logging_steps: Log metrics every N steps
output_dir: Directory for saving checkpoints and logs
"""
# Batch and optimization
batch_size: int = 16
learning_rate: float = 1e-4
num_epochs: int = 10
warmup_steps: int = 10000
max_steps: Optional[int] = None
weight_decay: float = 0.01
gradient_clip_norm: float = 1.0
# Learning rate scheduling
lr_scheduler_type: str = "cosine" # 'linear', 'cosine', 'constant'
min_lr_ratio: float = 0.1 # Minimum LR as ratio of peak LR
# MLM settings
mlm_probability: float = 0.15
# Multi-task loss weights
mlm_loss_weight: float = 1.0
structure_loss_weight: float = 0.0 # Set to 0 - no structure annotations in ncRNA data
pair_loss_weight: float = 0.0 # Set to 0 - no pair annotations in ncRNA data
# Focal loss settings for pair prediction (handles class imbalance)
use_focal_loss_for_pairs: bool = True
focal_alpha: float = 0.75 # Weight for positive pairs
focal_gamma: float = 2.0 # Focusing parameter
# Pair prediction threshold (optimal for imbalanced data with focal loss)
pair_prediction_threshold: float = 0.35 # Lower than 0.5 to capture more true pairs
# Checkpointing and logging
save_steps: int = 5000
eval_steps: int = 1000
logging_steps: int = 100
output_dir: str = "./checkpoints"
# Device and mixed precision
device: str = "cuda"
mixed_precision: bool = True
fp16: bool = True
# Data loading
num_workers: int = 2 # Reduced from 4 to avoid DataLoader warnings
prefetch_factor: int = 2
# AST settings
use_ast: bool = True
ast_target_activation: float = 0.4
ast_controller_kp: float = 0.01
ast_controller_ki: float = 0.001
# Reproducibility
seed: int = 42
def __post_init__(self):
"""Validate training configuration"""
assert self.batch_size > 0, "batch_size must be positive"
assert self.learning_rate > 0, "learning_rate must be positive"
assert 0 < self.mlm_probability < 1, "mlm_probability must be in (0, 1)"
@dataclass
class GenesisRNAConfigSmall(GenesisRNAConfig):
"""Small model configuration for testing and development"""
d_model: int = 256
n_heads: int = 4
n_layers: int = 4
dim_ff: int = 1024
max_len: int = 512
@dataclass
class GenesisRNAConfigBase(GenesisRNAConfig):
"""Base model configuration (default)"""
d_model: int = 512
n_heads: int = 8
n_layers: int = 8
dim_ff: int = 2048
max_len: int = 512
@dataclass
class GenesisRNAConfigLarge(GenesisRNAConfig):
"""Large model configuration for high-capacity training"""
d_model: int = 768
n_heads: int = 12
n_layers: int = 12
dim_ff: int = 3072
max_len: int = 1024