Spaces:
Sleeping
Sleeping
File size: 8,434 Bytes
3c45764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
"""
Configuration file with all training, model, and data parameters.
"""
import os
import torch
from pathlib import Path
# ============================================================================
# Project Settings
# ============================================================================
_project_root = Path(__file__).parent.parent
# ============================================================================
# Device Settings
# ============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ============================================================================
# Training Parameters
# ============================================================================
# Learning rate
lr = 1e-5 # Original ResShift setting
lr_min = 1e-5
lr_schedule = None
learning_rate = lr # Alias for backward compatibility
warmup_iterations = 100 # ~12.5% of total iterations (800), linear warmup from 0 to base_lr
# Dataloader
batch = [64, 64] # Original ResShift: adjust based on your GPU memory
batch_size = batch[0] # Use first value from batch list
microbatch = 100
num_workers = 4
prefetch_factor = 2
# Optimization settings
weight_decay = 0
ema_rate = 0.999
iterations = 3200 # 64 epochs for DIV2K (800 images / 64 batch_size = 12.5 batches per epoch)
# Save logging
save_freq = 200
log_freq = [50, 100] # [training loss, training images]
local_logging = True
tf_logging = False
# Validation settings
use_ema_val = True
val_freq = 100 # Run validation every 100 iterations
val_y_channel = True
val_resolution = 64 # model.params.lq_size
val_padding_mode = "reflect"
# Training setting
use_amp = True # Mixed precision training
seed = 123456
global_seeding = False
# Model compile
compile_flag = True
compile_mode = "reduce-overhead"
# ============================================================================
# Diffusion/Noise Schedule Parameters
# ============================================================================
sf = 4
schedule_name = "exponential"
schedule_power = 0.3 # Original ResShift setting
etas_end = 0.99 # Original ResShift setting
T = 15 # Original ResShift: 15 timesteps
min_noise_level = 0.04 # Original ResShift setting
eta_1 = min_noise_level # Alias for backward compatibility
eta_T = etas_end # Alias for backward compatibility
p = schedule_power # Alias for backward compatibility
kappa = 2.0
k = kappa # Alias for backward compatibility
weighted_mse = False
predict_type = "xstart" # Predict x0, not noise (key difference!)
timestep_respacing = None
scale_factor = 1.0
normalize_input = True
latent_flag = True # Working in latent space
# ============================================================================
# Model Architecture Parameters
# ============================================================================
# ResShift model architecture based on model_channels and channel_mult
# Initial Conv: 3 β 160
# Encoder Stage 1: 160 β 320 (downsample to 128x128)
# Encoder Stage 2: 320 β 320 (downsample to 64x64)
# Encoder Stage 3: 320 β 640 (downsample to 32x32)
# Encoder Stage 4: 640 (no downsampling, stays 32x32)
# Decoder Stage 1: 640 β 320 (upsample to 64x64)
# Decoder Stage 2: 320 β 320 (upsample to 128x128)
# Decoder Stage 3: 320 β 160 (upsample to 256x256)
# Decoder Stage 4: 160 β 3 (final output)
# Model params from ResShift configuration
image_size = 64 # Latent space: 64Γ64 (not 256Γ256 pixel space)
in_channels = 3
model_channels = 160 # Original ResShift: base channels
out_channels = 3
attention_resolutions = [64, 32, 16, 8] # Latent space resolutions
dropout = 0
channel_mult = [1, 2, 2, 4] # Original ResShift: 160, 320, 320, 640 channels
num_res_blocks = [2, 2, 2, 2]
conv_resample = True
dims = 2
use_fp16 = False
num_head_channels = 32
use_scale_shift_norm = True
resblock_updown = False
swin_depth = 2
swin_embed_dim = 192 # Original ResShift setting
window_size = 8 # Original ResShift setting (not 7)
mlp_ratio = 2.0 # Original ResShift uses 2.0, not 4
cond_lq = True # Enable LR conditioning
lq_size = 64 # LR latent size (same as image_size)
# U-Net architecture parameters based on ResShift configuration
# Initial conv: 3 β model_channels * channel_mult[0] = 160
initial_conv_out_channels = model_channels * channel_mult[0] # 160
# Encoder stage channels (based on channel_mult progression)
es1_in_channels = initial_conv_out_channels # 160
es1_out_channels = model_channels * channel_mult[1] # 320
es2_in_channels = es1_out_channels # 320
es2_out_channels = model_channels * channel_mult[2] # 320
es3_in_channels = es2_out_channels # 320
es3_out_channels = model_channels * channel_mult[3] # 640
es4_in_channels = es3_out_channels # 640
es4_out_channels = es3_out_channels # 640 (no downsampling)
# Decoder stage channels (reverse of encoder)
ds1_in_channels = es4_out_channels # 640
ds1_out_channels = es2_out_channels # 320
ds2_in_channels = ds1_out_channels # 320
ds2_out_channels = es2_out_channels # 320
ds3_in_channels = ds2_out_channels # 320
ds3_out_channels = es1_out_channels # 160
ds4_in_channels = ds3_out_channels # 160
ds4_out_channels = initial_conv_out_channels # 160
# Other model parameters
n_groupnorm_groups = 8 # Standard value
shift_size = window_size // 2 # Shift size for shifted window attention (should be window_size // 2, not swin_depth)
timestep_embed_dim = model_channels * 4 # Original ResShift: 160 * 4 = 640
num_heads = num_head_channels # Note: config has num_head_channels, but we need num_heads
# ============================================================================
# Autoencoder Parameters (from YAML, for reference)
# ============================================================================
autoencoder_ckpt_path = "pretrained_weights/autoencoder_vq_f4.pth"
autoencoder_use_fp16 = False # Temporarily disabled for CPU testing (FP16 is slow/hangs on CPU)
autoencoder_embed_dim = 3
autoencoder_n_embed = 8192
autoencoder_double_z = False
autoencoder_z_channels = 3
autoencoder_resolution = 256
autoencoder_in_channels = 3
autoencoder_out_ch = 3
autoencoder_ch = 128
autoencoder_ch_mult = [1, 2, 4]
autoencoder_num_res_blocks = 2
autoencoder_attn_resolutions = []
autoencoder_dropout = 0.0
autoencoder_padding_mode = "zeros"
# ============================================================================
# Degradation Parameters (used by realesrgan.py)
# ============================================================================
# Blur kernel settings (used for both first and second degradation)
blur_kernel_size = 21
kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
# First degradation stage
resize_prob = [0.2, 0.7, 0.1] # up, down, keep
resize_range = [0.15, 1.5]
gaussian_noise_prob = 0.5
noise_range = [1, 30]
poisson_scale_range = [0.05, 3.0]
gray_noise_prob = 0.4
jpeg_range = [30, 95]
data_train_blur_sigma = [0.2, 3.0]
data_train_betag_range = [0.5, 4.0]
data_train_betap_range = [1, 2.0]
data_train_sinc_prob = 0.1
# Second degradation stage
second_order_prob = 0.5
second_blur_prob = 0.8
resize_prob2 = [0.3, 0.4, 0.3] # up, down, keep
resize_range2 = [0.3, 1.2]
gaussian_noise_prob2 = 0.5
noise_range2 = [1, 25]
poisson_scale_range2 = [0.05, 2.5]
gray_noise_prob2 = 0.4
jpeg_range2 = [30, 95]
data_train_blur_kernel_size2 = 15
data_train_blur_sigma2 = [0.2, 1.5]
data_train_betag_range2 = [0.5, 4.0]
data_train_betap_range2 = [1, 2.0]
data_train_sinc_prob2 = 0.1
# Final sinc filter
data_train_final_sinc_prob = 0.8
final_sinc_prob = data_train_final_sinc_prob # Alias for backward compatibility
# Other degradation settings
gt_size = 256
resize_back = False
use_sharp = False
# ============================================================================
# Data Parameters
# ============================================================================
# Data paths - using defaults based on project structure
dir_HR = str(_project_root / "data" / "DIV2K_train_HR")
dir_LR = str(_project_root / "data" / "DIV2K_train_LR_bicubic" / "X4")
dir_valid_HR = str(_project_root / "data" / "DIV2K_valid_HR")
dir_valid_LR = str(_project_root / "data" / "DIV2K_valid_LR_bicubic" / "X4")
# Patch size (used by dataset)
patch_size = gt_size # 256
# Scale factor (from degradation.sf)
scale = sf # 4
|