|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import json |
|
|
import os |
|
|
import argparse |
|
|
import time |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
import wandb |
|
|
|
|
|
|
|
|
from compressor_with_embeddings import Compressor, Decompressor, PrecomputedEmbeddingDataset |
|
|
from final_flow_model import AMPFlowMatcherCFGConcat, SinusoidalTimeEmbedding |
|
|
from cfg_dataset import CFGFlowDataset, create_cfg_dataloader |
|
|
|
|
|
|
|
|
ESM_DIM = 1280 |
|
|
COMP_RATIO = 16 |
|
|
COMP_DIM = ESM_DIM // COMP_RATIO |
|
|
MAX_SEQ_LEN = 50 |
|
|
|
|
|
|
|
|
BATCH_SIZE = 512 |
|
|
EPOCHS = 2000 |
|
|
BASE_LR = 8e-4 |
|
|
LR_MIN = 4e-4 |
|
|
WARMUP_STEPS = 4000 |
|
|
GPU_ID = 0 |
|
|
|
|
|
|
|
|
USE_MIXED_PRECISION = True |
|
|
GRADIENT_CLIP_NORM = 0.5 |
|
|
WEIGHT_DECAY = 0.01 |
|
|
VALIDATION_INTERVAL = 5000 |
|
|
CHECKPOINT_INTERVAL = 300 |
|
|
NUM_WORKERS = 32 |
|
|
|
|
|
|
|
|
CFG_DROPOUT_RATE = 0.15 |
|
|
|
|
|
class AMPFlowTrainerSingleGPUFullData: |
|
|
""" |
|
|
Optimized Single GPU training pipeline for AMP generation using ProtFlow methodology. |
|
|
Uses ALL available data with H100-optimized settings for overnight training. |
|
|
""" |
|
|
|
|
|
def __init__(self, embeddings_path, cfg_data_path, use_wandb=False): |
|
|
self.device = torch.device(f'cuda:{GPU_ID}') |
|
|
self.embeddings_path = embeddings_path |
|
|
self.cfg_data_path = cfg_data_path |
|
|
self.use_wandb = use_wandb |
|
|
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
print(f"Using GPU {GPU_ID} for optimized H100 training") |
|
|
print(f"Mixed precision: {USE_MIXED_PRECISION}") |
|
|
print(f"Batch size: {BATCH_SIZE}") |
|
|
print(f"Target epochs: {EPOCHS}") |
|
|
print(f"Learning rate: {BASE_LR} -> {LR_MIN}") |
|
|
|
|
|
|
|
|
if USE_MIXED_PRECISION: |
|
|
self.scaler = GradScaler() |
|
|
print("✓ Mixed precision training enabled (BF16)") |
|
|
|
|
|
|
|
|
if self.use_wandb: |
|
|
wandb.init( |
|
|
project="amp-flow-training", |
|
|
config={ |
|
|
"batch_size": BATCH_SIZE, |
|
|
"epochs": EPOCHS, |
|
|
"base_lr": BASE_LR, |
|
|
"lr_min": LR_MIN, |
|
|
"warmup_steps": WARMUP_STEPS, |
|
|
"mixed_precision": USE_MIXED_PRECISION, |
|
|
"gradient_clip": GRADIENT_CLIP_NORM, |
|
|
"weight_decay": WEIGHT_DECAY |
|
|
} |
|
|
) |
|
|
|
|
|
print(f"Loading ALL AMP embeddings from {embeddings_path}...") |
|
|
|
|
|
|
|
|
self._load_all_embeddings() |
|
|
|
|
|
|
|
|
print("Computing preprocessing statistics...") |
|
|
self._compute_preprocessing_stats() |
|
|
|
|
|
|
|
|
self._initialize_models() |
|
|
|
|
|
|
|
|
self._initialize_data() |
|
|
|
|
|
|
|
|
self._initialize_optimizer() |
|
|
|
|
|
print("✓ Optimized Single GPU training setup complete with FULL DATA!") |
|
|
|
|
|
def _load_all_embeddings(self): |
|
|
"""Load ALL peptide embeddings from the combined file.""" |
|
|
|
|
|
combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt") |
|
|
|
|
|
if os.path.exists(combined_path): |
|
|
print(f"Loading combined embeddings from {combined_path}...") |
|
|
self.embeddings = torch.load(combined_path, map_location=self.device) |
|
|
print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}") |
|
|
else: |
|
|
print("Combined embeddings file not found, loading individual files...") |
|
|
|
|
|
import glob |
|
|
|
|
|
embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt")) |
|
|
embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')] |
|
|
|
|
|
print(f"Found {len(embedding_files)} individual embedding files") |
|
|
|
|
|
|
|
|
embeddings_list = [] |
|
|
for file_path in embedding_files: |
|
|
try: |
|
|
embedding = torch.load(file_path) |
|
|
if embedding.dim() == 2: |
|
|
embeddings_list.append(embedding) |
|
|
else: |
|
|
print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load {file_path}: {e}") |
|
|
|
|
|
if not embeddings_list: |
|
|
raise ValueError("No valid embeddings found!") |
|
|
|
|
|
self.embeddings = torch.stack(embeddings_list) |
|
|
print(f"Loaded {len(self.embeddings)} embeddings from individual files") |
|
|
|
|
|
def _compute_preprocessing_stats(self): |
|
|
"""Compute normalization statistics for embeddings.""" |
|
|
|
|
|
flat_embeddings = self.embeddings.reshape(-1, ESM_DIM) |
|
|
|
|
|
|
|
|
mean = flat_embeddings.mean(dim=0) |
|
|
std = flat_embeddings.std(dim=0) |
|
|
min_val = flat_embeddings.min() |
|
|
max_val = flat_embeddings.max() |
|
|
|
|
|
self.stats = { |
|
|
'mean': mean, |
|
|
'std': std, |
|
|
'min': min_val, |
|
|
'max': max_val |
|
|
} |
|
|
|
|
|
|
|
|
torch.save(self.stats, 'normalization_stats.pt') |
|
|
print(f"✓ Statistics computed and saved:") |
|
|
print(f" Total embeddings: {len(self.embeddings):,}") |
|
|
print(f" Mean: {mean.mean():.4f} ± {mean.std():.4f}") |
|
|
print(f" Std: {std.mean():.4f} ± {std.std():.4f}") |
|
|
print(f" Range: [{min_val:.4f}, {max_val:.4f}]") |
|
|
|
|
|
def _initialize_models(self): |
|
|
"""Initialize compressor, decompressor, and flow model.""" |
|
|
print("Initializing models...") |
|
|
|
|
|
|
|
|
self.compressor = Compressor().to(self.device) |
|
|
self.decompressor = Decompressor().to(self.device) |
|
|
|
|
|
self.compressor.load_state_dict(torch.load('final_compressor_model.pth', map_location=self.device)) |
|
|
self.decompressor.load_state_dict(torch.load('final_decompressor_model.pth', map_location=self.device)) |
|
|
|
|
|
|
|
|
self.flow_model = AMPFlowMatcherCFGConcat( |
|
|
hidden_dim=480, |
|
|
compressed_dim=COMP_DIM, |
|
|
n_layers=12, |
|
|
n_heads=16, |
|
|
dim_ff=3072, |
|
|
max_seq_len=25, |
|
|
use_cfg=True |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
try: |
|
|
self.flow_model = torch.compile(self.flow_model, mode="reduce-overhead") |
|
|
print("✓ Model compiled with torch.compile for speedup") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Model compilation failed: {e}") |
|
|
|
|
|
|
|
|
self.compressor.train() |
|
|
self.decompressor.train() |
|
|
self.flow_model.train() |
|
|
|
|
|
print(f"✓ Models initialized:") |
|
|
print(f" Compressor parameters: {sum(p.numel() for p in self.compressor.parameters()):,}") |
|
|
print(f" Decompressor parameters: {sum(p.numel() for p in self.decompressor.parameters()):,}") |
|
|
print(f" Flow model parameters: {sum(p.numel() for p in self.flow_model.parameters()):,}") |
|
|
|
|
|
def _initialize_data(self): |
|
|
"""Initialize datasets and dataloaders with FULL data.""" |
|
|
print("Initializing datasets with FULL data...") |
|
|
|
|
|
|
|
|
self.cfg_dataset = CFGFlowDataset( |
|
|
embeddings_path=self.embeddings_path, |
|
|
cfg_data_path=self.cfg_data_path, |
|
|
use_masked_labels=True, |
|
|
max_seq_len=MAX_SEQ_LEN, |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
|
|
|
self.dataloader = create_cfg_dataloader( |
|
|
self.cfg_dataset, |
|
|
batch_size=BATCH_SIZE, |
|
|
shuffle=True, |
|
|
num_workers=NUM_WORKERS |
|
|
) |
|
|
|
|
|
|
|
|
self.total_steps = len(self.dataloader) * EPOCHS |
|
|
self.validation_steps = VALIDATION_INTERVAL |
|
|
|
|
|
print(f"✓ Dataset initialized with FULL data:") |
|
|
print(f" Total samples: {len(self.cfg_dataset):,}") |
|
|
print(f" Batch size: {BATCH_SIZE}") |
|
|
print(f" Batches per epoch: {len(self.dataloader):,}") |
|
|
print(f" Total training steps: {self.total_steps:,}") |
|
|
print(f" Validation every: {self.validation_steps:,} steps") |
|
|
|
|
|
def _initialize_optimizer(self): |
|
|
"""Initialize optimizer and learning rate scheduler.""" |
|
|
print("Initializing optimizer and scheduler...") |
|
|
|
|
|
|
|
|
self.optimizer = optim.AdamW( |
|
|
self.flow_model.parameters(), |
|
|
lr=BASE_LR, |
|
|
weight_decay=WEIGHT_DECAY, |
|
|
betas=(0.9, 0.98), |
|
|
eps=1e-6 |
|
|
) |
|
|
|
|
|
|
|
|
warmup_scheduler = LinearLR( |
|
|
self.optimizer, |
|
|
start_factor=0.1, |
|
|
end_factor=1.0, |
|
|
total_iters=WARMUP_STEPS |
|
|
) |
|
|
|
|
|
main_scheduler = CosineAnnealingLR( |
|
|
self.optimizer, |
|
|
T_max=self.total_steps - WARMUP_STEPS, |
|
|
eta_min=LR_MIN |
|
|
) |
|
|
|
|
|
self.scheduler = SequentialLR( |
|
|
self.optimizer, |
|
|
schedulers=[warmup_scheduler, main_scheduler], |
|
|
milestones=[WARMUP_STEPS] |
|
|
) |
|
|
|
|
|
print(f"✓ Optimizer initialized:") |
|
|
print(f" Base LR: {BASE_LR}") |
|
|
print(f" Min LR: {LR_MIN}") |
|
|
print(f" Warmup steps: {WARMUP_STEPS}") |
|
|
print(f" Weight decay: {WEIGHT_DECAY}") |
|
|
print(f" Gradient clip norm: {GRADIENT_CLIP_NORM}") |
|
|
|
|
|
def _preprocess_batch(self, batch): |
|
|
"""Preprocess a batch of data for training.""" |
|
|
|
|
|
embeddings = batch['embeddings'].to(self.device) |
|
|
labels = batch['labels'].to(self.device) |
|
|
|
|
|
|
|
|
m, s = self.stats['mean'].to(self.device), self.stats['std'].to(self.device) |
|
|
mn, mx = self.stats['min'].to(self.device), self.stats['max'].to(self.device) |
|
|
|
|
|
embeddings = (embeddings - m) / (s + 1e-8) |
|
|
embeddings = (embeddings - mn) / (mx - mn + 1e-8) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
compressed = self.compressor(embeddings) |
|
|
|
|
|
return compressed, labels |
|
|
|
|
|
def _compute_validation_metrics(self): |
|
|
"""Compute validation metrics on a subset of data.""" |
|
|
self.flow_model.eval() |
|
|
val_losses = [] |
|
|
|
|
|
|
|
|
val_samples = min(1000, len(self.cfg_dataset)) |
|
|
val_indices = torch.randperm(len(self.cfg_dataset))[:val_samples] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(0, val_samples, BATCH_SIZE): |
|
|
batch_indices = val_indices[i:i+BATCH_SIZE] |
|
|
batch_data = [self.cfg_dataset[idx] for idx in batch_indices] |
|
|
|
|
|
|
|
|
embeddings = torch.stack([item['embedding'] for item in batch_data]) |
|
|
labels = torch.stack([item['label'] for item in batch_data]) |
|
|
|
|
|
|
|
|
compressed, labels = self._preprocess_batch({ |
|
|
'embeddings': embeddings, |
|
|
'labels': labels |
|
|
}) |
|
|
|
|
|
B, L, D = compressed.shape |
|
|
|
|
|
|
|
|
t = torch.rand(B, device=self.device) |
|
|
|
|
|
|
|
|
eps = torch.randn_like(compressed) |
|
|
|
|
|
|
|
|
xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps |
|
|
|
|
|
|
|
|
vt_pred = self.flow_model(xt, t, labels=labels) |
|
|
|
|
|
|
|
|
vt_target = eps - compressed |
|
|
|
|
|
|
|
|
loss = F.mse_loss(vt_pred, vt_target) |
|
|
val_losses.append(loss.item()) |
|
|
|
|
|
self.flow_model.train() |
|
|
return np.mean(val_losses) |
|
|
|
|
|
def train_flow_matching(self): |
|
|
"""Train the flow matching model with FULL data and optimizations.""" |
|
|
print(f"🚀 Starting Optimized Single GPU Flow Matching Training with FULL DATA") |
|
|
print(f"GPU: {GPU_ID}") |
|
|
print(f"Total iterations: {EPOCHS}") |
|
|
print(f"Batch size: {BATCH_SIZE}") |
|
|
print(f"Total samples: {len(self.cfg_dataset):,}") |
|
|
print(f"Mixed precision: {USE_MIXED_PRECISION}") |
|
|
print(f"Estimated time: ~8-10 hours (overnight training with ALL data)") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
best_loss = float('inf') |
|
|
losses = [] |
|
|
val_losses = [] |
|
|
global_step = 0 |
|
|
start_time = time.time() |
|
|
|
|
|
for epoch in tqdm(range(EPOCHS), desc="Training Flow Model"): |
|
|
epoch_losses = [] |
|
|
epoch_start_time = time.time() |
|
|
|
|
|
for batch_idx, batch in enumerate(self.dataloader): |
|
|
|
|
|
compressed, labels = self._preprocess_batch(batch) |
|
|
B, L, D = compressed.shape |
|
|
|
|
|
|
|
|
if torch.rand(1).item() < CFG_DROPOUT_RATE: |
|
|
labels = torch.full_like(labels, fill_value=-1) |
|
|
|
|
|
|
|
|
t = torch.rand(B, device=self.device) |
|
|
|
|
|
|
|
|
eps = torch.randn_like(compressed) |
|
|
|
|
|
|
|
|
xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps |
|
|
|
|
|
|
|
|
if USE_MIXED_PRECISION: |
|
|
with autocast(dtype=torch.bfloat16): |
|
|
vt_pred = self.flow_model(xt, t, labels=labels) |
|
|
vt_target = eps - compressed |
|
|
loss = F.mse_loss(vt_pred, vt_target) |
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
self.scaler.scale(loss).backward() |
|
|
|
|
|
|
|
|
self.scaler.unscale_(self.optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM) |
|
|
|
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
else: |
|
|
|
|
|
vt_pred = self.flow_model(xt, t, labels=labels) |
|
|
vt_target = eps - compressed |
|
|
loss = F.mse_loss(vt_pred, vt_target) |
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM) |
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
self.scheduler.step() |
|
|
|
|
|
epoch_losses.append(loss.item()) |
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
if batch_idx % 100 == 0: |
|
|
current_lr = self.scheduler.get_last_lr()[0] |
|
|
elapsed_time = time.time() - start_time |
|
|
steps_per_sec = global_step / elapsed_time |
|
|
eta_hours = (self.total_steps - global_step) / steps_per_sec / 3600 |
|
|
|
|
|
print(f"Epoch {epoch:4d} | Step {global_step:6d}/{self.total_steps:6d} | " |
|
|
f"Loss: {loss.item():.6f} | LR: {current_lr:.2e} | " |
|
|
f"Speed: {steps_per_sec:.1f} steps/s | ETA: {eta_hours:.1f}h") |
|
|
|
|
|
|
|
|
if self.use_wandb: |
|
|
wandb.log({ |
|
|
'train/loss': loss.item(), |
|
|
'train/learning_rate': current_lr, |
|
|
'train/steps_per_sec': steps_per_sec, |
|
|
'train/global_step': global_step |
|
|
}) |
|
|
|
|
|
|
|
|
if global_step % self.validation_steps == 0: |
|
|
val_loss = self._compute_validation_metrics() |
|
|
val_losses.append(val_loss) |
|
|
|
|
|
print(f"Validation at step {global_step}: Loss = {val_loss:.6f}") |
|
|
|
|
|
if self.use_wandb: |
|
|
wandb.log({ |
|
|
'val/loss': val_loss, |
|
|
'val/global_step': global_step |
|
|
}) |
|
|
|
|
|
|
|
|
if val_loss < best_loss: |
|
|
best_loss = val_loss |
|
|
self._save_checkpoint(epoch, val_loss, global_step, is_final=False, is_best=True) |
|
|
|
|
|
|
|
|
avg_loss = np.mean(epoch_losses) |
|
|
losses.append(avg_loss) |
|
|
epoch_time = time.time() - epoch_start_time |
|
|
|
|
|
print(f"Epoch {epoch:4d} | Avg Loss: {avg_loss:.6f} | " |
|
|
f"LR: {self.scheduler.get_last_lr()[0]:.2e} | " |
|
|
f"Time: {epoch_time:.1f}s | Samples: {len(self.cfg_dataset):,}") |
|
|
|
|
|
|
|
|
if (epoch + 1) % CHECKPOINT_INTERVAL == 0: |
|
|
self._save_checkpoint(epoch, avg_loss, global_step, is_final=True) |
|
|
|
|
|
|
|
|
self._save_checkpoint(EPOCHS - 1, losses[-1], global_step, is_final=True) |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
print("=" * 60) |
|
|
print("🎉 Optimized Training Complete with FULL DATA!") |
|
|
print(f"Best validation loss: {best_loss:.6f}") |
|
|
print(f"Total training time: {total_time/3600:.1f} hours") |
|
|
print(f"Total samples used: {len(self.cfg_dataset):,}") |
|
|
print(f"Final model saved as: amp_flow_model_final_optimized.pth") |
|
|
|
|
|
return losses, val_losses |
|
|
|
|
|
def _save_checkpoint(self, step, loss, global_step, is_final=False, is_best=False): |
|
|
"""Save model checkpoint.""" |
|
|
|
|
|
output_dir = '/data2/edwardsun/flow_checkpoints' |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
if is_best: |
|
|
filename = os.path.join(output_dir, 'amp_flow_model_best_optimized.pth') |
|
|
elif is_final: |
|
|
filename = os.path.join(output_dir, 'amp_flow_model_final_optimized.pth') |
|
|
else: |
|
|
filename = os.path.join(output_dir, f'amp_flow_checkpoint_optimized_step_{step:04d}.pth') |
|
|
|
|
|
checkpoint = { |
|
|
'step': step, |
|
|
'global_step': global_step, |
|
|
'loss': loss, |
|
|
'flow_model_state_dict': self.flow_model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'scheduler_state_dict': self.scheduler.state_dict(), |
|
|
'stats': self.stats, |
|
|
'total_samples': len(self.cfg_dataset), |
|
|
'config': { |
|
|
'batch_size': BATCH_SIZE, |
|
|
'epochs': EPOCHS, |
|
|
'base_lr': BASE_LR, |
|
|
'lr_min': LR_MIN, |
|
|
'warmup_steps': WARMUP_STEPS, |
|
|
'mixed_precision': USE_MIXED_PRECISION, |
|
|
'gradient_clip': GRADIENT_CLIP_NORM, |
|
|
'weight_decay': WEIGHT_DECAY |
|
|
} |
|
|
} |
|
|
|
|
|
torch.save(checkpoint, filename) |
|
|
print(f"✓ Checkpoint saved: {filename} (loss: {loss:.6f}, step: {global_step})") |
|
|
|
|
|
def main(): |
|
|
"""Main training function.""" |
|
|
global BATCH_SIZE, EPOCHS |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Optimized Single GPU AMP Flow Training with FULL DATA') |
|
|
parser.add_argument('--embeddings', default='/data2/edwardsun/flow_project/peptide_embeddings/', |
|
|
help='Path to peptide embeddings directory') |
|
|
parser.add_argument('--cfg_data', default='/data2/edwardsun/flow_project/test_uniprot_processed/uniprot_processed_data.json', |
|
|
help='Path to FULL CFG data file') |
|
|
parser.add_argument('--use_wandb', action='store_true', help='Use wandb for logging') |
|
|
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training') |
|
|
parser.add_argument('--epochs', type=int, default=EPOCHS, help='Number of training epochs') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.batch_size != BATCH_SIZE: |
|
|
BATCH_SIZE = args.batch_size |
|
|
if args.epochs != EPOCHS: |
|
|
EPOCHS = args.epochs |
|
|
|
|
|
print(f"Starting optimized training with batch_size={BATCH_SIZE}, epochs={EPOCHS}") |
|
|
|
|
|
|
|
|
trainer = AMPFlowTrainerSingleGPUFullData(args.embeddings, args.cfg_data, args.use_wandb) |
|
|
|
|
|
|
|
|
losses, val_losses = trainer.train_flow_matching() |
|
|
|
|
|
print("Optimized training completed successfully with FULL DATA!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |