import torch import torch.nn as nn from model import FullUNET from noiseControl import resshift_schedule from torch.utils.data import DataLoader from data import mini_dataset, train_dataset, get_vqgan_model import torch.optim as optim from config import (batch_size, device, learning_rate, iterations, weight_decay, T, k, _project_root) import wandb import os from dotenv import load_dotenv # Load environment variables from .env file (looks for .env in project root) load_dotenv(os.path.join(_project_root, '.env')) wandb.init( project="diffusionsr", name="reshift_training", config={ "learning_rate": learning_rate, "batch_size": batch_size, "steps": iterations, "model": "ResShift", "T": T, "k": k, "optimizer": "Adam", "betas": (0.9, 0.999), "grad_clip": 1.0, "criterion": "MSE", "device": str(device), "training_space": "latent_64x64" } ) # Load VQGAN for decoding latents for visualization vqgan = get_vqgan_model() train_dl = DataLoader(mini_dataset, batch_size=batch_size, shuffle=True) # Get a batch - now returns 64x64 latents hr_latent, lr_latent = next(iter(train_dl)) hr_latent = hr_latent.to(device) # (B, C, 64, 64) - HR latent lr_latent = lr_latent.to(device) # (B, C, 64, 64) - LR latent eta = resshift_schedule().to(device) eta = eta[:, None, None, None] # shape (T,1,1,1) residual = (lr_latent - hr_latent) # Residual in latent space model = FullUNET() model = model.to(device) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay) steps = iterations # Watch model for gradients/parameters wandb.watch(model, log="all", log_freq=10) for step in range(steps): model.train() # take random timestep (0 to T-1) t = torch.randint(0, T, (batch_size,)).to(device) # add the noise in latent space epsilon = torch.randn_like(hr_latent) # Noise in latent space eta_t = eta[t] x_t = hr_latent + eta_t * residual + k * torch.sqrt(eta_t) * epsilon # send the same patch in model forwardpass across different timestamps per each step # lr_latent is the low-resolution latent used for conditioning pred = model(x_t, t, lq=lr_latent) optimizer.zero_grad() loss = criterion(pred, epsilon) wandb.log({ "loss": loss.item(), "step": step, "learning_rate": optimizer.param_groups[0]['lr'] }) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() if step % 50 == 0: # Decode latents to pixel space for visualization with torch.no_grad(): hr_pixel = vqgan.decode(hr_latent[0:1]) # (1, 3, 256, 256) lr_pixel = vqgan.decode(lr_latent[0:1]) # (1, 3, 256, 256) pred_pixel = vqgan.decode(x_t[0:1]) # (1, 3, 256, 256) wandb.log({ "hr_sample": wandb.Image(hr_pixel[0].cpu().clamp(0, 1)), "lr_sample": wandb.Image(lr_pixel[0].cpu().clamp(0, 1)), "pred_sample": wandb.Image(pred_pixel[0].cpu().clamp(0, 1)) }) print(f'loss at step {step + 1} is {loss}') wandb.finish()