DiffusionSR / app.py
shekkari21's picture
added files for inference
efb85e3
"""
Gradio app for ResShift Super-Resolution
Hosted on Hugging Face Spaces
"""
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pathlib import Path
import sys
from huggingface_hub import hf_hub_download
# Add src to path
sys.path.insert(0, str(Path(__file__).parent / "src"))
from model import FullUNET
from autoencoder import get_vqgan
from noiseControl import resshift_schedule
from config import device, T, k, normalize_input, latent_flag, gt_size
# Hugging Face repo ID for weights
HF_WEIGHTS_REPO_ID = "shekkari21/DiffusionSR-weights"
# Global variables for loaded models
model = None
autoencoder = None
eta_schedule = None
def load_models():
"""Load models on startup."""
global model, autoencoder, eta_schedule
print("Loading models...")
# Load model checkpoint
checkpoint_path = "checkpoints/ckpts/model_3200.pth"
checkpoint_file = Path(checkpoint_path)
# Download from Hugging Face if not found locally
if not checkpoint_file.exists():
# Try to find any checkpoint locally first
ckpt_dir = Path("checkpoints/ckpts")
if ckpt_dir.exists():
checkpoints = list(ckpt_dir.glob("model_*.pth"))
if checkpoints:
checkpoint_path = str(checkpoints[-1]) # Use latest
print(f"Using checkpoint: {checkpoint_path}")
else:
# Download from Hugging Face
print(f"Checkpoint not found locally. Downloading from Hugging Face...")
try:
# Files are in root of weights repo, download to local directory structure
ckpt_dir.mkdir(parents=True, exist_ok=True)
downloaded_path = hf_hub_download(
repo_id=HF_WEIGHTS_REPO_ID,
filename="model_3200.pth",
local_dir=str(ckpt_dir),
local_dir_use_symlinks=False
)
checkpoint_path = str(ckpt_dir / "model_3200.pth")
print(f"✓ Downloaded checkpoint: {checkpoint_path}")
except Exception as e:
raise FileNotFoundError(
f"Could not download checkpoint from Hugging Face: {e}\n"
f"Please ensure the file exists in the repository."
)
else:
# Create directory and download
ckpt_dir.mkdir(parents=True, exist_ok=True)
print(f"Checkpoint not found locally. Downloading from Hugging Face...")
try:
downloaded_path = hf_hub_download(
repo_id=HF_WEIGHTS_REPO_ID,
filename="model_3200.pth",
local_dir=str(ckpt_dir),
local_dir_use_symlinks=False
)
checkpoint_path = str(ckpt_dir / "model_3200.pth")
print(f"✓ Downloaded checkpoint: {checkpoint_path}")
except Exception as e:
raise FileNotFoundError(
f"Could not download checkpoint from Hugging Face: {e}\n"
f"Please ensure the file exists in the repository."
)
model = FullUNET()
model = model.to(device)
ckpt = torch.load(checkpoint_path, map_location=device)
if 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
else:
state_dict = ckpt
# Handle compiled model checkpoints
if any(key.startswith('_orig_mod.') for key in state_dict.keys()):
new_state_dict = {}
for key, val in state_dict.items():
if key.startswith('_orig_mod.'):
new_state_dict[key[10:]] = val
else:
new_state_dict[key] = val
state_dict = new_state_dict
model.load_state_dict(state_dict)
model.eval()
print("✓ Model loaded")
# Load VQGAN autoencoder
autoencoder = get_vqgan()
print("✓ VQGAN autoencoder loaded")
# Initialize noise schedule
eta_schedule = resshift_schedule().to(device)
eta_schedule = eta_schedule[:, None, None, None]
print("✓ Noise schedule initialized")
return "Models loaded successfully!"
def _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag):
"""Scale input based on timestep."""
if normalize_input and latent_flag:
eta_t = eta_schedule[t]
std = torch.sqrt(eta_t * k**2 + 1)
x_t_scaled = x_t / std
else:
x_t_scaled = x_t
return x_t_scaled
def super_resolve(input_image):
"""
Perform super-resolution on input image.
Args:
input_image: PIL Image or numpy array
Returns:
PIL Image of super-resolved output
"""
if input_image is None:
return None
if model is None or autoencoder is None:
return None
try:
# Convert to PIL Image if needed
if isinstance(input_image, Image.Image):
img = input_image
else:
img = Image.fromarray(input_image)
# Resize to target size (256x256)
img = img.resize((gt_size, gt_size), Image.BICUBIC)
# Convert to tensor
img_tensor = TF.to_tensor(img).unsqueeze(0).to(device) # (1, 3, 256, 256)
# Run inference
with torch.no_grad():
# Encode to latent space
lr_latent = autoencoder.encode(img_tensor) # (1, 3, 64, 64)
# Initialize x_t at maximum timestep
epsilon_init = torch.randn_like(lr_latent)
eta_max = eta_schedule[T - 1]
x_t = lr_latent + k * torch.sqrt(eta_max) * epsilon_init
# Full diffusion sampling loop
for t_step in range(T - 1, -1, -1):
t = torch.full((lr_latent.shape[0],), t_step, device=device, dtype=torch.long)
# Scale input
x_t_scaled = _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag)
# Predict x0
x0_pred = model(x_t_scaled, t, lq=lr_latent)
# Compute x_{t-1} using equation (7)
if t_step > 0:
# Equation (7) from ResShift paper:
# μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)
# Σ_θ = κ² * (η_{t-1}/η_t) * α_t
# x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
eta_t = eta_schedule[t_step]
eta_t_minus_1 = eta_schedule[t_step - 1]
# Compute alpha_t = η_t - η_{t-1}
alpha_t = eta_t - eta_t_minus_1
# Compute mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred
mean = (eta_t_minus_1 / eta_t) * x_t + (alpha_t / eta_t) * x0_pred
# Compute variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t
variance = k**2 * (eta_t_minus_1 / eta_t) * alpha_t
# Sample: x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
noise = torch.randn_like(x_t)
nonzero_mask = torch.tensor(1.0 if t_step > 0 else 0.0, device=x_t.device).view(-1, *([1] * (len(x_t.shape) - 1)))
x_t = mean + nonzero_mask * torch.sqrt(variance) * noise
else:
x_t = x0_pred
# Decode back to pixel space
sr_latent = x_t
sr_image = autoencoder.decode(sr_latent) # (1, 3, 256, 256)
sr_image = sr_image.clamp(0, 1)
# Convert to PIL Image
sr_pil = TF.to_pil_image(sr_image.squeeze(0).cpu())
return sr_pil
except Exception as e:
print(f"Error during inference: {str(e)}")
import traceback
traceback.print_exc()
return None
# Create Gradio interface
with gr.Blocks(title="ResShift Super-Resolution") as demo:
gr.Markdown(
"""
# ResShift Super-Resolution
Upload a low-resolution image to get a super-resolved version using ResShift diffusion model.
**Note**: The model performs 4x super-resolution in latent space (256x256 → 256x256 pixel space, but with enhanced quality).
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image (Low Resolution)",
type="pil",
height=300
)
submit_btn = gr.Button("Super-Resolve", variant="primary")
with gr.Column():
output_image = gr.Image(
label="Super-Resolved Output",
type="pil",
height=300
)
status = gr.Textbox(label="Status", value="Loading models...", interactive=False)
# Load models on startup
demo.load(
fn=load_models,
outputs=status,
show_progress=True
)
# Process on button click
submit_btn.click(
fn=super_resolve,
inputs=input_image,
outputs=output_image,
show_progress=True
)
# Also process on image upload
input_image.change(
fn=super_resolve,
inputs=input_image,
outputs=output_image,
show_progress=True
)
if __name__ == "__main__":
demo.launch(share=True)