Spaces:
Sleeping
Sleeping
| """ | |
| Configuration Module for Genesis RNA | |
| Defines model architecture and training hyperparameters. | |
| """ | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| class GenesisRNAConfig: | |
| """ | |
| Configuration for Genesis RNA model architecture. | |
| This config defines the Transformer-based encoder architecture for | |
| RNA sequence modeling with support for multi-task learning. | |
| Attributes: | |
| vocab_size: Size of the token vocabulary | |
| d_model: Hidden dimension size | |
| n_heads: Number of attention heads | |
| n_layers: Number of transformer layers | |
| dim_ff: Feedforward dimension | |
| max_len: Maximum sequence length | |
| dropout: Dropout probability | |
| structure_num_labels: Number of structure labels (e.g., STEM, LOOP, BULGE) | |
| use_rotary_embeddings: Whether to use rotary positional embeddings | |
| attention_type: Type of attention mechanism ('standard', 'linear', 'flash') | |
| layer_norm_eps: Epsilon for layer normalization | |
| initializer_range: Standard deviation for weight initialization | |
| """ | |
| # Vocabulary | |
| vocab_size: int = 9 # 4 nucleotides + N + 4 special tokens | |
| # Architecture dimensions | |
| d_model: int = 512 | |
| n_heads: int = 8 | |
| n_layers: int = 8 | |
| dim_ff: int = 2048 | |
| max_len: int = 512 | |
| # Regularization | |
| dropout: float = 0.1 | |
| attention_dropout: float = 0.1 | |
| activation_dropout: float = 0.0 | |
| # Task-specific | |
| structure_num_labels: int = 5 # NONE, STEM, LOOP, BULGE, HAIRPIN | |
| # Positional encoding | |
| use_rotary_embeddings: bool = False | |
| max_position_embeddings: int = 512 | |
| # Attention settings | |
| attention_type: str = "standard" # 'standard', 'linear', 'flash' | |
| # Normalization | |
| layer_norm_eps: float = 1e-12 | |
| # Initialization | |
| initializer_range: float = 0.02 | |
| # Model type identifier | |
| model_type: str = "genesis_rna" | |
| def __post_init__(self): | |
| """Validate configuration""" | |
| assert self.d_model % self.n_heads == 0, \ | |
| f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})" | |
| assert self.vocab_size > 0, "vocab_size must be positive" | |
| assert self.n_layers > 0, "n_layers must be positive" | |
| assert 0 <= self.dropout <= 1, "dropout must be in [0, 1]" | |
| def head_dim(self) -> int: | |
| """Dimension per attention head""" | |
| return self.d_model // self.n_heads | |
| def to_dict(self): | |
| """Convert config to dictionary""" | |
| return { | |
| field.name: getattr(self, field.name) | |
| for field in self.__dataclass_fields__.values() | |
| } | |
| def from_dict(cls, config_dict): | |
| """Create config from dictionary""" | |
| return cls(**config_dict) | |
| class TrainingConfig: | |
| """ | |
| Training configuration for Genesis RNA pretraining. | |
| Attributes: | |
| batch_size: Training batch size | |
| learning_rate: Peak learning rate | |
| num_epochs: Number of training epochs | |
| warmup_steps: Number of warmup steps | |
| max_steps: Maximum training steps (overrides num_epochs if set) | |
| weight_decay: Weight decay coefficient | |
| gradient_clip_norm: Maximum gradient norm for clipping | |
| mlm_probability: Probability of masking tokens for MLM | |
| save_steps: Save checkpoint every N steps | |
| eval_steps: Evaluate every N steps | |
| logging_steps: Log metrics every N steps | |
| output_dir: Directory for saving checkpoints and logs | |
| """ | |
| # Batch and optimization | |
| batch_size: int = 16 | |
| learning_rate: float = 1e-4 | |
| num_epochs: int = 10 | |
| warmup_steps: int = 10000 | |
| max_steps: Optional[int] = None | |
| weight_decay: float = 0.01 | |
| gradient_clip_norm: float = 1.0 | |
| # Learning rate scheduling | |
| lr_scheduler_type: str = "cosine" # 'linear', 'cosine', 'constant' | |
| min_lr_ratio: float = 0.1 # Minimum LR as ratio of peak LR | |
| # MLM settings | |
| mlm_probability: float = 0.15 | |
| # Multi-task loss weights | |
| mlm_loss_weight: float = 1.0 | |
| structure_loss_weight: float = 0.0 # Set to 0 - no structure annotations in ncRNA data | |
| pair_loss_weight: float = 0.0 # Set to 0 - no pair annotations in ncRNA data | |
| # Focal loss settings for pair prediction (handles class imbalance) | |
| use_focal_loss_for_pairs: bool = True | |
| focal_alpha: float = 0.75 # Weight for positive pairs | |
| focal_gamma: float = 2.0 # Focusing parameter | |
| # Pair prediction threshold (optimal for imbalanced data with focal loss) | |
| pair_prediction_threshold: float = 0.35 # Lower than 0.5 to capture more true pairs | |
| # Checkpointing and logging | |
| save_steps: int = 5000 | |
| eval_steps: int = 1000 | |
| logging_steps: int = 100 | |
| output_dir: str = "./checkpoints" | |
| # Device and mixed precision | |
| device: str = "cuda" | |
| mixed_precision: bool = True | |
| fp16: bool = True | |
| # Data loading | |
| num_workers: int = 2 # Reduced from 4 to avoid DataLoader warnings | |
| prefetch_factor: int = 2 | |
| # AST settings | |
| use_ast: bool = True | |
| ast_target_activation: float = 0.4 | |
| ast_controller_kp: float = 0.01 | |
| ast_controller_ki: float = 0.001 | |
| # Reproducibility | |
| seed: int = 42 | |
| def __post_init__(self): | |
| """Validate training configuration""" | |
| assert self.batch_size > 0, "batch_size must be positive" | |
| assert self.learning_rate > 0, "learning_rate must be positive" | |
| assert 0 < self.mlm_probability < 1, "mlm_probability must be in (0, 1)" | |
| class GenesisRNAConfigSmall(GenesisRNAConfig): | |
| """Small model configuration for testing and development""" | |
| d_model: int = 256 | |
| n_heads: int = 4 | |
| n_layers: int = 4 | |
| dim_ff: int = 1024 | |
| max_len: int = 512 | |
| class GenesisRNAConfigBase(GenesisRNAConfig): | |
| """Base model configuration (default)""" | |
| d_model: int = 512 | |
| n_heads: int = 8 | |
| n_layers: int = 8 | |
| dim_ff: int = 2048 | |
| max_len: int = 512 | |
| class GenesisRNAConfigLarge(GenesisRNAConfig): | |
| """Large model configuration for high-capacity training""" | |
| d_model: int = 768 | |
| n_heads: int = 12 | |
| n_layers: int = 12 | |
| dim_ff: int = 3072 | |
| max_len: int = 1024 | |