Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import random | |
| from typing import Any | |
| import numpy as np | |
| import pandas as pd | |
| from pytorch_lightning import Trainer, LightningModule, LightningDataModule | |
| from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall | |
| from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments, AutoModelForSequenceClassification | |
| from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions | |
| import torch | |
| from torch import nn | |
| from datasets import load_dataset, IterableDataset | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login | |
| timber = logging.getLogger() | |
| # logging.basicConfig(level=logging.DEBUG) | |
| logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs... | |
| black = "\u001b[30m" | |
| red = "\u001b[31m" | |
| green = "\u001b[32m" | |
| yellow = "\u001b[33m" | |
| blue = "\u001b[34m" | |
| magenta = "\u001b[35m" | |
| cyan = "\u001b[36m" | |
| white = "\u001b[37m" | |
| FORWARD = "FORWARD_INPUT" | |
| BACKWARD = "BACKWARD_INPUT" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| PRETRAINED_MODEL_NAME: str = "LongSafari/hyenadna-small-32k-seqlen-hf" | |
| def login_inside_huggingface_virtualmachine(): | |
| # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space) | |
| try: | |
| load_dotenv() # Only useful on your laptop if .env exists | |
| print(".env file loaded successfully.") | |
| except Exception as e: | |
| print(f"Warning: Could not load .env file. Exception: {e}") | |
| # Try to get the token from environment variables | |
| try: | |
| token = os.getenv("HF_TOKEN") | |
| if not token: | |
| raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.") | |
| # Log in to Hugging Face Hub | |
| login(token) | |
| print("Logged in to Hugging Face Hub successfully.") | |
| except Exception as e: | |
| print(f"Error during Hugging Face login: {e}") | |
| # Handle the error appropriately (e.g., exit or retry) | |
| def one_hot_e(dna_seq: str) -> np.ndarray: | |
| mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]), | |
| 'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]), | |
| 'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]), | |
| 'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]), | |
| 'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]), | |
| 'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])} | |
| size_of_a_seq: int = len(dna_seq) | |
| # forward = np.zeros(shape=(size_of_a_seq, 4)) | |
| forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)] | |
| encoded = np.asarray(forward_list) | |
| encoded_transposed = encoded.transpose() # todo: Needs review | |
| return encoded_transposed | |
| def one_hot_e_column(column: pd.Series) -> np.ndarray: | |
| tmp_list: list = [one_hot_e(seq) for seq in column] | |
| encoded_column = np.asarray(tmp_list).astype(np.float32) | |
| return encoded_column | |
| def reverse_dna_seq(dna_seq: str) -> str: | |
| # m_reversed = "" | |
| # for i in range(0, len(dna_seq)): | |
| # m_reversed = dna_seq[i] + m_reversed | |
| # return m_reversed | |
| return dna_seq[::-1] | |
| def complement_dna_seq(dna_seq: str) -> str: | |
| comp_map = {"A": "T", "C": "G", "T": "A", "G": "C", | |
| "a": "t", "c": "g", "t": "a", "g": "c", | |
| "N": "N", "H": "H", "-": "-", | |
| "n": "n", "h": "h" | |
| } | |
| comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq] | |
| comp_dna_seq: str = "".join(comp_dna_seq_list) | |
| return comp_dna_seq | |
| def reverse_complement_dna_seq(dna_seq: str) -> str: | |
| return reverse_dna_seq(complement_dna_seq(dna_seq)) | |
| def reverse_complement_column(column: pd.Series) -> np.ndarray: | |
| rc_column: list = [reverse_complement_dna_seq(seq) for seq in column] | |
| return rc_column | |
| class TorchMetrics: | |
| def __init__(self, device=DEVICE): | |
| self.binary_accuracy = BinaryAccuracy().to(device) | |
| self.binary_auc = BinaryAUROC().to(device) | |
| self.binary_f1_score = BinaryF1Score().to(device) | |
| self.binary_precision = BinaryPrecision().to(device) | |
| self.binary_recall = BinaryRecall().to(device) | |
| pass | |
| def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed | |
| self.binary_accuracy.update(preds=batch_predicted_labels, target=batch_actual_labels) | |
| self.binary_auc.update(preds=batch_predicted_labels, target=batch_actual_labels) | |
| self.binary_f1_score.update(preds=batch_predicted_labels, target=batch_actual_labels) | |
| self.binary_precision.update(preds=batch_predicted_labels, target=batch_actual_labels) | |
| self.binary_recall.update(preds=batch_predicted_labels, target=batch_actual_labels) | |
| pass | |
| def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green): | |
| b_accuracy = self.binary_accuracy.compute() | |
| b_auc = self.binary_auc.compute() | |
| b_f1_score = self.binary_f1_score.compute() | |
| b_precision = self.binary_precision.compute() | |
| b_recall = self.binary_recall.compute() | |
| timber.info( | |
| log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}") | |
| log(f"{log_prefix}_accuracy", b_accuracy) | |
| log(f"{log_prefix}_auc", b_auc) | |
| log(f"{log_prefix}_f1_score", b_f1_score) | |
| log(f"{log_prefix}_precision", b_precision) | |
| log(f"{log_prefix}_recall", b_recall) | |
| self.binary_accuracy.reset() | |
| self.binary_auc.reset() | |
| self.binary_f1_score.reset() | |
| self.binary_precision.reset() | |
| self.binary_recall.reset() | |
| pass | |
| def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF): | |
| start = 0 | |
| end = len(seq) | |
| rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF))) | |
| random_end = rand_pos + len(DEBUG_MOTIF) | |
| output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end] | |
| assert len(seq) == len(output) | |
| return output | |
| class PagingMQTLDataset(IterableDataset): | |
| def __init__(self, | |
| m_dataset, | |
| seq_len, | |
| tokenizer, | |
| max_length=512, | |
| check_if_pipeline_is_ok_by_inserting_debug_motif=False): | |
| self.dataset = m_dataset | |
| self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif | |
| self.debug_motif = "ATCGCCTA" | |
| self.seq_len = seq_len | |
| self.bert_tokenizer = tokenizer | |
| self.max_length = max_length | |
| pass | |
| def __iter__(self): | |
| for row in self.dataset: | |
| processed = self.preprocess(row) | |
| if processed is not None: | |
| yield processed | |
| def preprocess(self, row): | |
| sequence = row['sequence'] # Fetch the 'sequence' column | |
| if len(sequence) != self.seq_len: | |
| return None # skip problematic row! | |
| label = row['label'] # Fetch the 'label' column (or whatever target you use) | |
| if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif: | |
| sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif) | |
| ohe_sequence = one_hot_e(dna_seq=sequence) | |
| one_seq_tensor = torch.from_numpy(ohe_sequence).to(torch.int64) | |
| # Tokenize the sequence | |
| encoded_sequence_tokenized: BatchEncoding = self.bert_tokenizer(one_seq_tensor) | |
| input_ids = encoded_sequence_tokenized["input_ids"] | |
| # encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()} | |
| return input_ids, label | |
| # def collate_fn(batch): | |
| # sequences, labels = zip(*batch) | |
| # ohe_seq, ohe_seq_rc = sequences[0], sequences[1] | |
| # # Pad sequences to the maximum length in this batch | |
| # padded_sequences = pad_sequence(ohe_seq, batch_first=True, padding_value=0) | |
| # padded_sequences_rc = pad_sequence(ohe_seq_rc, batch_first=True, padding_value=0) | |
| # # Convert labels to a tensor | |
| # labels = torch.stack(labels) | |
| # return [padded_sequences, padded_sequences_rc], labels | |
| class MqtlDataModule(LightningDataModule): | |
| def __init__(self, train_ds, val_ds, test_ds, batch_size=16): | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False, | |
| # collate_fn=collate_fn, | |
| num_workers=1, | |
| # persistent_workers=True | |
| ) | |
| self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False, | |
| # collate_fn=collate_fn, | |
| num_workers=1, | |
| # persistent_workers=True | |
| ) | |
| self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False, | |
| # collate_fn=collate_fn, | |
| num_workers=1, | |
| # persistent_workers=True | |
| ) | |
| pass | |
| def prepare_data(self): | |
| pass | |
| def setup(self, stage: str) -> None: | |
| timber.info(f"inside setup: {stage = }") | |
| pass | |
| def train_dataloader(self) -> TRAIN_DATALOADERS: | |
| return self.train_loader | |
| def val_dataloader(self) -> EVAL_DATALOADERS: | |
| return self.validate_loader | |
| def test_dataloader(self) -> EVAL_DATALOADERS: | |
| return self.test_loader | |
| class MQtlBertClassifierLightningModule(LightningModule): | |
| def __init__(self, | |
| classifier: nn.Module, | |
| criterion=None, # nn.BCEWithLogitsLoss(), | |
| regularization: int = 2, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care | |
| l1_lambda=0.001, | |
| l2_wright_decay=0.001, | |
| *args: Any, | |
| **kwargs: Any): | |
| super().__init__(*args, **kwargs) | |
| self.classifier = classifier | |
| self.criterion = criterion | |
| self.train_metrics = TorchMetrics() | |
| self.validate_metrics = TorchMetrics() | |
| self.test_metrics = TorchMetrics() | |
| self.regularization = regularization | |
| self.l1_lambda = l1_lambda | |
| self.l2_weight_decay = l2_wright_decay | |
| pass | |
| def forward(self, x, *args: Any, **kwargs: Any) -> Any: | |
| input_ids: torch.tensor = x["input_ids"] | |
| return self.classifier.forward(input_ids) | |
| def configure_optimizers(self) -> OptimizerLRScheduler: | |
| # Here we add weight decay (L2 regularization) to the optimizer | |
| weight_decay = 0.0 | |
| if self.regularization == 2 or self.regularization == 3: | |
| weight_decay = self.l2_weight_decay | |
| return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=weight_decay) # , weight_decay=0.005) | |
| def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT: | |
| # Accuracy on training batch data | |
| x, y = batch | |
| preds = self.forward(x) | |
| loss = self.criterion(preds, y) | |
| if self.regularization == 1 or self.regularization == 3: # apply l1 regularization | |
| l1_norm = sum(p.abs().sum() for p in self.parameters()) | |
| loss += self.l1_lambda * l1_norm | |
| self.log("train_loss", loss) | |
| # calculate the scores start | |
| self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y) | |
| # calculate the scores end | |
| return loss | |
| def on_train_epoch_end(self) -> None: | |
| self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train") | |
| pass | |
| def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT: | |
| # Accuracy on validation batch data | |
| # print(f"debug { batch = }") | |
| x, y = batch | |
| preds = self.forward(x) | |
| loss = self.criterion(preds, y) | |
| self.log("valid_loss", loss) | |
| # calculate the scores start | |
| self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y) | |
| # calculate the scores end | |
| return loss | |
| def on_validation_epoch_end(self) -> None: | |
| self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue) | |
| return None | |
| def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT: | |
| # Accuracy on validation batch data | |
| x, y = batch | |
| preds = self.forward(x) | |
| loss = self.criterion(preds, y) | |
| self.log("test_loss", loss) # do we need this? | |
| # calculate the scores start | |
| self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y) | |
| # calculate the scores end | |
| return loss | |
| def on_test_epoch_end(self) -> None: | |
| self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta) | |
| return None | |
| pass | |
| def start_bert(classifier_model, criterion, m_optimizer=torch.optim.Adam, WINDOW=200, | |
| is_binned=True, is_debug=False, max_epochs=10, batch_size=8): | |
| file_suffix = "" | |
| if is_binned: | |
| file_suffix = "_binned" | |
| data_files = { | |
| # small samples | |
| "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv", | |
| "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv", | |
| "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv", | |
| # medium samples | |
| "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv", | |
| "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv", | |
| "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv", | |
| # large samples | |
| "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv", | |
| "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv", | |
| "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv", | |
| } | |
| dataset_map = None | |
| is_my_laptop = os.path.isfile("/src/inputdata/dataset_4000_test_binned.csv") | |
| if is_my_laptop: | |
| dataset_map = load_dataset("csv", data_files=data_files, streaming=True) | |
| else: | |
| dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True) | |
| tokenizer = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=PRETRAINED_MODEL_NAME, | |
| trust_remote_code=True) | |
| train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"], | |
| check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug, | |
| tokenizer=tokenizer, | |
| seq_len=WINDOW | |
| ) | |
| val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"], | |
| check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug, | |
| tokenizer=tokenizer, | |
| seq_len=WINDOW) | |
| test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"], | |
| check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug, | |
| tokenizer=tokenizer, | |
| seq_len=WINDOW) | |
| data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size) | |
| classifier_model = classifier_model #.to(DEVICE) | |
| try: | |
| classifier_model = classifier_model.from_pretrained(classifier_model.model_repository_name) | |
| except Exception as x: | |
| print(x) | |
| # classifier_module = MQtlBertClassifierLightningModule( | |
| # classifier=classifier_model, | |
| # regularization=2, criterion=criterion) | |
| # if os.path.exists(model_save_path): | |
| # classifier_module.load_state_dict(torch.load(model_save_path)) | |
| args = { | |
| "output_dir": "tmp", | |
| "num_train_epochs": 1, | |
| "per_device_train_batch_size": 1, | |
| "gradient_accumulation_steps": 4, | |
| "gradient_checkpointing": True, | |
| "learning_rate": 2e-5, | |
| } | |
| training_args = TrainingArguments(**args) | |
| trainer = Trainer(model=classifier_model, args=training_args, datamodule=data_module, max_epochs=max_epochs, | |
| precision="32") | |
| trainer.fit(model=classifier_model) | |
| timber.info("\n\n") | |
| trainer.test(model=classifier_model) | |
| timber.info("\n\n") | |
| # torch.save(classifier_module.state_dict(), model_save_path) # deprecated, use classifier_model.save_pretrained(model_subdirectory) instead | |
| # save locally | |
| model_subdirectory = classifier_model.model_repository_name | |
| classifier_model.save_pretrained(model_subdirectory) | |
| # push to the hub | |
| commit_message = f":tada: Push model for window size {WINDOW} from huggingface space" | |
| if is_my_laptop: | |
| commit_message = f":tada: Push model for window size {WINDOW} from zephyrus" | |
| classifier_model.push_to_hub( | |
| repo_id=f"fahimfarhan/{classifier_model.model_repository_name}", | |
| # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/ | |
| commit_message=commit_message # f":tada: Push model for window size {WINDOW}" | |
| ) | |
| # reload | |
| # classifier_model = classifier_model.from_pretrained(f"fahimfarhan/{classifier_model.model_repository_name}") | |
| # classifier_model = classifier_model.from_pretrained(model_subdirectory) | |
| pass | |
| class CommonAttentionLayer(nn.Module): | |
| def __init__(self, hidden_size, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.attention_linear = nn.Linear(hidden_size, 1) | |
| pass | |
| def forward(self, hidden_states): | |
| # Apply linear layer | |
| attn_weights = self.attention_linear(hidden_states) | |
| # Apply softmax to get attention scores | |
| attn_weights = torch.softmax(attn_weights, dim=1) | |
| # Apply attention weights to hidden states | |
| context_vector = torch.sum(attn_weights * hidden_states, dim=1) | |
| return context_vector, attn_weights | |
| class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss): | |
| def forward(self, input, target): | |
| return super().forward(input.squeeze(), target.float()) | |
| class HyenaDnaMQTLClassifier(nn.Module): | |
| def __init__(self, | |
| seq_len: int, model_repository_name: str, | |
| bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=PRETRAINED_MODEL_NAME), | |
| hidden_size=768, | |
| num_classes=1, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.seq_len = seq_len | |
| self.model_repository_name = model_repository_name | |
| self.model_name = "MQtlDnaBERT6Classifier" | |
| self.bert_model = bert_model | |
| self.attention = CommonAttentionLayer(hidden_size) | |
| self.classifier = nn.Linear(hidden_size, num_classes) | |
| pass | |
| def forward(self, input_ids: torch.tensor): | |
| """ | |
| # torch.Size([128, 1, 512]) --> [128, 512] | |
| input_ids = input_ids.squeeze(dim=1).to(DEVICE) | |
| # torch.Size([16, 1, 512]) --> [16, 512] | |
| attention_mask = attention_mask.squeeze(dim=1).to(DEVICE) | |
| token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE) | |
| """ | |
| bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(input_ids=input_ids) | |
| last_hidden_state = bert_output.last_hidden_state | |
| context_vector, ignore_attention_weight = self.attention(last_hidden_state) | |
| y = self.classifier(context_vector) | |
| return y | |
| if __name__ == '__main__': | |
| login_inside_huggingface_virtualmachine() | |
| WINDOW = 1000 | |
| some_model = BertModel.from_pretrained( | |
| pretrained_model_name_or_path=PRETRAINED_MODEL_NAME) # HyenaDnaMQTLClassifier(seq_len=WINDOW, model_repository_name="hyenadna-sm-32k-mqtl-classifier") | |
| criterion = None | |
| start_bert( | |
| classifier_model=some_model, | |
| criterion=criterion, | |
| WINDOW=WINDOW, | |
| is_debug=False, | |
| max_epochs=20, | |
| batch_size=16 | |
| ) | |
| pass | |