FlowFinal / src /amp_flow_training_single_gpu_full_data.py
esunAI's picture
Add amp_flow_training_single_gpu_full_data.py
321da93 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
import numpy as np
from tqdm import tqdm
import json
import os
import argparse
import time
from torch.cuda.amp import autocast, GradScaler
import wandb # For logging (optional)
# Import your existing components
from compressor_with_embeddings import Compressor, Decompressor, PrecomputedEmbeddingDataset
from final_flow_model import AMPFlowMatcherCFGConcat, SinusoidalTimeEmbedding
from cfg_dataset import CFGFlowDataset, create_cfg_dataloader
# ---------------- Optimized Configuration for H100 ----------------
ESM_DIM = 1280 # ESM-2 hidden dim (esm2_t33_650M_UR50D)
COMP_RATIO = 16 # compression factor
COMP_DIM = ESM_DIM // COMP_RATIO
MAX_SEQ_LEN = 50 # Actual sequence length from final_sequence_encoder.py
# OPTIMIZED H100 hyperparameters - HIGH THROUGHPUT + STABLE TRAINING
BATCH_SIZE = 512 # PUSH H100 TO LIMITS - using ~70GB memory
EPOCHS = 2000 # Slightly more epochs with safer LR for same 5-6 hour target
BASE_LR = 8e-4 # SAFE but effective LR - 2x original, not 4x
LR_MIN = 4e-4 # Conservative minimum learning rate
WARMUP_STEPS = 4000 # Gentler warmup to avoid explosion
GPU_ID = 0 # Use GPU 0
# Training optimizations
USE_MIXED_PRECISION = True # BF16 for H100
GRADIENT_CLIP_NORM = 0.5 # TIGHTER gradient clipping for flow matching stability
WEIGHT_DECAY = 0.01 # Weight decay for regularization
VALIDATION_INTERVAL = 5000 # Validate every 5K steps (more frequent)
CHECKPOINT_INTERVAL = 300 # Save checkpoint every 300 epochs (more frequent)
NUM_WORKERS = 32 # MAXIMIZED data loading workers for H100
# CFG training parameters
CFG_DROPOUT_RATE = 0.15 # 15% of batches as unconditional for CFG
class AMPFlowTrainerSingleGPUFullData:
"""
Optimized Single GPU training pipeline for AMP generation using ProtFlow methodology.
Uses ALL available data with H100-optimized settings for overnight training.
"""
def __init__(self, embeddings_path, cfg_data_path, use_wandb=False):
self.device = torch.device(f'cuda:{GPU_ID}')
self.embeddings_path = embeddings_path
self.cfg_data_path = cfg_data_path
self.use_wandb = use_wandb
# Enable H100 optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print(f"Using GPU {GPU_ID} for optimized H100 training")
print(f"Mixed precision: {USE_MIXED_PRECISION}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Target epochs: {EPOCHS}")
print(f"Learning rate: {BASE_LR} -> {LR_MIN}")
# Initialize mixed precision training
if USE_MIXED_PRECISION:
self.scaler = GradScaler()
print("✓ Mixed precision training enabled (BF16)")
# Initialize wandb if requested
if self.use_wandb:
wandb.init(
project="amp-flow-training",
config={
"batch_size": BATCH_SIZE,
"epochs": EPOCHS,
"base_lr": BASE_LR,
"lr_min": LR_MIN,
"warmup_steps": WARMUP_STEPS,
"mixed_precision": USE_MIXED_PRECISION,
"gradient_clip": GRADIENT_CLIP_NORM,
"weight_decay": WEIGHT_DECAY
}
)
print(f"Loading ALL AMP embeddings from {embeddings_path}...")
# Load ALL embeddings (use the combined file instead of individual files)
self._load_all_embeddings()
# Compute normalization statistics
print("Computing preprocessing statistics...")
self._compute_preprocessing_stats()
# Initialize models
self._initialize_models()
# Initialize datasets and dataloaders
self._initialize_data()
# Initialize optimizer and scheduler
self._initialize_optimizer()
print("✓ Optimized Single GPU training setup complete with FULL DATA!")
def _load_all_embeddings(self):
"""Load ALL peptide embeddings from the combined file."""
# Try to load the combined embeddings file first
combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt")
if os.path.exists(combined_path):
print(f"Loading combined embeddings from {combined_path}...")
self.embeddings = torch.load(combined_path, map_location=self.device)
print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}")
else:
print("Combined embeddings file not found, loading individual files...")
# Fallback to individual files
import glob
embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt"))
embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')]
print(f"Found {len(embedding_files)} individual embedding files")
# Load and stack all embeddings
embeddings_list = []
for file_path in embedding_files:
try:
embedding = torch.load(file_path)
if embedding.dim() == 2: # (seq_len, hidden_dim)
embeddings_list.append(embedding)
else:
print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}")
except Exception as e:
print(f"Warning: Could not load {file_path}: {e}")
if not embeddings_list:
raise ValueError("No valid embeddings found!")
self.embeddings = torch.stack(embeddings_list)
print(f"Loaded {len(self.embeddings)} embeddings from individual files")
def _compute_preprocessing_stats(self):
"""Compute normalization statistics for embeddings."""
# Flatten all embeddings
flat_embeddings = self.embeddings.reshape(-1, ESM_DIM)
# Compute statistics
mean = flat_embeddings.mean(dim=0)
std = flat_embeddings.std(dim=0)
min_val = flat_embeddings.min()
max_val = flat_embeddings.max()
self.stats = {
'mean': mean,
'std': std,
'min': min_val,
'max': max_val
}
# Save statistics
torch.save(self.stats, 'normalization_stats.pt')
print(f"✓ Statistics computed and saved:")
print(f" Total embeddings: {len(self.embeddings):,}")
print(f" Mean: {mean.mean():.4f} ± {mean.std():.4f}")
print(f" Std: {std.mean():.4f} ± {std.std():.4f}")
print(f" Range: [{min_val:.4f}, {max_val:.4f}]")
def _initialize_models(self):
"""Initialize compressor, decompressor, and flow model."""
print("Initializing models...")
# Load pre-trained compressor and decompressor
self.compressor = Compressor().to(self.device)
self.decompressor = Decompressor().to(self.device)
self.compressor.load_state_dict(torch.load('final_compressor_model.pth', map_location=self.device))
self.decompressor.load_state_dict(torch.load('final_decompressor_model.pth', map_location=self.device))
# Initialize flow model with CFG
self.flow_model = AMPFlowMatcherCFGConcat(
hidden_dim=480,
compressed_dim=COMP_DIM,
n_layers=12,
n_heads=16,
dim_ff=3072,
max_seq_len=25, # MAX_SEQ_LEN // 2 due to pooling
use_cfg=True
).to(self.device)
# Compile model for PyTorch 2.x speedup (if available)
try:
self.flow_model = torch.compile(self.flow_model, mode="reduce-overhead")
print("✓ Model compiled with torch.compile for speedup")
except Exception as e:
print(f"⚠️ Model compilation failed: {e}")
# Set models to training mode
self.compressor.train()
self.decompressor.train()
self.flow_model.train()
print(f"✓ Models initialized:")
print(f" Compressor parameters: {sum(p.numel() for p in self.compressor.parameters()):,}")
print(f" Decompressor parameters: {sum(p.numel() for p in self.decompressor.parameters()):,}")
print(f" Flow model parameters: {sum(p.numel() for p in self.flow_model.parameters()):,}")
def _initialize_data(self):
"""Initialize datasets and dataloaders with FULL data."""
print("Initializing datasets with FULL data...")
# Create CFG dataset with FULL UniProt data
self.cfg_dataset = CFGFlowDataset(
embeddings_path=self.embeddings_path,
cfg_data_path=self.cfg_data_path,
use_masked_labels=True,
max_seq_len=MAX_SEQ_LEN,
device=self.device
)
# Create dataloader with optimized settings
self.dataloader = create_cfg_dataloader(
self.cfg_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS
)
# Calculate total steps and validation intervals
self.total_steps = len(self.dataloader) * EPOCHS
self.validation_steps = VALIDATION_INTERVAL
print(f"✓ Dataset initialized with FULL data:")
print(f" Total samples: {len(self.cfg_dataset):,}")
print(f" Batch size: {BATCH_SIZE}")
print(f" Batches per epoch: {len(self.dataloader):,}")
print(f" Total training steps: {self.total_steps:,}")
print(f" Validation every: {self.validation_steps:,} steps")
def _initialize_optimizer(self):
"""Initialize optimizer and learning rate scheduler."""
print("Initializing optimizer and scheduler...")
# Optimizer for flow model only (compressor/decompressor are frozen)
self.optimizer = optim.AdamW(
self.flow_model.parameters(),
lr=BASE_LR,
weight_decay=WEIGHT_DECAY,
betas=(0.9, 0.98), # Optimized betas for flow matching
eps=1e-6 # Lower epsilon for numerical stability
)
# Learning rate scheduler with proper warmup and cosine annealing
warmup_scheduler = LinearLR(
self.optimizer,
start_factor=0.1,
end_factor=1.0,
total_iters=WARMUP_STEPS
)
main_scheduler = CosineAnnealingLR(
self.optimizer,
T_max=self.total_steps - WARMUP_STEPS,
eta_min=LR_MIN
)
self.scheduler = SequentialLR(
self.optimizer,
schedulers=[warmup_scheduler, main_scheduler],
milestones=[WARMUP_STEPS]
)
print(f"✓ Optimizer initialized:")
print(f" Base LR: {BASE_LR}")
print(f" Min LR: {LR_MIN}")
print(f" Warmup steps: {WARMUP_STEPS}")
print(f" Weight decay: {WEIGHT_DECAY}")
print(f" Gradient clip norm: {GRADIENT_CLIP_NORM}")
def _preprocess_batch(self, batch):
"""Preprocess a batch of data for training."""
# Extract data
embeddings = batch['embeddings'].to(self.device) # (B, L, ESM_DIM)
labels = batch['labels'].to(self.device) # (B,)
# Normalize embeddings
m, s = self.stats['mean'].to(self.device), self.stats['std'].to(self.device)
mn, mx = self.stats['min'].to(self.device), self.stats['max'].to(self.device)
embeddings = (embeddings - m) / (s + 1e-8)
embeddings = (embeddings - mn) / (mx - mn + 1e-8)
# Compress embeddings
with torch.no_grad():
compressed = self.compressor(embeddings) # (B, L, COMP_DIM)
return compressed, labels
def _compute_validation_metrics(self):
"""Compute validation metrics on a subset of data."""
self.flow_model.eval()
val_losses = []
# Use a subset of data for validation
val_samples = min(1000, len(self.cfg_dataset))
val_indices = torch.randperm(len(self.cfg_dataset))[:val_samples]
with torch.no_grad():
for i in range(0, val_samples, BATCH_SIZE):
batch_indices = val_indices[i:i+BATCH_SIZE]
batch_data = [self.cfg_dataset[idx] for idx in batch_indices]
# Collate batch
embeddings = torch.stack([item['embedding'] for item in batch_data])
labels = torch.stack([item['label'] for item in batch_data])
# Preprocess
compressed, labels = self._preprocess_batch({
'embeddings': embeddings,
'labels': labels
})
B, L, D = compressed.shape
# Sample random time
t = torch.rand(B, device=self.device)
# Sample random noise
eps = torch.randn_like(compressed)
# Compute target
xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps
# Predict vector field
vt_pred = self.flow_model(xt, t, labels=labels)
# Target vector field
vt_target = eps - compressed
# Compute loss
loss = F.mse_loss(vt_pred, vt_target)
val_losses.append(loss.item())
self.flow_model.train()
return np.mean(val_losses)
def train_flow_matching(self):
"""Train the flow matching model with FULL data and optimizations."""
print(f"🚀 Starting Optimized Single GPU Flow Matching Training with FULL DATA")
print(f"GPU: {GPU_ID}")
print(f"Total iterations: {EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Total samples: {len(self.cfg_dataset):,}")
print(f"Mixed precision: {USE_MIXED_PRECISION}")
print(f"Estimated time: ~8-10 hours (overnight training with ALL data)")
print("=" * 60)
# Training loop
best_loss = float('inf')
losses = []
val_losses = []
global_step = 0
start_time = time.time()
for epoch in tqdm(range(EPOCHS), desc="Training Flow Model"):
epoch_losses = []
epoch_start_time = time.time()
for batch_idx, batch in enumerate(self.dataloader):
# Preprocess batch
compressed, labels = self._preprocess_batch(batch)
B, L, D = compressed.shape
# CFG training: randomly mask some labels for unconditional training
if torch.rand(1).item() < CFG_DROPOUT_RATE:
labels = torch.full_like(labels, fill_value=-1) # Unconditional
# Sample random time
t = torch.rand(B, device=self.device) # (B,)
# Sample random noise
eps = torch.randn_like(compressed) # (B, L, D)
# Compute target: x_t = (1-t) * x_0 + t * eps
xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps
# Forward pass with mixed precision
if USE_MIXED_PRECISION:
with autocast(dtype=torch.bfloat16):
vt_pred = self.flow_model(xt, t, labels=labels) # (B, L, D)
vt_target = eps - compressed # (B, L, D)
loss = F.mse_loss(vt_pred, vt_target)
# Backward pass with gradient scaling
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
# Gradient clipping
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# Standard training
vt_pred = self.flow_model(xt, t, labels=labels) # (B, L, D)
vt_target = eps - compressed # (B, L, D)
loss = F.mse_loss(vt_pred, vt_target)
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM)
self.optimizer.step()
# Update learning rate
self.scheduler.step()
epoch_losses.append(loss.item())
global_step += 1
# Logging
if batch_idx % 100 == 0:
current_lr = self.scheduler.get_last_lr()[0]
elapsed_time = time.time() - start_time
steps_per_sec = global_step / elapsed_time
eta_hours = (self.total_steps - global_step) / steps_per_sec / 3600
print(f"Epoch {epoch:4d} | Step {global_step:6d}/{self.total_steps:6d} | "
f"Loss: {loss.item():.6f} | LR: {current_lr:.2e} | "
f"Speed: {steps_per_sec:.1f} steps/s | ETA: {eta_hours:.1f}h")
# Log to wandb
if self.use_wandb:
wandb.log({
'train/loss': loss.item(),
'train/learning_rate': current_lr,
'train/steps_per_sec': steps_per_sec,
'train/global_step': global_step
})
# Validation
if global_step % self.validation_steps == 0:
val_loss = self._compute_validation_metrics()
val_losses.append(val_loss)
print(f"Validation at step {global_step}: Loss = {val_loss:.6f}")
if self.use_wandb:
wandb.log({
'val/loss': val_loss,
'val/global_step': global_step
})
# Early stopping check
if val_loss < best_loss:
best_loss = val_loss
self._save_checkpoint(epoch, val_loss, global_step, is_final=False, is_best=True)
# Compute epoch statistics
avg_loss = np.mean(epoch_losses)
losses.append(avg_loss)
epoch_time = time.time() - epoch_start_time
print(f"Epoch {epoch:4d} | Avg Loss: {avg_loss:.6f} | "
f"LR: {self.scheduler.get_last_lr()[0]:.2e} | "
f"Time: {epoch_time:.1f}s | Samples: {len(self.cfg_dataset):,}")
# Save checkpoint
if (epoch + 1) % CHECKPOINT_INTERVAL == 0:
self._save_checkpoint(epoch, avg_loss, global_step, is_final=True)
# Save final model
self._save_checkpoint(EPOCHS - 1, losses[-1], global_step, is_final=True)
total_time = time.time() - start_time
print("=" * 60)
print("🎉 Optimized Training Complete with FULL DATA!")
print(f"Best validation loss: {best_loss:.6f}")
print(f"Total training time: {total_time/3600:.1f} hours")
print(f"Total samples used: {len(self.cfg_dataset):,}")
print(f"Final model saved as: amp_flow_model_final_optimized.pth")
return losses, val_losses
def _save_checkpoint(self, step, loss, global_step, is_final=False, is_best=False):
"""Save model checkpoint."""
# Create output directory if it doesn't exist
output_dir = '/data2/edwardsun/flow_checkpoints'
os.makedirs(output_dir, exist_ok=True)
if is_best:
filename = os.path.join(output_dir, 'amp_flow_model_best_optimized.pth')
elif is_final:
filename = os.path.join(output_dir, 'amp_flow_model_final_optimized.pth')
else:
filename = os.path.join(output_dir, f'amp_flow_checkpoint_optimized_step_{step:04d}.pth')
checkpoint = {
'step': step,
'global_step': global_step,
'loss': loss,
'flow_model_state_dict': self.flow_model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'stats': self.stats,
'total_samples': len(self.cfg_dataset),
'config': {
'batch_size': BATCH_SIZE,
'epochs': EPOCHS,
'base_lr': BASE_LR,
'lr_min': LR_MIN,
'warmup_steps': WARMUP_STEPS,
'mixed_precision': USE_MIXED_PRECISION,
'gradient_clip': GRADIENT_CLIP_NORM,
'weight_decay': WEIGHT_DECAY
}
}
torch.save(checkpoint, filename)
print(f"✓ Checkpoint saved: {filename} (loss: {loss:.6f}, step: {global_step})")
def main():
"""Main training function."""
global BATCH_SIZE, EPOCHS
parser = argparse.ArgumentParser(description='Optimized Single GPU AMP Flow Training with FULL DATA')
parser.add_argument('--embeddings', default='/data2/edwardsun/flow_project/peptide_embeddings/',
help='Path to peptide embeddings directory')
parser.add_argument('--cfg_data', default='/data2/edwardsun/flow_project/test_uniprot_processed/uniprot_processed_data.json',
help='Path to FULL CFG data file')
parser.add_argument('--use_wandb', action='store_true', help='Use wandb for logging')
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training')
parser.add_argument('--epochs', type=int, default=EPOCHS, help='Number of training epochs')
args = parser.parse_args()
# Update global variables if provided
if args.batch_size != BATCH_SIZE:
BATCH_SIZE = args.batch_size
if args.epochs != EPOCHS:
EPOCHS = args.epochs
print(f"Starting optimized training with batch_size={BATCH_SIZE}, epochs={EPOCHS}")
# Initialize trainer
trainer = AMPFlowTrainerSingleGPUFullData(args.embeddings, args.cfg_data, args.use_wandb)
# Start training
losses, val_losses = trainer.train_flow_matching()
print("Optimized training completed successfully with FULL DATA!")
if __name__ == "__main__":
main()