File size: 6,398 Bytes
e954cd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""

Configuration Module for Genesis RNA



Defines model architecture and training hyperparameters.

"""

from dataclasses import dataclass, field
from typing import Optional


@dataclass
class GenesisRNAConfig:
    """

    Configuration for Genesis RNA model architecture.



    This config defines the Transformer-based encoder architecture for

    RNA sequence modeling with support for multi-task learning.



    Attributes:

        vocab_size: Size of the token vocabulary

        d_model: Hidden dimension size

        n_heads: Number of attention heads

        n_layers: Number of transformer layers

        dim_ff: Feedforward dimension

        max_len: Maximum sequence length

        dropout: Dropout probability

        structure_num_labels: Number of structure labels (e.g., STEM, LOOP, BULGE)

        use_rotary_embeddings: Whether to use rotary positional embeddings

        attention_type: Type of attention mechanism ('standard', 'linear', 'flash')

        layer_norm_eps: Epsilon for layer normalization

        initializer_range: Standard deviation for weight initialization

    """

    # Vocabulary
    vocab_size: int = 9  # 4 nucleotides + N + 4 special tokens

    # Architecture dimensions
    d_model: int = 512
    n_heads: int = 8
    n_layers: int = 8
    dim_ff: int = 2048
    max_len: int = 512

    # Regularization
    dropout: float = 0.1
    attention_dropout: float = 0.1
    activation_dropout: float = 0.0

    # Task-specific
    structure_num_labels: int = 5  # NONE, STEM, LOOP, BULGE, HAIRPIN

    # Positional encoding
    use_rotary_embeddings: bool = False
    max_position_embeddings: int = 512

    # Attention settings
    attention_type: str = "standard"  # 'standard', 'linear', 'flash'

    # Normalization
    layer_norm_eps: float = 1e-12

    # Initialization
    initializer_range: float = 0.02

    # Model type identifier
    model_type: str = "genesis_rna"

    def __post_init__(self):
        """Validate configuration"""
        assert self.d_model % self.n_heads == 0, \
            f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
        assert self.vocab_size > 0, "vocab_size must be positive"
        assert self.n_layers > 0, "n_layers must be positive"
        assert 0 <= self.dropout <= 1, "dropout must be in [0, 1]"

    @property
    def head_dim(self) -> int:
        """Dimension per attention head"""
        return self.d_model // self.n_heads

    def to_dict(self):
        """Convert config to dictionary"""
        return {
            field.name: getattr(self, field.name)
            for field in self.__dataclass_fields__.values()
        }

    @classmethod
    def from_dict(cls, config_dict):
        """Create config from dictionary"""
        return cls(**config_dict)


@dataclass
class TrainingConfig:
    """

    Training configuration for Genesis RNA pretraining.



    Attributes:

        batch_size: Training batch size

        learning_rate: Peak learning rate

        num_epochs: Number of training epochs

        warmup_steps: Number of warmup steps

        max_steps: Maximum training steps (overrides num_epochs if set)

        weight_decay: Weight decay coefficient

        gradient_clip_norm: Maximum gradient norm for clipping

        mlm_probability: Probability of masking tokens for MLM

        save_steps: Save checkpoint every N steps

        eval_steps: Evaluate every N steps

        logging_steps: Log metrics every N steps

        output_dir: Directory for saving checkpoints and logs

    """

    # Batch and optimization
    batch_size: int = 16
    learning_rate: float = 1e-4
    num_epochs: int = 10
    warmup_steps: int = 10000
    max_steps: Optional[int] = None
    weight_decay: float = 0.01
    gradient_clip_norm: float = 1.0

    # Learning rate scheduling
    lr_scheduler_type: str = "cosine"  # 'linear', 'cosine', 'constant'
    min_lr_ratio: float = 0.1  # Minimum LR as ratio of peak LR

    # MLM settings
    mlm_probability: float = 0.15

    # Multi-task loss weights
    mlm_loss_weight: float = 1.0
    structure_loss_weight: float = 0.0  # Set to 0 - no structure annotations in ncRNA data
    pair_loss_weight: float = 0.0  # Set to 0 - no pair annotations in ncRNA data

    # Focal loss settings for pair prediction (handles class imbalance)
    use_focal_loss_for_pairs: bool = True
    focal_alpha: float = 0.75  # Weight for positive pairs
    focal_gamma: float = 2.0  # Focusing parameter

    # Pair prediction threshold (optimal for imbalanced data with focal loss)
    pair_prediction_threshold: float = 0.35  # Lower than 0.5 to capture more true pairs

    # Checkpointing and logging
    save_steps: int = 5000
    eval_steps: int = 1000
    logging_steps: int = 100
    output_dir: str = "./checkpoints"

    # Device and mixed precision
    device: str = "cuda"
    mixed_precision: bool = True
    fp16: bool = True

    # Data loading
    num_workers: int = 2  # Reduced from 4 to avoid DataLoader warnings
    prefetch_factor: int = 2

    # AST settings
    use_ast: bool = True
    ast_target_activation: float = 0.4
    ast_controller_kp: float = 0.01
    ast_controller_ki: float = 0.001

    # Reproducibility
    seed: int = 42

    def __post_init__(self):
        """Validate training configuration"""
        assert self.batch_size > 0, "batch_size must be positive"
        assert self.learning_rate > 0, "learning_rate must be positive"
        assert 0 < self.mlm_probability < 1, "mlm_probability must be in (0, 1)"


@dataclass
class GenesisRNAConfigSmall(GenesisRNAConfig):
    """Small model configuration for testing and development"""
    d_model: int = 256
    n_heads: int = 4
    n_layers: int = 4
    dim_ff: int = 1024
    max_len: int = 512


@dataclass
class GenesisRNAConfigBase(GenesisRNAConfig):
    """Base model configuration (default)"""
    d_model: int = 512
    n_heads: int = 8
    n_layers: int = 8
    dim_ff: int = 2048
    max_len: int = 512


@dataclass
class GenesisRNAConfigLarge(GenesisRNAConfig):
    """Large model configuration for high-capacity training"""
    d_model: int = 768
    n_heads: int = 12
    n_layers: int = 12
    dim_ff: int = 3072
    max_len: int = 1024