""" Configuration file with all training, model, and data parameters. """ import os import torch from pathlib import Path # ============================================================================ # Project Settings # ============================================================================ _project_root = Path(__file__).parent.parent # ============================================================================ # Device Settings # ============================================================================ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ============================================================================ # Training Parameters # ============================================================================ # Learning rate lr = 1e-5 # Original ResShift setting lr_min = 1e-5 lr_schedule = None learning_rate = lr # Alias for backward compatibility warmup_iterations = 100 # ~12.5% of total iterations (800), linear warmup from 0 to base_lr # Dataloader batch = [64, 64] # Original ResShift: adjust based on your GPU memory batch_size = batch[0] # Use first value from batch list microbatch = 100 num_workers = 4 prefetch_factor = 2 # Optimization settings weight_decay = 0 ema_rate = 0.999 iterations = 3200 # 64 epochs for DIV2K (800 images / 64 batch_size = 12.5 batches per epoch) # Save logging save_freq = 200 log_freq = [50, 100] # [training loss, training images] local_logging = True tf_logging = False # Validation settings use_ema_val = True val_freq = 100 # Run validation every 100 iterations val_y_channel = True val_resolution = 64 # model.params.lq_size val_padding_mode = "reflect" # Training setting use_amp = True # Mixed precision training seed = 123456 global_seeding = False # Model compile compile_flag = True compile_mode = "reduce-overhead" # ============================================================================ # Diffusion/Noise Schedule Parameters # ============================================================================ sf = 4 schedule_name = "exponential" schedule_power = 0.3 # Original ResShift setting etas_end = 0.99 # Original ResShift setting T = 15 # Original ResShift: 15 timesteps min_noise_level = 0.04 # Original ResShift setting eta_1 = min_noise_level # Alias for backward compatibility eta_T = etas_end # Alias for backward compatibility p = schedule_power # Alias for backward compatibility kappa = 2.0 k = kappa # Alias for backward compatibility weighted_mse = False predict_type = "xstart" # Predict x0, not noise (key difference!) timestep_respacing = None scale_factor = 1.0 normalize_input = True latent_flag = True # Working in latent space # ============================================================================ # Model Architecture Parameters # ============================================================================ # ResShift model architecture based on model_channels and channel_mult # Initial Conv: 3 → 160 # Encoder Stage 1: 160 → 320 (downsample to 128x128) # Encoder Stage 2: 320 → 320 (downsample to 64x64) # Encoder Stage 3: 320 → 640 (downsample to 32x32) # Encoder Stage 4: 640 (no downsampling, stays 32x32) # Decoder Stage 1: 640 → 320 (upsample to 64x64) # Decoder Stage 2: 320 → 320 (upsample to 128x128) # Decoder Stage 3: 320 → 160 (upsample to 256x256) # Decoder Stage 4: 160 → 3 (final output) # Model params from ResShift configuration image_size = 64 # Latent space: 64×64 (not 256×256 pixel space) in_channels = 3 model_channels = 160 # Original ResShift: base channels out_channels = 3 attention_resolutions = [64, 32, 16, 8] # Latent space resolutions dropout = 0 channel_mult = [1, 2, 2, 4] # Original ResShift: 160, 320, 320, 640 channels num_res_blocks = [2, 2, 2, 2] conv_resample = True dims = 2 use_fp16 = False num_head_channels = 32 use_scale_shift_norm = True resblock_updown = False swin_depth = 2 swin_embed_dim = 192 # Original ResShift setting window_size = 8 # Original ResShift setting (not 7) mlp_ratio = 2.0 # Original ResShift uses 2.0, not 4 cond_lq = True # Enable LR conditioning lq_size = 64 # LR latent size (same as image_size) # U-Net architecture parameters based on ResShift configuration # Initial conv: 3 → model_channels * channel_mult[0] = 160 initial_conv_out_channels = model_channels * channel_mult[0] # 160 # Encoder stage channels (based on channel_mult progression) es1_in_channels = initial_conv_out_channels # 160 es1_out_channels = model_channels * channel_mult[1] # 320 es2_in_channels = es1_out_channels # 320 es2_out_channels = model_channels * channel_mult[2] # 320 es3_in_channels = es2_out_channels # 320 es3_out_channels = model_channels * channel_mult[3] # 640 es4_in_channels = es3_out_channels # 640 es4_out_channels = es3_out_channels # 640 (no downsampling) # Decoder stage channels (reverse of encoder) ds1_in_channels = es4_out_channels # 640 ds1_out_channels = es2_out_channels # 320 ds2_in_channels = ds1_out_channels # 320 ds2_out_channels = es2_out_channels # 320 ds3_in_channels = ds2_out_channels # 320 ds3_out_channels = es1_out_channels # 160 ds4_in_channels = ds3_out_channels # 160 ds4_out_channels = initial_conv_out_channels # 160 # Other model parameters n_groupnorm_groups = 8 # Standard value shift_size = window_size // 2 # Shift size for shifted window attention (should be window_size // 2, not swin_depth) timestep_embed_dim = model_channels * 4 # Original ResShift: 160 * 4 = 640 num_heads = num_head_channels # Note: config has num_head_channels, but we need num_heads # ============================================================================ # Autoencoder Parameters (from YAML, for reference) # ============================================================================ autoencoder_ckpt_path = "pretrained_weights/autoencoder_vq_f4.pth" autoencoder_use_fp16 = False # Temporarily disabled for CPU testing (FP16 is slow/hangs on CPU) autoencoder_embed_dim = 3 autoencoder_n_embed = 8192 autoencoder_double_z = False autoencoder_z_channels = 3 autoencoder_resolution = 256 autoencoder_in_channels = 3 autoencoder_out_ch = 3 autoencoder_ch = 128 autoencoder_ch_mult = [1, 2, 4] autoencoder_num_res_blocks = 2 autoencoder_attn_resolutions = [] autoencoder_dropout = 0.0 autoencoder_padding_mode = "zeros" # ============================================================================ # Degradation Parameters (used by realesrgan.py) # ============================================================================ # Blur kernel settings (used for both first and second degradation) blur_kernel_size = 21 kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] # First degradation stage resize_prob = [0.2, 0.7, 0.1] # up, down, keep resize_range = [0.15, 1.5] gaussian_noise_prob = 0.5 noise_range = [1, 30] poisson_scale_range = [0.05, 3.0] gray_noise_prob = 0.4 jpeg_range = [30, 95] data_train_blur_sigma = [0.2, 3.0] data_train_betag_range = [0.5, 4.0] data_train_betap_range = [1, 2.0] data_train_sinc_prob = 0.1 # Second degradation stage second_order_prob = 0.5 second_blur_prob = 0.8 resize_prob2 = [0.3, 0.4, 0.3] # up, down, keep resize_range2 = [0.3, 1.2] gaussian_noise_prob2 = 0.5 noise_range2 = [1, 25] poisson_scale_range2 = [0.05, 2.5] gray_noise_prob2 = 0.4 jpeg_range2 = [30, 95] data_train_blur_kernel_size2 = 15 data_train_blur_sigma2 = [0.2, 1.5] data_train_betag_range2 = [0.5, 4.0] data_train_betap_range2 = [1, 2.0] data_train_sinc_prob2 = 0.1 # Final sinc filter data_train_final_sinc_prob = 0.8 final_sinc_prob = data_train_final_sinc_prob # Alias for backward compatibility # Other degradation settings gt_size = 256 resize_back = False use_sharp = False # ============================================================================ # Data Parameters # ============================================================================ # Data paths - using defaults based on project structure dir_HR = str(_project_root / "data" / "DIV2K_train_HR") dir_LR = str(_project_root / "data" / "DIV2K_train_LR_bicubic" / "X4") dir_valid_HR = str(_project_root / "data" / "DIV2K_valid_HR") dir_valid_LR = str(_project_root / "data" / "DIV2K_valid_LR_bicubic" / "X4") # Patch size (used by dataset) patch_size = gt_size # 256 # Scale factor (from degradation.sf) scale = sf # 4