File size: 8,434 Bytes
3c45764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""
Configuration file with all training, model, and data parameters.
"""
import os
import torch
from pathlib import Path

# ============================================================================
# Project Settings
# ============================================================================
_project_root = Path(__file__).parent.parent

# ============================================================================
# Device Settings
# ============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================================================================
# Training Parameters
# ============================================================================
# Learning rate
lr = 1e-5  # Original ResShift setting
lr_min = 1e-5
lr_schedule = None
learning_rate = lr  # Alias for backward compatibility
warmup_iterations = 100  # ~12.5% of total iterations (800), linear warmup from 0 to base_lr

# Dataloader
batch = [64, 64]  # Original ResShift: adjust based on your GPU memory
batch_size = batch[0]  # Use first value from batch list
microbatch = 100
num_workers = 4
prefetch_factor = 2

# Optimization settings
weight_decay = 0
ema_rate = 0.999
iterations = 3200  # 64 epochs for DIV2K (800 images / 64 batch_size = 12.5 batches per epoch)

# Save logging
save_freq = 200
log_freq = [50, 100]  # [training loss, training images]
local_logging = True
tf_logging = False

# Validation settings
use_ema_val = True
val_freq = 100  # Run validation every 100 iterations
val_y_channel = True
val_resolution = 64  # model.params.lq_size
val_padding_mode = "reflect"

# Training setting
use_amp = True  # Mixed precision training
seed = 123456
global_seeding = False

# Model compile
compile_flag = True
compile_mode = "reduce-overhead"

# ============================================================================
# Diffusion/Noise Schedule Parameters
# ============================================================================
sf = 4
schedule_name = "exponential"
schedule_power = 0.3  # Original ResShift setting
etas_end = 0.99  # Original ResShift setting
T = 15  # Original ResShift: 15 timesteps
min_noise_level = 0.04  # Original ResShift setting
eta_1 = min_noise_level  # Alias for backward compatibility
eta_T = etas_end  # Alias for backward compatibility
p = schedule_power  # Alias for backward compatibility
kappa = 2.0
k = kappa  # Alias for backward compatibility
weighted_mse = False
predict_type = "xstart"  # Predict x0, not noise (key difference!)
timestep_respacing = None
scale_factor = 1.0
normalize_input = True
latent_flag = True  # Working in latent space

# ============================================================================
# Model Architecture Parameters
# ============================================================================
# ResShift model architecture based on model_channels and channel_mult
# Initial Conv: 3 β†’ 160
# Encoder Stage 1: 160 β†’ 320 (downsample to 128x128)
# Encoder Stage 2: 320 β†’ 320 (downsample to 64x64)
# Encoder Stage 3: 320 β†’ 640 (downsample to 32x32)
# Encoder Stage 4: 640 (no downsampling, stays 32x32)
# Decoder Stage 1: 640 β†’ 320 (upsample to 64x64)
# Decoder Stage 2: 320 β†’ 320 (upsample to 128x128)
# Decoder Stage 3: 320 β†’ 160 (upsample to 256x256)
# Decoder Stage 4: 160 β†’ 3 (final output)

# Model params from ResShift configuration
image_size = 64  # Latent space: 64Γ—64 (not 256Γ—256 pixel space)
in_channels = 3
model_channels = 160  # Original ResShift: base channels
out_channels = 3
attention_resolutions = [64, 32, 16, 8]  # Latent space resolutions
dropout = 0
channel_mult = [1, 2, 2, 4]  # Original ResShift: 160, 320, 320, 640 channels
num_res_blocks = [2, 2, 2, 2]
conv_resample = True
dims = 2
use_fp16 = False
num_head_channels = 32
use_scale_shift_norm = True
resblock_updown = False
swin_depth = 2
swin_embed_dim = 192  # Original ResShift setting
window_size = 8  # Original ResShift setting (not 7)
mlp_ratio = 2.0  # Original ResShift uses 2.0, not 4
cond_lq = True  # Enable LR conditioning
lq_size = 64  # LR latent size (same as image_size)

# U-Net architecture parameters based on ResShift configuration
# Initial conv: 3 β†’ model_channels * channel_mult[0] = 160
initial_conv_out_channels = model_channels * channel_mult[0]  # 160

# Encoder stage channels (based on channel_mult progression)
es1_in_channels = initial_conv_out_channels  # 160
es1_out_channels = model_channels * channel_mult[1]  # 320
es2_in_channels = es1_out_channels  # 320
es2_out_channels = model_channels * channel_mult[2]  # 320
es3_in_channels = es2_out_channels  # 320
es3_out_channels = model_channels * channel_mult[3]  # 640
es4_in_channels = es3_out_channels  # 640
es4_out_channels = es3_out_channels  # 640 (no downsampling)

# Decoder stage channels (reverse of encoder)
ds1_in_channels = es4_out_channels  # 640
ds1_out_channels = es2_out_channels  # 320
ds2_in_channels = ds1_out_channels  # 320
ds2_out_channels = es2_out_channels  # 320
ds3_in_channels = ds2_out_channels  # 320
ds3_out_channels = es1_out_channels  # 160
ds4_in_channels = ds3_out_channels  # 160
ds4_out_channels = initial_conv_out_channels  # 160

# Other model parameters
n_groupnorm_groups = 8  # Standard value
shift_size = window_size // 2  # Shift size for shifted window attention (should be window_size // 2, not swin_depth)
timestep_embed_dim = model_channels * 4  # Original ResShift: 160 * 4 = 640
num_heads = num_head_channels  # Note: config has num_head_channels, but we need num_heads

# ============================================================================
# Autoencoder Parameters (from YAML, for reference)
# ============================================================================
autoencoder_ckpt_path = "pretrained_weights/autoencoder_vq_f4.pth"
autoencoder_use_fp16 = False  # Temporarily disabled for CPU testing (FP16 is slow/hangs on CPU)
autoencoder_embed_dim = 3
autoencoder_n_embed = 8192
autoencoder_double_z = False
autoencoder_z_channels = 3
autoencoder_resolution = 256
autoencoder_in_channels = 3
autoencoder_out_ch = 3
autoencoder_ch = 128
autoencoder_ch_mult = [1, 2, 4]
autoencoder_num_res_blocks = 2
autoencoder_attn_resolutions = []
autoencoder_dropout = 0.0
autoencoder_padding_mode = "zeros"

# ============================================================================
# Degradation Parameters (used by realesrgan.py)
# ============================================================================
# Blur kernel settings (used for both first and second degradation)
blur_kernel_size = 21
kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]

# First degradation stage
resize_prob = [0.2, 0.7, 0.1]  # up, down, keep
resize_range = [0.15, 1.5]
gaussian_noise_prob = 0.5
noise_range = [1, 30]
poisson_scale_range = [0.05, 3.0]
gray_noise_prob = 0.4
jpeg_range = [30, 95]
data_train_blur_sigma = [0.2, 3.0]
data_train_betag_range = [0.5, 4.0]
data_train_betap_range = [1, 2.0]
data_train_sinc_prob = 0.1

# Second degradation stage
second_order_prob = 0.5
second_blur_prob = 0.8
resize_prob2 = [0.3, 0.4, 0.3]  # up, down, keep
resize_range2 = [0.3, 1.2]
gaussian_noise_prob2 = 0.5
noise_range2 = [1, 25]
poisson_scale_range2 = [0.05, 2.5]
gray_noise_prob2 = 0.4
jpeg_range2 = [30, 95]
data_train_blur_kernel_size2 = 15
data_train_blur_sigma2 = [0.2, 1.5]
data_train_betag_range2 = [0.5, 4.0]
data_train_betap_range2 = [1, 2.0]
data_train_sinc_prob2 = 0.1

# Final sinc filter
data_train_final_sinc_prob = 0.8
final_sinc_prob = data_train_final_sinc_prob  # Alias for backward compatibility

# Other degradation settings
gt_size = 256
resize_back = False
use_sharp = False

# ============================================================================
# Data Parameters
# ============================================================================
# Data paths - using defaults based on project structure
dir_HR = str(_project_root / "data" / "DIV2K_train_HR")
dir_LR = str(_project_root / "data" / "DIV2K_train_LR_bicubic" / "X4")
dir_valid_HR = str(_project_root / "data" / "DIV2K_valid_HR")
dir_valid_LR = str(_project_root / "data" / "DIV2K_valid_LR_bicubic" / "X4")

# Patch size (used by dataset)
patch_size = gt_size  # 256

# Scale factor (from degradation.sf)
scale = sf  # 4