| | import torch |
| | import torchvision |
| | from torchvision.utils import save_image, make_grid |
| | import os |
| | from config import Config |
| | from model import SmoothDiffusionUNet |
| | from noise_scheduler import FrequencyAwareNoise |
| | from sample import frequency_aware_sample |
| | import numpy as np |
| |
|
| | def debug_model_predictions(): |
| | """Debug what the model is actually predicting""" |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| | |
| | |
| | log_dirs = [] |
| | if os.path.exists('./logs'): |
| | for item in os.listdir('./logs'): |
| | if os.path.isdir(os.path.join('./logs', item)): |
| | log_dirs.append(item) |
| | |
| | if not log_dirs: |
| | print("No log directories found!") |
| | return |
| | |
| | latest_log = sorted(log_dirs)[-1] |
| | log_path = os.path.join('./logs', latest_log) |
| | |
| | checkpoint_files = [] |
| | for file in os.listdir(log_path): |
| | if file.startswith('model_epoch_') and file.endswith('.pth'): |
| | epoch = int(file.split('_')[2].split('.')[0]) |
| | checkpoint_files.append((epoch, file)) |
| | |
| | if not checkpoint_files: |
| | print("No checkpoint files found!") |
| | return |
| | |
| | |
| | checkpoint_files.sort() |
| | latest_epoch, latest_file = checkpoint_files[-1] |
| | checkpoint_path = os.path.join(log_path, latest_file) |
| | |
| | print(f"Loading {latest_file}") |
| | |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | config = checkpoint.get('config', Config()) |
| | |
| | model = SmoothDiffusionUNet(config).to(device) |
| | noise_scheduler = FrequencyAwareNoise(config) |
| | |
| | if 'model_state_dict' in checkpoint: |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | else: |
| | model.load_state_dict(checkpoint) |
| | |
| | model.eval() |
| | |
| | print("\n=== DEBUGGING MODEL PREDICTIONS ===") |
| | |
| | with torch.no_grad(): |
| | |
| | x_test = torch.randn(1, 3, 64, 64, device=device) |
| | |
| | |
| | timesteps_to_test = [0, 50, 100, 250, 499] |
| | |
| | for t_val in timesteps_to_test: |
| | t_tensor = torch.full((1,), t_val, device=device, dtype=torch.long) |
| | |
| | |
| | pred_noise = model(x_test, t_tensor) |
| | |
| | print(f"\nTimestep {t_val}:") |
| | print(f" Input range: [{x_test.min().item():.3f}, {x_test.max().item():.3f}]") |
| | print(f" Input mean/std: {x_test.mean().item():.3f} / {x_test.std().item():.3f}") |
| | print(f" Predicted noise range: [{pred_noise.min().item():.3f}, {pred_noise.max().item():.3f}]") |
| | print(f" Predicted noise mean/std: {pred_noise.mean().item():.3f} / {pred_noise.std().item():.3f}") |
| | |
| | |
| | if torch.isnan(pred_noise).any(): |
| | print(f" ❌ NaN detected in predictions!") |
| | elif pred_noise.std().item() < 0.01: |
| | print(f" ⚠️ Very low variance - model might be collapsed") |
| | elif pred_noise.std().item() > 10: |
| | print(f" ⚠️ Very high variance - model might be unstable") |
| | else: |
| | print(f" ✓ Prediction variance looks reasonable") |
| | |
| | print("\n=== TESTING TRAINING DATA SIMULATION ===") |
| | |
| | |
| | with torch.no_grad(): |
| | |
| | x0 = torch.randn(1, 3, 64, 64, device=device) * 0.5 |
| | t = torch.randint(100, 400, (1,), device=device) |
| | |
| | |
| | xt, noise_target = noise_scheduler.apply_noise(x0, t) |
| | |
| | |
| | pred_noise = model(xt, t) |
| | |
| | print(f"\nTraining simulation:") |
| | print(f" Clean image range: [{x0.min().item():.3f}, {x0.max().item():.3f}]") |
| | print(f" Noisy image range: [{xt.min().item():.3f}, {xt.max().item():.3f}]") |
| | print(f" Target noise range: [{noise_target.min().item():.3f}, {noise_target.max().item():.3f}]") |
| | print(f" Target noise mean/std: {noise_target.mean().item():.3f} / {noise_target.std().item():.3f}") |
| | print(f" Predicted noise range: [{pred_noise.min().item():.3f}, {pred_noise.max().item():.3f}]") |
| | print(f" Predicted noise mean/std: {pred_noise.mean().item():.3f} / {pred_noise.std().item():.3f}") |
| | |
| | |
| | mse = torch.mean((pred_noise - noise_target) ** 2) |
| | print(f" MSE between prediction and target: {mse.item():.6f}") |
| | |
| | if mse.item() > 1.0: |
| | print(f" ⚠️ High MSE suggests poor training") |
| | elif mse.item() < 0.001: |
| | print(f" ✓ Very low MSE - model learned well") |
| | else: |
| | print(f" ✓ Reasonable MSE") |
| | |
| | print("\n=== ATTEMPTING CORRECTED SAMPLING ===") |
| | |
| | |
| | try: |
| | samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=4) |
| | save_image(grid, "debug_samples.png", normalize=False) |
| | print(f"Samples saved to debug_samples.png") |
| | |
| | print(f"Sample statistics:") |
| | print(f" Range: [{samples.min().item():.3f}, {samples.max().item():.3f}]") |
| | print(f" Mean: {samples.mean().item():.3f}") |
| | print(f" Std: {samples.std().item():.3f}") |
| | |
| | except Exception as e: |
| | print(f"Sampling failed: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | if __name__ == "__main__": |
| | debug_model_predictions() |
| |
|