Spaces:
Running
A newer version of the Gradio SDK is available:
6.1.0
title: ResShift Super-Resolution
emoji: πΌοΈ
colorFrom: blue
colorTo: purple
sdk: gradio
sdk_version: 5.49.1
app_file: app.py
pinned: false
license: mit
DiffusionSR
A from-scratch implementation of the ResShift paper: an efficient diffusion-based super-resolution model that uses a U-Net architecture with Swin Transformer blocks to enhance low-resolution images. This implementation combines the power of diffusion models with transformer-based attention mechanisms for high-quality image super-resolution.
Overview
This project is a complete from-scratch implementation of ResShift, a diffusion model for single image super-resolution (SISR) that efficiently reduces the number of diffusion steps required by shifting the residual between high-resolution and low-resolution images. The model architecture consists of:
- Encoder: 4-stage encoder with residual blocks and time embeddings
- Bottleneck: Swin Transformer blocks for global feature modeling
- Decoder: 4-stage decoder with skip connections from the encoder
- Noise Schedule: ResShift schedule (15 timesteps) for the diffusion process
Features
- ResShift Implementation: Complete from-scratch implementation of the ResShift paper
- Efficient Diffusion: Residual shifting mechanism reduces required diffusion steps
- U-Net Architecture: Encoder-decoder structure with skip connections
- Swin Transformer: Window-based attention mechanism in the bottleneck
- Time Conditioning: Sinusoidal time embeddings for diffusion timesteps
- DIV2K Dataset: Trained on DIV2K high-quality image dataset
- Comprehensive Evaluation: Metrics include PSNR, SSIM, and LPIPS
Requirements
- Python >= 3.11
- PyTorch >= 2.9.1
- uv (Python package manager)
Installation
1. Clone the Repository
git clone <repository-url>
cd DiffusionSR
2. Install uv (if not already installed)
# On macOS and Linux
curl -LsSf https://astral.sh/uv/install.sh | sh
# Or using pip
pip install uv
3. Create Virtual Environment and Install Dependencies
# Create virtual environment and install dependencies
uv venv
# Activate the virtual environment
# On macOS/Linux:
source .venv/bin/activate
# On Windows:
# .venv\Scripts\activate
# Install project dependencies
uv pip install -e .
Alternatively, you can use uv's sync command:
uv sync
Dataset Setup
The model expects the DIV2K dataset in the following structure:
data/
βββ DIV2K_train_HR/ # High-resolution training images
βββ DIV2K_train_LR_bicubic/
βββ X4/ # Low-resolution images (4x downsampled)
Download DIV2K Dataset
- Download the DIV2K dataset from the official website
- Extract the files to the
data/directory - Ensure the directory structure matches the above
Note: Update the paths in src/data.py (lines 75-76) to match your dataset location:
train_dataset = SRDataset(
dir_HR = 'path/to/DIV2K_train_HR',
dir_LR = 'path/to/DIV2K_train_LR_bicubic/X4',
scale=4,
patch_size=256
)
Usage
Training
To train the model, run:
python src/train.py
The training script will:
- Load the dataset using the
SRDatasetclass - Initialize the
FullUNETmodel - Train using the ResShift noise schedule
- Save training progress and loss values
Training Configuration
Current training parameters (in src/train.py):
- Batch size: 4
- Learning rate: 1e-4
- Optimizer: Adam (betas: 0.9, 0.999)
- Loss function: MSE Loss
- Gradient clipping: 1.0
- Training steps: 150
- Scale factor: 4x
- Patch size: 256x256
You can modify these parameters directly in src/train.py to suit your needs.
Evaluation
The model performance is evaluated using the following metrics:
PSNR (Peak Signal-to-Noise Ratio): Measures the ratio between the maximum possible power of a signal and the power of corrupting noise. Higher PSNR values indicate better image quality reconstruction.
SSIM (Structural Similarity Index Measure): Assesses the similarity between two images based on luminance, contrast, and structure. SSIM values range from -1 to 1, with higher values (closer to 1) indicating greater similarity to the ground truth.
LPIPS (Learned Perceptual Image Patch Similarity): Evaluates perceptual similarity between images using deep network features. Lower LPIPS values indicate images that are more perceptually similar to the reference image.
To run evaluation (once implemented), use:
python src/test.py
Project Structure
DiffusionSR/
βββ data/ # Dataset directory (not tracked in git)
β βββ DIV2K_train_HR/
β βββ DIV2K_train_LR_bicubic/
βββ src/
β βββ config.py # Configuration file
β βββ data.py # Dataset class and data loading
β βββ model.py # U-Net model architecture
β βββ noiseControl.py # ResShift noise schedule
β βββ train.py # Training script
β βββ test.py # Testing script (to be implemented)
βββ pyproject.toml # Project dependencies and metadata
βββ uv.lock # Locked dependency versions
βββ README.md # This file
Model Architecture
Encoder
- Initial Conv: 3 β 64 channels
- Stage 1: 64 β 128 channels, 256Γ256 β 128Γ128
- Stage 2: 128 β 256 channels, 128Γ128 β 64Γ64
- Stage 3: 256 β 512 channels, 64Γ64 β 32Γ32
- Stage 4: 512 channels (no downsampling)
Bottleneck
- Residual blocks with Swin Transformer blocks
- Window size: 7Γ7
- Shifted window attention for global context
Decoder
- Stage 1: 512 β 256 channels, 32Γ32 β 64Γ64
- Stage 2: 256 β 128 channels, 64Γ64 β 128Γ128
- Stage 3: 128 β 64 channels, 128Γ128 β 256Γ256
- Stage 4: 64 β 64 channels
- Final Conv: 64 β 3 channels (RGB output)
Key Components
ResShift Noise Schedule
The model implements the ResShift noise schedule as described in the original paper, defined in src/noiseControl.py:
- 15 timesteps (0-14)
- Parameters:
eta1=0.001,etaT=0.999,p=0.8 - Efficiently shifts the residual between HR and LR images during the diffusion process
Time Embeddings
Sinusoidal embeddings are used to condition the model on diffusion timesteps, similar to positional encodings in transformers.
Data Augmentation
The dataset includes:
- Random cropping (aligned between HR and LR)
- Random horizontal/vertical flips
- Random 180Β° rotation
Development
Adding New Features
- Model modifications: Edit
src/model.py - Training changes: Modify
src/train.py - Data pipeline: Update
src/data.py - Configuration: Add settings to
src/config.py
License
[Add your license here]
Citation
If you use this code in your research, please cite the original ResShift paper:
@article{yue2023resshift,
title={ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting},
author={Yue, Zongsheng and Wang, Jianyi and Loy, Chen Change},
journal={arXiv preprint arXiv:2307.12348},
year={2023}
}
Acknowledgments
- ResShift Authors: Zongsheng Yue, Jianyi Wang, and Chen Change Loy for their foundational work on efficient diffusion-based super-resolution
- DIV2K dataset providers
- PyTorch community
- Swin Transformer architecture inspiration