File size: 6,960 Bytes
e954cd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""

RNA Tokenization Module



Provides tokenization for RNA sequences with support for:

- Nucleotide encoding (A, C, G, U, N)

- Special tokens ([PAD], [MASK], [CLS], [SEP])

- Structure labels (STEM, LOOP, BULGE, HAIRPIN)

- Random masking for MLM pretraining

"""

import torch
from dataclasses import dataclass
from typing import List, Tuple

# RNA vocabulary
NUC_VOCAB = ["A", "C", "G", "U", "N"]
SPECIAL_TOKENS = ["[PAD]", "[MASK]", "[CLS]", "[SEP]"]
STRUCT_LABELS = ["NONE", "STEM", "LOOP", "BULGE", "HAIRPIN"]


@dataclass
class RNATokenizerConfig:
    """Configuration for RNA tokenizer"""
    add_structure_tokens: bool = False
    mlm_probability: float = 0.15
    mask_token_probability: float = 0.8  # Prob of replacing with [MASK]
    random_token_probability: float = 0.1  # Prob of replacing with random token
    # remaining probability: keep original token


class RNATokenizer:
    """

    Tokenizer for RNA sequences.



    Converts RNA sequences to token IDs and provides masking utilities

    for masked language modeling (MLM) pretraining.



    Example:

        >>> tokenizer = RNATokenizer()

        >>> seq = "ACGUACGU"

        >>> tokens = tokenizer.encode(seq, max_len=16)

        >>> masked_tokens, labels = tokenizer.random_mask(tokens)

    """

    def __init__(self, cfg: RNATokenizerConfig = None):
        self.cfg = cfg or RNATokenizerConfig()

        # Build vocabulary
        self.vocab = SPECIAL_TOKENS + NUC_VOCAB
        self.token_to_id = {t: i for i, t in enumerate(self.vocab)}
        self.id_to_token = {i: t for t, i in self.token_to_id.items()}

        # Special token IDs
        self.pad_id = self.token_to_id["[PAD]"]
        self.mask_id = self.token_to_id["[MASK]"]
        self.cls_id = self.token_to_id["[CLS]"]
        self.sep_id = self.token_to_id["[SEP]"]

        # Nucleotide IDs for random replacement
        self.nucleotide_ids = [self.token_to_id[nuc] for nuc in NUC_VOCAB[:4]]  # A, C, G, U

        self.vocab_size = len(self.vocab)

    def encode(self, seq: str, max_len: int) -> torch.Tensor:
        """

        Encode an RNA sequence to token IDs.



        Args:

            seq: RNA sequence string (e.g., "ACGUACGU")

            max_len: Maximum sequence length (will pad or truncate)



        Returns:

            Token IDs tensor of shape [max_len]

        """
        tokens = [self.cls_id]

        # Convert sequence to tokens
        for ch in seq.upper():
            if ch in self.token_to_id:
                tokens.append(self.token_to_id[ch])
            else:
                # Unknown nucleotide -> N
                tokens.append(self.token_to_id["N"])

        tokens.append(self.sep_id)

        # Truncate or pad
        if len(tokens) > max_len:
            tokens = tokens[:max_len]
        else:
            tokens += [self.pad_id] * (max_len - len(tokens))

        return torch.tensor(tokens, dtype=torch.long)

    def decode(self, token_ids: torch.Tensor) -> str:
        """

        Decode token IDs back to RNA sequence.



        Args:

            token_ids: Token IDs tensor



        Returns:

            RNA sequence string

        """
        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.tolist()

        tokens = []
        for idx in token_ids:
            token = self.id_to_token.get(idx, "N")
            if token not in SPECIAL_TOKENS:
                tokens.append(token)

        return "".join(tokens)

    def random_mask(

        self,

        input_ids: torch.Tensor,

        mlm_prob: float = None

    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Apply random masking for MLM pretraining.



        Strategy (following BERT):

        - Select mlm_prob (default 15%) of tokens

        - Of selected tokens:

            - 80% -> replace with [MASK]

            - 10% -> replace with random token

            - 10% -> keep original



        Args:

            input_ids: Token IDs tensor [B, L] or [L]

            mlm_prob: Probability of masking each token (default: from config)



        Returns:

            Tuple of (masked_input_ids, labels)

            - masked_input_ids: Input with masked tokens

            - labels: Original tokens for masked positions, -100 elsewhere

        """
        if mlm_prob is None:
            mlm_prob = self.cfg.mlm_probability

        labels = input_ids.clone()

        # Create mask for tokens to be masked (excluding special tokens)
        probability_matrix = torch.full(input_ids.shape, mlm_prob)

        # Don't mask special tokens
        special_tokens_mask = (
            (input_ids == self.pad_id) |
            (input_ids == self.cls_id) |
            (input_ids == self.sep_id) |
            (input_ids == self.mask_id)
        )
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        # Sample which tokens to mask
        masked_indices = torch.bernoulli(probability_matrix).bool()

        # Set labels to -100 for non-masked tokens (ignore in loss)
        labels[~masked_indices] = -100

        # Create masked input
        input_ids_masked = input_ids.clone()

        # 80% of the time: replace with [MASK]
        mask_token_mask = (
            torch.bernoulli(torch.full(input_ids.shape, self.cfg.mask_token_probability)).bool()
            & masked_indices
        )
        input_ids_masked[mask_token_mask] = self.mask_id

        # 10% of the time: replace with random nucleotide
        random_token_mask = (
            torch.bernoulli(
                torch.full(input_ids.shape, self.cfg.random_token_probability / (1 - self.cfg.mask_token_probability))
            ).bool()
            & masked_indices
            & ~mask_token_mask
        )
        random_tokens = torch.randint(
            len(self.nucleotide_ids),
            input_ids.shape,
            dtype=torch.long
        )
        random_token_ids = torch.tensor(self.nucleotide_ids)[random_tokens]
        input_ids_masked[random_token_mask] = random_token_ids[random_token_mask]

        # Remaining 10%: keep original token (already in input_ids_masked)

        return input_ids_masked, labels

    def batch_encode(self, sequences: List[str], max_len: int) -> torch.Tensor:
        """

        Encode a batch of RNA sequences.



        Args:

            sequences: List of RNA sequence strings

            max_len: Maximum sequence length



        Returns:

            Tensor of shape [batch_size, max_len]

        """
        return torch.stack([self.encode(seq, max_len) for seq in sequences])

    def __len__(self) -> int:
        """Return vocabulary size"""
        return self.vocab_size

    def __repr__(self) -> str:
        return f"RNATokenizer(vocab_size={self.vocab_size}, mlm_prob={self.cfg.mlm_probability})"