Spaces:
Sleeping
Sleeping
| """ | |
| RNA Tokenization Module | |
| Provides tokenization for RNA sequences with support for: | |
| - Nucleotide encoding (A, C, G, U, N) | |
| - Special tokens ([PAD], [MASK], [CLS], [SEP]) | |
| - Structure labels (STEM, LOOP, BULGE, HAIRPIN) | |
| - Random masking for MLM pretraining | |
| """ | |
| import torch | |
| from dataclasses import dataclass | |
| from typing import List, Tuple | |
| # RNA vocabulary | |
| NUC_VOCAB = ["A", "C", "G", "U", "N"] | |
| SPECIAL_TOKENS = ["[PAD]", "[MASK]", "[CLS]", "[SEP]"] | |
| STRUCT_LABELS = ["NONE", "STEM", "LOOP", "BULGE", "HAIRPIN"] | |
| class RNATokenizerConfig: | |
| """Configuration for RNA tokenizer""" | |
| add_structure_tokens: bool = False | |
| mlm_probability: float = 0.15 | |
| mask_token_probability: float = 0.8 # Prob of replacing with [MASK] | |
| random_token_probability: float = 0.1 # Prob of replacing with random token | |
| # remaining probability: keep original token | |
| class RNATokenizer: | |
| """ | |
| Tokenizer for RNA sequences. | |
| Converts RNA sequences to token IDs and provides masking utilities | |
| for masked language modeling (MLM) pretraining. | |
| Example: | |
| >>> tokenizer = RNATokenizer() | |
| >>> seq = "ACGUACGU" | |
| >>> tokens = tokenizer.encode(seq, max_len=16) | |
| >>> masked_tokens, labels = tokenizer.random_mask(tokens) | |
| """ | |
| def __init__(self, cfg: RNATokenizerConfig = None): | |
| self.cfg = cfg or RNATokenizerConfig() | |
| # Build vocabulary | |
| self.vocab = SPECIAL_TOKENS + NUC_VOCAB | |
| self.token_to_id = {t: i for i, t in enumerate(self.vocab)} | |
| self.id_to_token = {i: t for t, i in self.token_to_id.items()} | |
| # Special token IDs | |
| self.pad_id = self.token_to_id["[PAD]"] | |
| self.mask_id = self.token_to_id["[MASK]"] | |
| self.cls_id = self.token_to_id["[CLS]"] | |
| self.sep_id = self.token_to_id["[SEP]"] | |
| # Nucleotide IDs for random replacement | |
| self.nucleotide_ids = [self.token_to_id[nuc] for nuc in NUC_VOCAB[:4]] # A, C, G, U | |
| self.vocab_size = len(self.vocab) | |
| def encode(self, seq: str, max_len: int) -> torch.Tensor: | |
| """ | |
| Encode an RNA sequence to token IDs. | |
| Args: | |
| seq: RNA sequence string (e.g., "ACGUACGU") | |
| max_len: Maximum sequence length (will pad or truncate) | |
| Returns: | |
| Token IDs tensor of shape [max_len] | |
| """ | |
| tokens = [self.cls_id] | |
| # Convert sequence to tokens | |
| for ch in seq.upper(): | |
| if ch in self.token_to_id: | |
| tokens.append(self.token_to_id[ch]) | |
| else: | |
| # Unknown nucleotide -> N | |
| tokens.append(self.token_to_id["N"]) | |
| tokens.append(self.sep_id) | |
| # Truncate or pad | |
| if len(tokens) > max_len: | |
| tokens = tokens[:max_len] | |
| else: | |
| tokens += [self.pad_id] * (max_len - len(tokens)) | |
| return torch.tensor(tokens, dtype=torch.long) | |
| def decode(self, token_ids: torch.Tensor) -> str: | |
| """ | |
| Decode token IDs back to RNA sequence. | |
| Args: | |
| token_ids: Token IDs tensor | |
| Returns: | |
| RNA sequence string | |
| """ | |
| if isinstance(token_ids, torch.Tensor): | |
| token_ids = token_ids.tolist() | |
| tokens = [] | |
| for idx in token_ids: | |
| token = self.id_to_token.get(idx, "N") | |
| if token not in SPECIAL_TOKENS: | |
| tokens.append(token) | |
| return "".join(tokens) | |
| def random_mask( | |
| self, | |
| input_ids: torch.Tensor, | |
| mlm_prob: float = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Apply random masking for MLM pretraining. | |
| Strategy (following BERT): | |
| - Select mlm_prob (default 15%) of tokens | |
| - Of selected tokens: | |
| - 80% -> replace with [MASK] | |
| - 10% -> replace with random token | |
| - 10% -> keep original | |
| Args: | |
| input_ids: Token IDs tensor [B, L] or [L] | |
| mlm_prob: Probability of masking each token (default: from config) | |
| Returns: | |
| Tuple of (masked_input_ids, labels) | |
| - masked_input_ids: Input with masked tokens | |
| - labels: Original tokens for masked positions, -100 elsewhere | |
| """ | |
| if mlm_prob is None: | |
| mlm_prob = self.cfg.mlm_probability | |
| labels = input_ids.clone() | |
| # Create mask for tokens to be masked (excluding special tokens) | |
| probability_matrix = torch.full(input_ids.shape, mlm_prob) | |
| # Don't mask special tokens | |
| special_tokens_mask = ( | |
| (input_ids == self.pad_id) | | |
| (input_ids == self.cls_id) | | |
| (input_ids == self.sep_id) | | |
| (input_ids == self.mask_id) | |
| ) | |
| probability_matrix.masked_fill_(special_tokens_mask, value=0.0) | |
| # Sample which tokens to mask | |
| masked_indices = torch.bernoulli(probability_matrix).bool() | |
| # Set labels to -100 for non-masked tokens (ignore in loss) | |
| labels[~masked_indices] = -100 | |
| # Create masked input | |
| input_ids_masked = input_ids.clone() | |
| # 80% of the time: replace with [MASK] | |
| mask_token_mask = ( | |
| torch.bernoulli(torch.full(input_ids.shape, self.cfg.mask_token_probability)).bool() | |
| & masked_indices | |
| ) | |
| input_ids_masked[mask_token_mask] = self.mask_id | |
| # 10% of the time: replace with random nucleotide | |
| random_token_mask = ( | |
| torch.bernoulli( | |
| torch.full(input_ids.shape, self.cfg.random_token_probability / (1 - self.cfg.mask_token_probability)) | |
| ).bool() | |
| & masked_indices | |
| & ~mask_token_mask | |
| ) | |
| random_tokens = torch.randint( | |
| len(self.nucleotide_ids), | |
| input_ids.shape, | |
| dtype=torch.long | |
| ) | |
| random_token_ids = torch.tensor(self.nucleotide_ids)[random_tokens] | |
| input_ids_masked[random_token_mask] = random_token_ids[random_token_mask] | |
| # Remaining 10%: keep original token (already in input_ids_masked) | |
| return input_ids_masked, labels | |
| def batch_encode(self, sequences: List[str], max_len: int) -> torch.Tensor: | |
| """ | |
| Encode a batch of RNA sequences. | |
| Args: | |
| sequences: List of RNA sequence strings | |
| max_len: Maximum sequence length | |
| Returns: | |
| Tensor of shape [batch_size, max_len] | |
| """ | |
| return torch.stack([self.encode(seq, max_len) for seq in sequences]) | |
| def __len__(self) -> int: | |
| """Return vocabulary size""" | |
| return self.vocab_size | |
| def __repr__(self) -> str: | |
| return f"RNATokenizer(vocab_size={self.vocab_size}, mlm_prob={self.cfg.mlm_probability})" | |