| | import torch |
| | import torchvision |
| | from torchvision.utils import save_image |
| | import os |
| | import numpy as np |
| | from scipy.fftpack import dctn, idctn |
| | from config import Config |
| |
|
| | def frequency_aware_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4): |
| | """OPTIMIZED sampling for frequency-aware trained models""" |
| | config = Config() |
| | model.eval() |
| | |
| | with torch.no_grad(): |
| | |
| | |
| | x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.4 |
| | |
| | print(f"Starting optimized frequency-aware sampling for {n_samples} samples...") |
| | print(f"Initial moderate noise range: [{x.min().item():.3f}, {x.max().item():.3f}]") |
| | |
| | |
| | |
| | total_steps = 100 |
| | timesteps = [] |
| | |
| | |
| | for i in range(total_steps): |
| | |
| | t = int(300 * (1 - i / total_steps) ** 2) |
| | timesteps.append(max(t, 0)) |
| | |
| | timesteps = sorted(list(set(timesteps)), reverse=True) |
| | |
| | print(f"Using {len(timesteps)} adaptive timesteps: {timesteps[:10]}...{timesteps[-5:]}") |
| | |
| | for step, t in enumerate(timesteps): |
| | if step % 20 == 0: |
| | print(f" Step {step}/{len(timesteps)}, t={t}, range: [{x.min().item():.3f}, {x.max().item():.3f}]") |
| | |
| | t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long) |
| | |
| | |
| | predicted_noise = model(x, t_tensor) |
| | |
| | |
| | alpha_t = noise_scheduler.alphas[t].item() |
| | alpha_bar_t = noise_scheduler.alpha_bars[t].item() |
| | beta_t = noise_scheduler.betas[t].item() |
| | |
| | if step < len(timesteps) - 1: |
| | |
| | next_t = timesteps[step + 1] |
| | alpha_bar_prev = noise_scheduler.alpha_bars[next_t].item() |
| | |
| | |
| | pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t) |
| | pred_x0 = torch.clamp(pred_x0, -1.2, 1.2) |
| | |
| | |
| | coeff1 = np.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t) |
| | coeff2 = np.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t) |
| | posterior_mean = coeff1 * x + coeff2 * pred_x0 |
| | |
| | |
| | if next_t > 0: |
| | posterior_variance = beta_t * (1 - alpha_bar_prev) / (1 - alpha_bar_t) |
| | noise = torch.randn_like(x) |
| | |
| | |
| | noise_scale = np.sqrt(posterior_variance) * 0.3 |
| | x = posterior_mean + noise_scale * noise |
| | else: |
| | x = posterior_mean |
| | else: |
| | |
| | x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t) |
| | |
| | |
| | x = torch.clamp(x, -1.3, 1.3) |
| | |
| | |
| | x = torch.clamp(x, -1, 1) |
| | |
| | print(f"Final samples statistics:") |
| | print(f" Range: [{x.min().item():.3f}, {x.max().item():.3f}]") |
| | print(f" Mean: {x.mean().item():.3f}, Std: {x.std().item():.3f}") |
| | |
| | |
| | unique_vals = len(torch.unique(torch.round(x * 100) / 100)) |
| | print(f" Unique values (x100): {unique_vals}") |
| | |
| | if unique_vals < 20: |
| | print(" ⚠️ Low diversity - might be collapsed") |
| | elif x.std().item() < 0.05: |
| | print(" ⚠️ Very low variance - uniform output") |
| | elif x.std().item() > 0.9: |
| | print(" ⚠️ High variance - might still be noisy") |
| | else: |
| | print(" ✅ Good sample diversity and range!") |
| | |
| | |
| | x_display = torch.clamp((x + 1.0) / 2.0, 0, 1) |
| | |
| | |
| | grid = torchvision.utils.make_grid(x_display, nrow=2, normalize=False, pad_value=1.0) |
| | |
| | |
| | if writer and epoch is not None: |
| | writer.add_image('Samples', grid, epoch) |
| | |
| | if epoch is not None: |
| | os.makedirs("samples", exist_ok=True) |
| | save_image(grid, f"samples/epoch_{epoch}.png") |
| | |
| | return x, grid |
| |
|
| | |
| | def progressive_frequency_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4): |
| | """Progressive sampling - fewer steps, more stable for frequency-aware models""" |
| | config = Config() |
| | model.eval() |
| | |
| | with torch.no_grad(): |
| | |
| | x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.4 |
| | |
| | print(f"Starting progressive frequency sampling for {n_samples} samples...") |
| | |
| | |
| | timesteps = [300, 250, 200, 150, 120, 90, 70, 50, 35, 25, 15, 8, 3, 1] |
| | |
| | for i, t_val in enumerate(timesteps): |
| | print(f"Step {i+1}/{len(timesteps)}, t={t_val}") |
| | |
| | t_tensor = torch.full((n_samples,), t_val, device=device, dtype=torch.long) |
| | |
| | |
| | predicted_noise = model(x, t_tensor) |
| | |
| | |
| | alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
| | |
| | |
| | pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t) |
| | pred_x0 = torch.clamp(pred_x0, -1, 1) |
| | |
| | |
| | if i < len(timesteps) - 1: |
| | next_t = timesteps[i + 1] |
| | alpha_bar_next = noise_scheduler.alpha_bars[next_t].item() |
| | |
| | |
| | blend_factor = 0.3 |
| | x = (1 - blend_factor) * x + blend_factor * pred_x0 |
| | |
| | |
| | noise_scale = np.sqrt(1 - alpha_bar_next) * 0.2 |
| | noise = torch.randn_like(x) |
| | x = np.sqrt(alpha_bar_next) * x + noise_scale * noise |
| | else: |
| | |
| | x = pred_x0 |
| | |
| | |
| | x = torch.clamp(x, -1.2, 1.2) |
| | |
| | |
| | x = torch.clamp(x, -1, 1) |
| | |
| | print(f"Progressive samples - Range: [{x.min():.3f}, {x.max():.3f}], Mean: {x.mean():.3f}, Std: {x.std():.3f}") |
| | |
| | |
| | x_display = torch.clamp((x + 1) / 2, 0, 1) |
| | grid = torchvision.utils.make_grid(x_display, nrow=2, normalize=False, pad_value=1.0) |
| | |
| | if writer and epoch is not None: |
| | writer.add_image('Progressive_Samples', grid, epoch) |
| | |
| | if epoch is not None: |
| | os.makedirs("samples", exist_ok=True) |
| | save_image(grid, f"samples/progressive_epoch_{epoch}.png") |
| | |
| | return x, grid |
| |
|
| | def optimized_frequency_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4): |
| | """Optimized sampling with adaptive timesteps for frequency-aware models""" |
| | config = Config() |
| | model.eval() |
| | |
| | with torch.no_grad(): |
| | |
| | x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.5 |
| | |
| | print(f"Starting optimized frequency sampling for {n_samples} samples...") |
| | |
| | |
| | early_steps = list(range(400, 200, -25)) |
| | middle_steps = list(range(200, 50, -15)) |
| | final_steps = list(range(50, 0, -5)) |
| | |
| | timesteps = early_steps + middle_steps + final_steps |
| | |
| | for i, t_val in enumerate(timesteps): |
| | if i % 10 == 0: |
| | print(f"Step {i+1}/{len(timesteps)}, t={t_val}") |
| | |
| | t_tensor = torch.full((n_samples,), t_val, device=device, dtype=torch.long) |
| | |
| | |
| | predicted_noise = model(x, t_tensor) |
| | |
| | |
| | alpha_t = noise_scheduler.alphas[t_val].item() |
| | alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
| | beta_t = noise_scheduler.betas[t_val].item() |
| | |
| | if t_val > 0: |
| | |
| | next_idx = min(i + 1, len(timesteps) - 1) |
| | if next_idx < len(timesteps): |
| | next_t = timesteps[next_idx] if next_idx < len(timesteps) else 0 |
| | alpha_bar_prev = noise_scheduler.alpha_bars[next_t].item() if next_t > 0 else 1.0 |
| | else: |
| | alpha_bar_prev = 1.0 |
| | |
| | |
| | pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t) |
| | pred_x0 = torch.clamp(pred_x0, -1, 1) |
| | |
| | |
| | coeff1 = np.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t) |
| | coeff2 = np.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t) |
| | mean = coeff1 * x + coeff2 * pred_x0 |
| | |
| | |
| | if t_val > 5: |
| | posterior_variance = beta_t * (1 - alpha_bar_prev) / (1 - alpha_bar_t) |
| | |
| | |
| | noise_scale = 1.0 if t_val > 100 else 0.5 |
| | noise = torch.randn_like(x) |
| | x = mean + np.sqrt(posterior_variance) * noise * noise_scale |
| | else: |
| | x = mean |
| | else: |
| | |
| | x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t) |
| | |
| | |
| | clamp_range = 2.0 if t_val > 200 else 1.5 if t_val > 50 else 1.2 |
| | x = torch.clamp(x, -clamp_range, clamp_range) |
| | |
| | |
| | x = torch.clamp(x, -1, 1) |
| | |
| | print(f"Optimized samples - Range: [{x.min():.3f}, {x.max():.3f}], Mean: {x.mean():.3f}, Std: {x.std():.3f}") |
| | |
| | |
| | unique_vals = len(torch.unique(torch.round(x * 100) / 100)) |
| | if unique_vals > 50: |
| | print("✅ Good diversity in generated samples") |
| | else: |
| | print("⚠️ Low diversity - samples might be collapsed") |
| | |
| | |
| | x_display = torch.clamp((x + 1) / 2, 0, 1) |
| | grid = torchvision.utils.make_grid(x_display, nrow=2, normalize=False, pad_value=1.0) |
| | |
| | if writer and epoch is not None: |
| | writer.add_image('Optimized_Samples', grid, epoch) |
| | |
| | if epoch is not None: |
| | os.makedirs("samples", exist_ok=True) |
| | save_image(grid, f"samples/optimized_epoch_{epoch}.png") |
| | |
| | return x, grid |
| |
|
| | |
| | def aggressive_frequency_sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4): |
| | """Aggressive sampling - leverages the model's strong denoising ability""" |
| | config = Config() |
| | model.eval() |
| | |
| | with torch.no_grad(): |
| | |
| | x = torch.randn(n_samples, 3, config.image_size, config.image_size, device=device) * 0.8 |
| | |
| | print(f"Starting aggressive frequency sampling for {n_samples} samples...") |
| | print(f"Initial noise range: [{x.min():.3f}, {x.max():.3f}], std: {x.std():.3f}") |
| | |
| | |
| | |
| | timesteps = [350, 280, 220, 170, 130, 100, 75, 55, 40, 28, 18, 10, 5, 2, 1] |
| | |
| | for i, t_val in enumerate(timesteps): |
| | t_tensor = torch.full((n_samples,), t_val, device=device, dtype=torch.long) |
| | |
| | |
| | predicted_noise = model(x, t_tensor) |
| | |
| | |
| | alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
| | |
| | |
| | pred_x0 = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t) |
| | pred_x0 = torch.clamp(pred_x0, -1, 1) |
| | |
| | if i < len(timesteps) - 2: |
| | |
| | alpha_bar_next = noise_scheduler.alpha_bars[timesteps[i + 1]].item() if i + 1 < len(timesteps) else 1.0 |
| | |
| | |
| | trust_factor = 0.6 if t_val > 100 else 0.8 |
| | x = (1 - trust_factor) * x + trust_factor * pred_x0 |
| | |
| | |
| | if t_val > 10: |
| | noise_strength = np.sqrt(1 - alpha_bar_next) * 0.4 |
| | fresh_noise = torch.randn_like(x) |
| | x = np.sqrt(alpha_bar_next) * x + noise_strength * fresh_noise |
| | |
| | elif i == len(timesteps) - 2: |
| | |
| | x = 0.2 * x + 0.8 * pred_x0 |
| | tiny_noise = torch.randn_like(x) * 0.05 |
| | x = x + tiny_noise |
| | else: |
| | x = pred_x0 |
| | |
| | |
| | x = torch.clamp(x, -1.5, 1.5) |
| | |
| | if i % 3 == 0: |
| | print(f" Step {i+1}/{len(timesteps)}, t={t_val}, range: [{x.min():.3f}, {x.max():.3f}], std: {x.std():.3f}") |
| | |
| | |
| | x = torch.clamp(x, -1, 1) |
| | |
| | print(f"Aggressive samples - Range: [{x.min():.3f}, {x.max():.3f}], Mean: {x.mean():.3f}, Std: {x.std():.3f}") |
| | |
| | |
| | unique_vals = len(torch.unique(torch.round(x * 200) / 200)) |
| | print(f"Unique values (x200): {unique_vals}") |
| | |
| | if x.std().item() < 0.05: |
| | print("❌ Very low variance - output collapsed") |
| | elif x.std().item() < 0.15: |
| | print("⚠️ Low variance - output may be too smooth") |
| | elif x.std().item() > 0.6: |
| | print("⚠️ High variance - output may be noisy") |
| | else: |
| | print("✅ Good variance - output looks promising") |
| | |
| | if unique_vals < 20: |
| | print("❌ Very low diversity") |
| | elif unique_vals < 100: |
| | print("⚠️ Moderate diversity") |
| | else: |
| | print("✅ Good diversity") |
| | |
| | |
| | x_display = torch.clamp((x + 1) / 2, 0, 1) |
| | grid = torchvision.utils.make_grid(x_display, nrow=2, normalize=False, pad_value=1.0) |
| | |
| | if writer and epoch is not None: |
| | writer.add_image('Aggressive_Samples', grid, epoch) |
| | |
| | if epoch is not None: |
| | os.makedirs("samples", exist_ok=True) |
| | save_image(grid, f"samples/aggressive_epoch_{epoch}.png") |
| | |
| | return x, grid |
| |
|
| | |
| | def sample(model, noise_scheduler, device, epoch=None, writer=None, n_samples=4): |
| | return frequency_aware_sample(model, noise_scheduler, device, epoch, writer, n_samples) |