Spaces:
Sleeping
Sleeping
File size: 6,960 Bytes
e954cd5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
"""
RNA Tokenization Module
Provides tokenization for RNA sequences with support for:
- Nucleotide encoding (A, C, G, U, N)
- Special tokens ([PAD], [MASK], [CLS], [SEP])
- Structure labels (STEM, LOOP, BULGE, HAIRPIN)
- Random masking for MLM pretraining
"""
import torch
from dataclasses import dataclass
from typing import List, Tuple
# RNA vocabulary
NUC_VOCAB = ["A", "C", "G", "U", "N"]
SPECIAL_TOKENS = ["[PAD]", "[MASK]", "[CLS]", "[SEP]"]
STRUCT_LABELS = ["NONE", "STEM", "LOOP", "BULGE", "HAIRPIN"]
@dataclass
class RNATokenizerConfig:
"""Configuration for RNA tokenizer"""
add_structure_tokens: bool = False
mlm_probability: float = 0.15
mask_token_probability: float = 0.8 # Prob of replacing with [MASK]
random_token_probability: float = 0.1 # Prob of replacing with random token
# remaining probability: keep original token
class RNATokenizer:
"""
Tokenizer for RNA sequences.
Converts RNA sequences to token IDs and provides masking utilities
for masked language modeling (MLM) pretraining.
Example:
>>> tokenizer = RNATokenizer()
>>> seq = "ACGUACGU"
>>> tokens = tokenizer.encode(seq, max_len=16)
>>> masked_tokens, labels = tokenizer.random_mask(tokens)
"""
def __init__(self, cfg: RNATokenizerConfig = None):
self.cfg = cfg or RNATokenizerConfig()
# Build vocabulary
self.vocab = SPECIAL_TOKENS + NUC_VOCAB
self.token_to_id = {t: i for i, t in enumerate(self.vocab)}
self.id_to_token = {i: t for t, i in self.token_to_id.items()}
# Special token IDs
self.pad_id = self.token_to_id["[PAD]"]
self.mask_id = self.token_to_id["[MASK]"]
self.cls_id = self.token_to_id["[CLS]"]
self.sep_id = self.token_to_id["[SEP]"]
# Nucleotide IDs for random replacement
self.nucleotide_ids = [self.token_to_id[nuc] for nuc in NUC_VOCAB[:4]] # A, C, G, U
self.vocab_size = len(self.vocab)
def encode(self, seq: str, max_len: int) -> torch.Tensor:
"""
Encode an RNA sequence to token IDs.
Args:
seq: RNA sequence string (e.g., "ACGUACGU")
max_len: Maximum sequence length (will pad or truncate)
Returns:
Token IDs tensor of shape [max_len]
"""
tokens = [self.cls_id]
# Convert sequence to tokens
for ch in seq.upper():
if ch in self.token_to_id:
tokens.append(self.token_to_id[ch])
else:
# Unknown nucleotide -> N
tokens.append(self.token_to_id["N"])
tokens.append(self.sep_id)
# Truncate or pad
if len(tokens) > max_len:
tokens = tokens[:max_len]
else:
tokens += [self.pad_id] * (max_len - len(tokens))
return torch.tensor(tokens, dtype=torch.long)
def decode(self, token_ids: torch.Tensor) -> str:
"""
Decode token IDs back to RNA sequence.
Args:
token_ids: Token IDs tensor
Returns:
RNA sequence string
"""
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
tokens = []
for idx in token_ids:
token = self.id_to_token.get(idx, "N")
if token not in SPECIAL_TOKENS:
tokens.append(token)
return "".join(tokens)
def random_mask(
self,
input_ids: torch.Tensor,
mlm_prob: float = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply random masking for MLM pretraining.
Strategy (following BERT):
- Select mlm_prob (default 15%) of tokens
- Of selected tokens:
- 80% -> replace with [MASK]
- 10% -> replace with random token
- 10% -> keep original
Args:
input_ids: Token IDs tensor [B, L] or [L]
mlm_prob: Probability of masking each token (default: from config)
Returns:
Tuple of (masked_input_ids, labels)
- masked_input_ids: Input with masked tokens
- labels: Original tokens for masked positions, -100 elsewhere
"""
if mlm_prob is None:
mlm_prob = self.cfg.mlm_probability
labels = input_ids.clone()
# Create mask for tokens to be masked (excluding special tokens)
probability_matrix = torch.full(input_ids.shape, mlm_prob)
# Don't mask special tokens
special_tokens_mask = (
(input_ids == self.pad_id) |
(input_ids == self.cls_id) |
(input_ids == self.sep_id) |
(input_ids == self.mask_id)
)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
# Sample which tokens to mask
masked_indices = torch.bernoulli(probability_matrix).bool()
# Set labels to -100 for non-masked tokens (ignore in loss)
labels[~masked_indices] = -100
# Create masked input
input_ids_masked = input_ids.clone()
# 80% of the time: replace with [MASK]
mask_token_mask = (
torch.bernoulli(torch.full(input_ids.shape, self.cfg.mask_token_probability)).bool()
& masked_indices
)
input_ids_masked[mask_token_mask] = self.mask_id
# 10% of the time: replace with random nucleotide
random_token_mask = (
torch.bernoulli(
torch.full(input_ids.shape, self.cfg.random_token_probability / (1 - self.cfg.mask_token_probability))
).bool()
& masked_indices
& ~mask_token_mask
)
random_tokens = torch.randint(
len(self.nucleotide_ids),
input_ids.shape,
dtype=torch.long
)
random_token_ids = torch.tensor(self.nucleotide_ids)[random_tokens]
input_ids_masked[random_token_mask] = random_token_ids[random_token_mask]
# Remaining 10%: keep original token (already in input_ids_masked)
return input_ids_masked, labels
def batch_encode(self, sequences: List[str], max_len: int) -> torch.Tensor:
"""
Encode a batch of RNA sequences.
Args:
sequences: List of RNA sequence strings
max_len: Maximum sequence length
Returns:
Tensor of shape [batch_size, max_len]
"""
return torch.stack([self.encode(seq, max_len) for seq in sequences])
def __len__(self) -> int:
"""Return vocabulary size"""
return self.vocab_size
def __repr__(self) -> str:
return f"RNATokenizer(vocab_size={self.vocab_size}, mlm_prob={self.cfg.mlm_probability})"
|