Spaces:
Sleeping
Sleeping
File size: 6,398 Bytes
e954cd5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
"""
Configuration Module for Genesis RNA
Defines model architecture and training hyperparameters.
"""
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class GenesisRNAConfig:
"""
Configuration for Genesis RNA model architecture.
This config defines the Transformer-based encoder architecture for
RNA sequence modeling with support for multi-task learning.
Attributes:
vocab_size: Size of the token vocabulary
d_model: Hidden dimension size
n_heads: Number of attention heads
n_layers: Number of transformer layers
dim_ff: Feedforward dimension
max_len: Maximum sequence length
dropout: Dropout probability
structure_num_labels: Number of structure labels (e.g., STEM, LOOP, BULGE)
use_rotary_embeddings: Whether to use rotary positional embeddings
attention_type: Type of attention mechanism ('standard', 'linear', 'flash')
layer_norm_eps: Epsilon for layer normalization
initializer_range: Standard deviation for weight initialization
"""
# Vocabulary
vocab_size: int = 9 # 4 nucleotides + N + 4 special tokens
# Architecture dimensions
d_model: int = 512
n_heads: int = 8
n_layers: int = 8
dim_ff: int = 2048
max_len: int = 512
# Regularization
dropout: float = 0.1
attention_dropout: float = 0.1
activation_dropout: float = 0.0
# Task-specific
structure_num_labels: int = 5 # NONE, STEM, LOOP, BULGE, HAIRPIN
# Positional encoding
use_rotary_embeddings: bool = False
max_position_embeddings: int = 512
# Attention settings
attention_type: str = "standard" # 'standard', 'linear', 'flash'
# Normalization
layer_norm_eps: float = 1e-12
# Initialization
initializer_range: float = 0.02
# Model type identifier
model_type: str = "genesis_rna"
def __post_init__(self):
"""Validate configuration"""
assert self.d_model % self.n_heads == 0, \
f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
assert self.vocab_size > 0, "vocab_size must be positive"
assert self.n_layers > 0, "n_layers must be positive"
assert 0 <= self.dropout <= 1, "dropout must be in [0, 1]"
@property
def head_dim(self) -> int:
"""Dimension per attention head"""
return self.d_model // self.n_heads
def to_dict(self):
"""Convert config to dictionary"""
return {
field.name: getattr(self, field.name)
for field in self.__dataclass_fields__.values()
}
@classmethod
def from_dict(cls, config_dict):
"""Create config from dictionary"""
return cls(**config_dict)
@dataclass
class TrainingConfig:
"""
Training configuration for Genesis RNA pretraining.
Attributes:
batch_size: Training batch size
learning_rate: Peak learning rate
num_epochs: Number of training epochs
warmup_steps: Number of warmup steps
max_steps: Maximum training steps (overrides num_epochs if set)
weight_decay: Weight decay coefficient
gradient_clip_norm: Maximum gradient norm for clipping
mlm_probability: Probability of masking tokens for MLM
save_steps: Save checkpoint every N steps
eval_steps: Evaluate every N steps
logging_steps: Log metrics every N steps
output_dir: Directory for saving checkpoints and logs
"""
# Batch and optimization
batch_size: int = 16
learning_rate: float = 1e-4
num_epochs: int = 10
warmup_steps: int = 10000
max_steps: Optional[int] = None
weight_decay: float = 0.01
gradient_clip_norm: float = 1.0
# Learning rate scheduling
lr_scheduler_type: str = "cosine" # 'linear', 'cosine', 'constant'
min_lr_ratio: float = 0.1 # Minimum LR as ratio of peak LR
# MLM settings
mlm_probability: float = 0.15
# Multi-task loss weights
mlm_loss_weight: float = 1.0
structure_loss_weight: float = 0.0 # Set to 0 - no structure annotations in ncRNA data
pair_loss_weight: float = 0.0 # Set to 0 - no pair annotations in ncRNA data
# Focal loss settings for pair prediction (handles class imbalance)
use_focal_loss_for_pairs: bool = True
focal_alpha: float = 0.75 # Weight for positive pairs
focal_gamma: float = 2.0 # Focusing parameter
# Pair prediction threshold (optimal for imbalanced data with focal loss)
pair_prediction_threshold: float = 0.35 # Lower than 0.5 to capture more true pairs
# Checkpointing and logging
save_steps: int = 5000
eval_steps: int = 1000
logging_steps: int = 100
output_dir: str = "./checkpoints"
# Device and mixed precision
device: str = "cuda"
mixed_precision: bool = True
fp16: bool = True
# Data loading
num_workers: int = 2 # Reduced from 4 to avoid DataLoader warnings
prefetch_factor: int = 2
# AST settings
use_ast: bool = True
ast_target_activation: float = 0.4
ast_controller_kp: float = 0.01
ast_controller_ki: float = 0.001
# Reproducibility
seed: int = 42
def __post_init__(self):
"""Validate training configuration"""
assert self.batch_size > 0, "batch_size must be positive"
assert self.learning_rate > 0, "learning_rate must be positive"
assert 0 < self.mlm_probability < 1, "mlm_probability must be in (0, 1)"
@dataclass
class GenesisRNAConfigSmall(GenesisRNAConfig):
"""Small model configuration for testing and development"""
d_model: int = 256
n_heads: int = 4
n_layers: int = 4
dim_ff: int = 1024
max_len: int = 512
@dataclass
class GenesisRNAConfigBase(GenesisRNAConfig):
"""Base model configuration (default)"""
d_model: int = 512
n_heads: int = 8
n_layers: int = 8
dim_ff: int = 2048
max_len: int = 512
@dataclass
class GenesisRNAConfigLarge(GenesisRNAConfig):
"""Large model configuration for high-capacity training"""
d_model: int = 768
n_heads: int = 12
n_layers: int = 12
dim_ff: int = 3072
max_len: int = 1024
|