Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import os | |
| import random | |
| import math | |
| from PIL import Image | |
| import torchvision.transforms.functional as TF | |
| import torch.nn.functional as F | |
| from config import ( | |
| patch_size, scale, dir_HR, dir_LR, dir_valid_HR, dir_valid_LR, | |
| _project_root, device, gt_size | |
| ) | |
| from realesrgan import RealESRGANDegrader | |
| from autoencoder import get_vqgan | |
| # Initialize degradation pipeline and VQGAN (lazy loading) | |
| _degrader = None | |
| _vqgan = None | |
| def get_degrader(): | |
| """Get or create degradation pipeline.""" | |
| global _degrader | |
| if _degrader is None: | |
| _degrader = RealESRGANDegrader(scale=scale) | |
| return _degrader | |
| def get_vqgan_model(): | |
| """Get or create VQGAN model.""" | |
| global _vqgan | |
| if _vqgan is None: | |
| _vqgan = get_vqgan(device=device) | |
| return _vqgan | |
| class SRDatasetOnTheFly(torch.utils.data.Dataset): | |
| """ | |
| PyTorch Dataset for on-the-fly degradation and VQGAN encoding. | |
| This dataset: | |
| 1. Loads full HR images | |
| 2. Crops 256x256 patches on-the-fly | |
| 3. Applies RealESRGAN degradation to generate LR | |
| 4. Upsamples LR to 256x256 using bicubic | |
| 5. Encodes both HR and LR through VQGAN to get 64x64 latents | |
| Args: | |
| dir_HR (str): Directory path containing high-resolution images. | |
| scale (int, optional): Super-resolution scale factor. Defaults to config.scale (4). | |
| patch_size (int, optional): Size of patches. Defaults to config.patch_size (256). | |
| max_samples (int, optional): Maximum number of images to load. If None, loads all. | |
| Returns: | |
| tuple: (hr_latent, lr_latent) where both are torch.Tensor of shape (C, 64, 64) | |
| representing VQGAN-encoded latents. | |
| """ | |
| def __init__(self, dir_HR, scale=scale, patch_size=patch_size, max_samples=None): | |
| super().__init__() | |
| self.dir_HR = dir_HR | |
| self.scale = scale | |
| self.patch_size = patch_size | |
| # Get all image files | |
| self.filenames = sorted([ | |
| f for f in os.listdir(self.dir_HR) | |
| if f.lower().endswith(('.png', '.jpg', '.jpeg')) | |
| ]) | |
| # Limit to max_samples if specified | |
| if max_samples is not None: | |
| self.filenames = self.filenames[:max_samples] | |
| # Initialize degradation and VQGAN (will be loaded on first use) | |
| self.degrader = None | |
| self.vqgan = None | |
| def __len__(self): | |
| return len(self.filenames) | |
| def _load_image(self, img_path): | |
| """Load and validate image.""" | |
| img = Image.open(img_path).convert("RGB") | |
| img_tensor = TF.to_tensor(img) # (C, H, W) in range [0, 1] | |
| return img_tensor | |
| def _crop_patch(self, img_tensor, patch_size): | |
| """ | |
| Crop a random patch from image. | |
| Args: | |
| img_tensor: (C, H, W) tensor | |
| patch_size: Size of patch to crop | |
| Returns: | |
| patch: (C, patch_size, patch_size) tensor | |
| """ | |
| C, H, W = img_tensor.shape | |
| # Pad if image is smaller than patch_size | |
| if H < patch_size or W < patch_size: | |
| pad_h = max(0, patch_size - H) | |
| pad_w = max(0, patch_size - W) | |
| img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect') | |
| H, W = img_tensor.shape[1], img_tensor.shape[2] | |
| # Random crop | |
| top = random.randint(0, max(0, H - patch_size)) | |
| left = random.randint(0, max(0, W - patch_size)) | |
| patch = img_tensor[:, top:top+patch_size, left:left+patch_size] | |
| return patch | |
| def _apply_augmentations(self, hr, lr): | |
| """ | |
| Apply synchronized augmentations to HR and LR. | |
| Args: | |
| hr: (C, H, W) HR tensor | |
| lr: (C, H, W) LR tensor | |
| Returns: | |
| hr_aug, lr_aug: Augmented tensors | |
| """ | |
| # Horizontal flip | |
| if random.random() < 0.5: | |
| hr = torch.flip(hr, dims=[2]) | |
| lr = torch.flip(lr, dims=[2]) | |
| # Vertical flip | |
| if random.random() < 0.5: | |
| hr = torch.flip(hr, dims=[1]) | |
| lr = torch.flip(lr, dims=[1]) | |
| # 180° rotation | |
| if random.random() < 0.5: | |
| hr = torch.rot90(hr, k=2, dims=[1, 2]) | |
| lr = torch.rot90(lr, k=2, dims=[1, 2]) | |
| return hr, lr | |
| def __getitem__(self, idx): | |
| # Load HR image | |
| hr_path = os.path.join(self.dir_HR, self.filenames[idx]) | |
| hr_full = self._load_image(hr_path) # (C, H, W) in [0, 1] | |
| # Crop 256x256 patch from HR | |
| hr_patch = self._crop_patch(hr_full, self.patch_size) # (C, 256, 256) | |
| # Initialize degrader and VQGAN on first use | |
| if self.degrader is None: | |
| self.degrader = get_degrader() | |
| if self.vqgan is None: | |
| self.vqgan = get_vqgan_model() | |
| # Apply degradation on-the-fly to generate LR | |
| # Degrader expects (C, H, W) and returns (C, H//scale, W//scale) | |
| hr_patch_gpu = hr_patch.to(device) # (C, 256, 256) | |
| with torch.no_grad(): | |
| lr_patch = self.degrader.degrade(hr_patch_gpu) # (C, 64, 64) in pixel space | |
| # Upsample LR to 256x256 using bicubic interpolation | |
| lr_patch_upsampled = F.interpolate( | |
| lr_patch.unsqueeze(0), # (1, C, 64, 64) | |
| size=(self.patch_size, self.patch_size), | |
| mode='bicubic', | |
| align_corners=False | |
| ).squeeze(0) # (C, 256, 256) | |
| # Apply augmentations (synchronized) | |
| hr_patch, lr_patch_upsampled = self._apply_augmentations( | |
| hr_patch.cpu(), | |
| lr_patch_upsampled.cpu() | |
| ) | |
| # Encode through VQGAN to get latents (64x64) | |
| # Move to device for encoding | |
| hr_patch_gpu = hr_patch.to(device).unsqueeze(0) # (1, C, 256, 256) | |
| lr_patch_gpu = lr_patch_upsampled.to(device).unsqueeze(0) # (1, C, 256, 256) | |
| with torch.no_grad(): | |
| # Encode HR: 256x256 -> 64x64 latent | |
| hr_latent = self.vqgan.encode(hr_patch_gpu) # (1, C, 64, 64) | |
| # Encode LR: 256x256 -> 64x64 latent | |
| lr_latent = self.vqgan.encode(lr_patch_gpu) # (1, C, 64, 64) | |
| # Remove batch dimension and move to CPU | |
| hr_latent = hr_latent.squeeze(0).cpu() # (C, 64, 64) | |
| lr_latent = lr_latent.squeeze(0).cpu() # (C, 64, 64) | |
| return hr_latent, lr_latent | |
| # Create datasets using on-the-fly processing | |
| train_dataset = SRDatasetOnTheFly( | |
| dir_HR=dir_HR, | |
| scale=scale, | |
| patch_size=patch_size | |
| ) | |
| valid_dataset = SRDatasetOnTheFly( | |
| dir_HR=dir_valid_HR, | |
| scale=scale, | |
| patch_size=patch_size | |
| ) | |
| # Mini dataset with 8 images for testing | |
| mini_dataset = SRDatasetOnTheFly( | |
| dir_HR=dir_HR, | |
| scale=scale, | |
| patch_size=patch_size, | |
| max_samples=8 | |
| ) | |
| print(f"\nFull training dataset size: {len(train_dataset)}") | |
| print(f"Full validation dataset size: {len(valid_dataset)}") | |
| print(f"Mini dataset size: {len(mini_dataset)}") | |
| print(f"Using on-the-fly degradation and VQGAN encoding") | |
| print(f"Output: 64x64 latents (from 256x256 patches)") | |