""" Main training script for ResShift diffusion model. This script initializes the Trainer class and runs the main training loop. """ import multiprocessing # Fix CUDA multiprocessing: Set start method to 'spawn' for compatibility with CUDA # This is required when using DataLoader with num_workers > 0 on systems where # CUDA is initialized before worker processes are created (Colab, some Linux setups) # Must be set before any CUDA initialization or DataLoader creation try: multiprocessing.set_start_method('spawn', force=True) except RuntimeError: # Start method already set (e.g., in another module), ignore pass from trainer import Trainer from config import ( iterations, batch_size, microbatch, learning_rate, warmup_iterations, save_freq, log_freq, T, k, val_freq ) import torch import wandb def train(resume_ckpt=None): """ Main training loop that integrates all components. Training flow: 1. Build model and dataloader 2. Setup optimization 3. Training loop: - Get batch from dataloader - Training step (forward, backward, optimizer step) - Adjust learning rate - Log metrics and images - Save checkpoints Args: resume_ckpt: Path to checkpoint file to resume from (optional) """ # Initialize trainer trainer = Trainer(resume_ckpt=resume_ckpt) print("=" * 100) if resume_ckpt: print("Resuming Training") else: print("Starting Training") print("=" * 100) # Build model (Component 2) trainer.build_model() # Resume from checkpoint if provided (must be after model is built) if resume_ckpt: trainer.resume_from_ckpt(resume_ckpt) # Setup optimization (Component 1) trainer.setup_optimization() # Build dataloader (Component 3) trainer.build_dataloader() # Initialize training trainer.model.train() train_iter = iter(trainer.dataloaders['train']) print(f"\nTraining Configuration:") print(f" - Total iterations: {iterations}") print(f" - Batch size: {batch_size}") print(f" - Micro-batch size: {microbatch}") print(f" - Learning rate: {learning_rate}") print(f" - Warmup iterations: {warmup_iterations}") print(f" - Save frequency: {save_freq}") print(f" - Log frequency: {log_freq}") print(f" - Device: {trainer.device}") print("=" * 100) print("\nStarting training loop...\n") # Training loop for step in range(trainer.iters_start, iterations): trainer.current_iters = step + 1 # Get batch from dataloader try: hr_latent, lr_latent = next(train_iter) except StopIteration: # Restart iterator if exhausted (shouldn't happen with infinite cycle, but safety) train_iter = iter(trainer.dataloaders['train']) hr_latent, lr_latent = next(train_iter) # Move to device hr_latent = hr_latent.to(trainer.device) lr_latent = lr_latent.to(trainer.device) # Training step (Component 5) # This handles: forward pass, backward pass, optimizer step, gradient accumulation loss, timing_dict = trainer.training_step(hr_latent, lr_latent) # Adjust learning rate (Component 6) trainer.adjust_lr() # Run validation (Component 9) if 'val' in trainer.dataloaders and trainer.current_iters % val_freq == 0: trainer.validation() # Store timing info for logging trainer._last_timing = timing_dict # Only recompute for logging if we're actually logging images # This avoids unnecessary computation when only logging loss if trainer.current_iters % log_freq[1] == 0: # Prepare data for logging (need x_t and pred for visualization) with torch.no_grad(): residual = (lr_latent - hr_latent) t_log = torch.randint(0, T, (hr_latent.shape[0],)).to(trainer.device) epsilon_log = torch.randn_like(hr_latent) eta_t_log = trainer.eta[t_log] x_t_log = hr_latent + eta_t_log * residual + k * torch.sqrt(eta_t_log) * epsilon_log trainer.model.eval() # Model predicts x0 (clean HR latent), not noise x0_pred_log = trainer.model(x_t_log[0:1], t_log[0:1], lq=lr_latent[0:1]) trainer.model.train() # Log training metrics and images (Component 8) trainer.log_step_train( loss=loss, hr_latent=hr_latent[0:1], lr_latent=lr_latent[0:1], x_t=x_t_log[0:1], pred=x0_pred_log, # x0 prediction (clean HR latent) phase='train' ) else: # Only log loss/metrics, no images trainer.log_step_train( loss=loss, hr_latent=hr_latent[0:1], lr_latent=lr_latent[0:1], x_t=None, # Not needed when not logging images pred=None, # Not needed when not logging images phase='train' ) # Save checkpoint (Component 7) if trainer.current_iters % save_freq == 0: trainer.save_ckpt() # Final checkpoint print("\n" + "=" * 100) print("Training completed!") print("=" * 100) trainer.save_ckpt() print(f"Final checkpoint saved at iteration {trainer.current_iters}") # Finish WandB wandb.finish() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Train ResShift diffusion model') parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint file to resume from (e.g., checkpoints/ckpts/model_10000.pth)') args = parser.parse_args() train(resume_ckpt=args.resume)