mgbam's picture
Upload 4 files
e954cd5 verified
"""
RNA Tokenization Module
Provides tokenization for RNA sequences with support for:
- Nucleotide encoding (A, C, G, U, N)
- Special tokens ([PAD], [MASK], [CLS], [SEP])
- Structure labels (STEM, LOOP, BULGE, HAIRPIN)
- Random masking for MLM pretraining
"""
import torch
from dataclasses import dataclass
from typing import List, Tuple
# RNA vocabulary
NUC_VOCAB = ["A", "C", "G", "U", "N"]
SPECIAL_TOKENS = ["[PAD]", "[MASK]", "[CLS]", "[SEP]"]
STRUCT_LABELS = ["NONE", "STEM", "LOOP", "BULGE", "HAIRPIN"]
@dataclass
class RNATokenizerConfig:
"""Configuration for RNA tokenizer"""
add_structure_tokens: bool = False
mlm_probability: float = 0.15
mask_token_probability: float = 0.8 # Prob of replacing with [MASK]
random_token_probability: float = 0.1 # Prob of replacing with random token
# remaining probability: keep original token
class RNATokenizer:
"""
Tokenizer for RNA sequences.
Converts RNA sequences to token IDs and provides masking utilities
for masked language modeling (MLM) pretraining.
Example:
>>> tokenizer = RNATokenizer()
>>> seq = "ACGUACGU"
>>> tokens = tokenizer.encode(seq, max_len=16)
>>> masked_tokens, labels = tokenizer.random_mask(tokens)
"""
def __init__(self, cfg: RNATokenizerConfig = None):
self.cfg = cfg or RNATokenizerConfig()
# Build vocabulary
self.vocab = SPECIAL_TOKENS + NUC_VOCAB
self.token_to_id = {t: i for i, t in enumerate(self.vocab)}
self.id_to_token = {i: t for t, i in self.token_to_id.items()}
# Special token IDs
self.pad_id = self.token_to_id["[PAD]"]
self.mask_id = self.token_to_id["[MASK]"]
self.cls_id = self.token_to_id["[CLS]"]
self.sep_id = self.token_to_id["[SEP]"]
# Nucleotide IDs for random replacement
self.nucleotide_ids = [self.token_to_id[nuc] for nuc in NUC_VOCAB[:4]] # A, C, G, U
self.vocab_size = len(self.vocab)
def encode(self, seq: str, max_len: int) -> torch.Tensor:
"""
Encode an RNA sequence to token IDs.
Args:
seq: RNA sequence string (e.g., "ACGUACGU")
max_len: Maximum sequence length (will pad or truncate)
Returns:
Token IDs tensor of shape [max_len]
"""
tokens = [self.cls_id]
# Convert sequence to tokens
for ch in seq.upper():
if ch in self.token_to_id:
tokens.append(self.token_to_id[ch])
else:
# Unknown nucleotide -> N
tokens.append(self.token_to_id["N"])
tokens.append(self.sep_id)
# Truncate or pad
if len(tokens) > max_len:
tokens = tokens[:max_len]
else:
tokens += [self.pad_id] * (max_len - len(tokens))
return torch.tensor(tokens, dtype=torch.long)
def decode(self, token_ids: torch.Tensor) -> str:
"""
Decode token IDs back to RNA sequence.
Args:
token_ids: Token IDs tensor
Returns:
RNA sequence string
"""
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
tokens = []
for idx in token_ids:
token = self.id_to_token.get(idx, "N")
if token not in SPECIAL_TOKENS:
tokens.append(token)
return "".join(tokens)
def random_mask(
self,
input_ids: torch.Tensor,
mlm_prob: float = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply random masking for MLM pretraining.
Strategy (following BERT):
- Select mlm_prob (default 15%) of tokens
- Of selected tokens:
- 80% -> replace with [MASK]
- 10% -> replace with random token
- 10% -> keep original
Args:
input_ids: Token IDs tensor [B, L] or [L]
mlm_prob: Probability of masking each token (default: from config)
Returns:
Tuple of (masked_input_ids, labels)
- masked_input_ids: Input with masked tokens
- labels: Original tokens for masked positions, -100 elsewhere
"""
if mlm_prob is None:
mlm_prob = self.cfg.mlm_probability
labels = input_ids.clone()
# Create mask for tokens to be masked (excluding special tokens)
probability_matrix = torch.full(input_ids.shape, mlm_prob)
# Don't mask special tokens
special_tokens_mask = (
(input_ids == self.pad_id) |
(input_ids == self.cls_id) |
(input_ids == self.sep_id) |
(input_ids == self.mask_id)
)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
# Sample which tokens to mask
masked_indices = torch.bernoulli(probability_matrix).bool()
# Set labels to -100 for non-masked tokens (ignore in loss)
labels[~masked_indices] = -100
# Create masked input
input_ids_masked = input_ids.clone()
# 80% of the time: replace with [MASK]
mask_token_mask = (
torch.bernoulli(torch.full(input_ids.shape, self.cfg.mask_token_probability)).bool()
& masked_indices
)
input_ids_masked[mask_token_mask] = self.mask_id
# 10% of the time: replace with random nucleotide
random_token_mask = (
torch.bernoulli(
torch.full(input_ids.shape, self.cfg.random_token_probability / (1 - self.cfg.mask_token_probability))
).bool()
& masked_indices
& ~mask_token_mask
)
random_tokens = torch.randint(
len(self.nucleotide_ids),
input_ids.shape,
dtype=torch.long
)
random_token_ids = torch.tensor(self.nucleotide_ids)[random_tokens]
input_ids_masked[random_token_mask] = random_token_ids[random_token_mask]
# Remaining 10%: keep original token (already in input_ids_masked)
return input_ids_masked, labels
def batch_encode(self, sequences: List[str], max_len: int) -> torch.Tensor:
"""
Encode a batch of RNA sequences.
Args:
sequences: List of RNA sequence strings
max_len: Maximum sequence length
Returns:
Tensor of shape [batch_size, max_len]
"""
return torch.stack([self.encode(seq, max_len) for seq in sequences])
def __len__(self) -> int:
"""Return vocabulary size"""
return self.vocab_size
def __repr__(self) -> str:
return f"RNATokenizer(vocab_size={self.vocab_size}, mlm_prob={self.cfg.mlm_probability})"