DiffusionSR / src /data.py
shekkari21's picture
Commiting all the super resolution files
3c45764
import torch
import torch.nn as nn
import os
import random
import math
from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from config import (
patch_size, scale, dir_HR, dir_LR, dir_valid_HR, dir_valid_LR,
_project_root, device, gt_size
)
from realesrgan import RealESRGANDegrader
from autoencoder import get_vqgan
# Initialize degradation pipeline and VQGAN (lazy loading)
_degrader = None
_vqgan = None
def get_degrader():
"""Get or create degradation pipeline."""
global _degrader
if _degrader is None:
_degrader = RealESRGANDegrader(scale=scale)
return _degrader
def get_vqgan_model():
"""Get or create VQGAN model."""
global _vqgan
if _vqgan is None:
_vqgan = get_vqgan(device=device)
return _vqgan
class SRDatasetOnTheFly(torch.utils.data.Dataset):
"""
PyTorch Dataset for on-the-fly degradation and VQGAN encoding.
This dataset:
1. Loads full HR images
2. Crops 256x256 patches on-the-fly
3. Applies RealESRGAN degradation to generate LR
4. Upsamples LR to 256x256 using bicubic
5. Encodes both HR and LR through VQGAN to get 64x64 latents
Args:
dir_HR (str): Directory path containing high-resolution images.
scale (int, optional): Super-resolution scale factor. Defaults to config.scale (4).
patch_size (int, optional): Size of patches. Defaults to config.patch_size (256).
max_samples (int, optional): Maximum number of images to load. If None, loads all.
Returns:
tuple: (hr_latent, lr_latent) where both are torch.Tensor of shape (C, 64, 64)
representing VQGAN-encoded latents.
"""
def __init__(self, dir_HR, scale=scale, patch_size=patch_size, max_samples=None):
super().__init__()
self.dir_HR = dir_HR
self.scale = scale
self.patch_size = patch_size
# Get all image files
self.filenames = sorted([
f for f in os.listdir(self.dir_HR)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])
# Limit to max_samples if specified
if max_samples is not None:
self.filenames = self.filenames[:max_samples]
# Initialize degradation and VQGAN (will be loaded on first use)
self.degrader = None
self.vqgan = None
def __len__(self):
return len(self.filenames)
def _load_image(self, img_path):
"""Load and validate image."""
img = Image.open(img_path).convert("RGB")
img_tensor = TF.to_tensor(img) # (C, H, W) in range [0, 1]
return img_tensor
def _crop_patch(self, img_tensor, patch_size):
"""
Crop a random patch from image.
Args:
img_tensor: (C, H, W) tensor
patch_size: Size of patch to crop
Returns:
patch: (C, patch_size, patch_size) tensor
"""
C, H, W = img_tensor.shape
# Pad if image is smaller than patch_size
if H < patch_size or W < patch_size:
pad_h = max(0, patch_size - H)
pad_w = max(0, patch_size - W)
img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
H, W = img_tensor.shape[1], img_tensor.shape[2]
# Random crop
top = random.randint(0, max(0, H - patch_size))
left = random.randint(0, max(0, W - patch_size))
patch = img_tensor[:, top:top+patch_size, left:left+patch_size]
return patch
def _apply_augmentations(self, hr, lr):
"""
Apply synchronized augmentations to HR and LR.
Args:
hr: (C, H, W) HR tensor
lr: (C, H, W) LR tensor
Returns:
hr_aug, lr_aug: Augmented tensors
"""
# Horizontal flip
if random.random() < 0.5:
hr = torch.flip(hr, dims=[2])
lr = torch.flip(lr, dims=[2])
# Vertical flip
if random.random() < 0.5:
hr = torch.flip(hr, dims=[1])
lr = torch.flip(lr, dims=[1])
# 180° rotation
if random.random() < 0.5:
hr = torch.rot90(hr, k=2, dims=[1, 2])
lr = torch.rot90(lr, k=2, dims=[1, 2])
return hr, lr
def __getitem__(self, idx):
# Load HR image
hr_path = os.path.join(self.dir_HR, self.filenames[idx])
hr_full = self._load_image(hr_path) # (C, H, W) in [0, 1]
# Crop 256x256 patch from HR
hr_patch = self._crop_patch(hr_full, self.patch_size) # (C, 256, 256)
# Initialize degrader and VQGAN on first use
if self.degrader is None:
self.degrader = get_degrader()
if self.vqgan is None:
self.vqgan = get_vqgan_model()
# Apply degradation on-the-fly to generate LR
# Degrader expects (C, H, W) and returns (C, H//scale, W//scale)
hr_patch_gpu = hr_patch.to(device) # (C, 256, 256)
with torch.no_grad():
lr_patch = self.degrader.degrade(hr_patch_gpu) # (C, 64, 64) in pixel space
# Upsample LR to 256x256 using bicubic interpolation
lr_patch_upsampled = F.interpolate(
lr_patch.unsqueeze(0), # (1, C, 64, 64)
size=(self.patch_size, self.patch_size),
mode='bicubic',
align_corners=False
).squeeze(0) # (C, 256, 256)
# Apply augmentations (synchronized)
hr_patch, lr_patch_upsampled = self._apply_augmentations(
hr_patch.cpu(),
lr_patch_upsampled.cpu()
)
# Encode through VQGAN to get latents (64x64)
# Move to device for encoding
hr_patch_gpu = hr_patch.to(device).unsqueeze(0) # (1, C, 256, 256)
lr_patch_gpu = lr_patch_upsampled.to(device).unsqueeze(0) # (1, C, 256, 256)
with torch.no_grad():
# Encode HR: 256x256 -> 64x64 latent
hr_latent = self.vqgan.encode(hr_patch_gpu) # (1, C, 64, 64)
# Encode LR: 256x256 -> 64x64 latent
lr_latent = self.vqgan.encode(lr_patch_gpu) # (1, C, 64, 64)
# Remove batch dimension and move to CPU
hr_latent = hr_latent.squeeze(0).cpu() # (C, 64, 64)
lr_latent = lr_latent.squeeze(0).cpu() # (C, 64, 64)
return hr_latent, lr_latent
# Create datasets using on-the-fly processing
train_dataset = SRDatasetOnTheFly(
dir_HR=dir_HR,
scale=scale,
patch_size=patch_size
)
valid_dataset = SRDatasetOnTheFly(
dir_HR=dir_valid_HR,
scale=scale,
patch_size=patch_size
)
# Mini dataset with 8 images for testing
mini_dataset = SRDatasetOnTheFly(
dir_HR=dir_HR,
scale=scale,
patch_size=patch_size,
max_samples=8
)
print(f"\nFull training dataset size: {len(train_dataset)}")
print(f"Full validation dataset size: {len(valid_dataset)}")
print(f"Mini dataset size: {len(mini_dataset)}")
print(f"Using on-the-fly degradation and VQGAN encoding")
print(f"Output: 64x64 latents (from 256x256 patches)")