#!/usr/bin/env python # -*- coding:utf-8 -*- """ Inference script for ResShift diffusion model. Performs super-resolution on LR images using full diffusion sampling. Consistent with original ResShift inference interface. """ import os import sys import argparse from pathlib import Path import torch import torch.nn as nn from PIL import Image import torchvision.transforms.functional as TF import numpy as np from tqdm import tqdm from model import FullUNET from autoencoder import get_vqgan from noiseControl import resshift_schedule from config import ( device, T, k, normalize_input, latent_flag, autoencoder_ckpt_path, _project_root, image_size, # Latent space size (64) gt_size, # Pixel space size (256) sf, # Scale factor (4) ) def get_parser(**parser_kwargs): """Parse command-line arguments.""" parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument( "-i", "--in_path", type=str, required=True, help="Input path (image file or directory)." ) parser.add_argument( "-o", "--out_path", type=str, default="./results", help="Output path (image file or directory)." ) parser.add_argument( "--checkpoint", type=str, required=True, help="Path to model checkpoint (e.g., checkpoints/ckpts/model_1500.pth)." ) parser.add_argument( "--ema_checkpoint", type=str, default=None, help="Path to EMA checkpoint (optional, e.g., checkpoints/ckpts/ema_model_1500.pth)." ) parser.add_argument( "--use_ema", action="store_true", help="Use EMA model for inference (requires --ema_checkpoint)." ) parser.add_argument( "--scale", type=int, default=4, help="Scale factor for SR (default: 4)." ) parser.add_argument( "--seed", type=int, default=12345, help="Random seed for reproducibility." ) parser.add_argument( "--bs", type=int, default=1, help="Batch size for inference." ) parser.add_argument( "--chop_size", type=int, default=512, choices=[512, 256, 64], help="Chopping size for large images (default: 512)." ) parser.add_argument( "--chop_stride", type=int, default=-1, help="Chopping stride (default: auto-calculated)." ) parser.add_argument( "--chop_bs", type=int, default=1, help="Batch size for chopping (default: 1)." ) parser.add_argument( "--use_amp", action="store_true", default=True, help="Use automatic mixed precision (default: True)." ) return parser.parse_args() def set_seed(seed): """Set random seed for reproducibility.""" import random random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def load_image(image_path): """ Load and preprocess image for inference. Args: image_path: Path to input image Returns: Preprocessed image tensor (1, 3, H, W) in [0, 1] range Original image size (H, W) """ # Load image img = Image.open(image_path).convert("RGB") orig_size = img.size # (W, H) # Calculate target size (LR should be downscaled by scale factor) # For 4x SR: if input is 256x256, it's already LR, output will be 1024x1024 # But we work in 256x256 pixel space, so we keep input at 256x256 target_size = gt_size # 256x256 # Resize to target size (bicubic interpolation) img = img.resize((target_size, target_size), Image.BICUBIC) # Convert to tensor and normalize to [0, 1] img_tensor = TF.to_tensor(img).unsqueeze(0) # (1, 3, H, W) return img_tensor, orig_size def save_image(tensor, save_path, orig_size=None): """ Save tensor image to file. Args: tensor: Image tensor (1, 3, H, W) in [0, 1] save_path: Path to save image orig_size: Original image size (W, H) for optional resize """ # Convert to PIL Image img = TF.to_pil_image(tensor.squeeze(0).cpu()) # Optionally resize to original size scaled by scale factor if orig_size is not None: target_size = (orig_size[0] * sf, orig_size[1] * sf) img = img.resize(target_size, Image.LANCZOS) # Save image save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) img.save(save_path) print(f"✓ Saved SR image to: {save_path}") def _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag): """ Scale input based on timestep for training stability. Args: x_t: Noisy input tensor (B, C, H, W) t: Timestep tensor (B,) eta_schedule: Noise schedule (T, 1, 1, 1) k: Noise scaling factor normalize_input: Whether to normalize input latent_flag: Whether working in latent space Returns: Scaled input tensor """ if normalize_input and latent_flag: eta_t = eta_schedule[t] # (B, 1, 1, 1) std = torch.sqrt(eta_t * k**2 + 1) x_t_scaled = x_t / std else: x_t_scaled = x_t return x_t_scaled def inference_single_image( model, autoencoder, lr_image_tensor, eta_schedule, device, T=15, k=2.0, normalize_input=True, latent_flag=True, use_amp=False, ): """ Perform inference on a single LR image using full diffusion sampling. Args: model: Trained ResShift model autoencoder: VQGAN autoencoder for encoding/decoding lr_image_tensor: LR image tensor (1, 3, 256, 256) in [0, 1] eta_schedule: Noise schedule (T, 1, 1, 1) device: Device to run inference on T: Number of diffusion timesteps k: Noise scaling factor normalize_input: Whether to normalize input latent_flag: Whether working in latent space use_amp: Whether to use automatic mixed precision Returns: SR image tensor (1, 3, 256, 256) in [0, 1] """ model.eval() # Move to device lr_image_tensor = lr_image_tensor.to(device) # Autocast context if use_amp and torch.cuda.is_available(): autocast_context = torch.amp.autocast('cuda') else: from contextlib import nullcontext autocast_context = nullcontext() with torch.no_grad(): # Encode LR image to latent space lr_latent = autoencoder.encode(lr_image_tensor) # (1, 3, 64, 64) # Initialize x_t at maximum timestep (T-1) # Start from LR with maximum noise epsilon_init = torch.randn_like(lr_latent) eta_max = eta_schedule[T - 1] # Start from noisy LR x_t = lr_latent + k * torch.sqrt(eta_max) * epsilon_init # Full diffusion sampling loop for t_step in range(T - 1, -1, -1): # T-1, T-2, ..., 1, 0 t = torch.full((lr_latent.shape[0],), t_step, device=device, dtype=torch.long) # Scale input if needed x_t_scaled = _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag) # Predict x0 from current noisy state with autocast_context: x0_pred = model(x_t_scaled, t, lq=lr_latent) # If not the last step, compute x_{t-1} from predicted x0 using equation (7) if t_step > 0: # Equation (7) from ResShift paper: # μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t) # Σ_θ = κ² * (η_{t-1}/η_t) * α_t # x_{t-1} = μ_θ + sqrt(Σ_θ) * ε eta_t = eta_schedule[t_step] eta_t_minus_1 = eta_schedule[t_step - 1] # Compute alpha_t = η_t - η_{t-1} alpha_t = eta_t - eta_t_minus_1 # Compute mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred mean = (eta_t_minus_1 / eta_t) * x_t + (alpha_t / eta_t) * x0_pred # Compute variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t variance = k**2 * (eta_t_minus_1 / eta_t) * alpha_t # Sample: x_{t-1} = μ_θ + sqrt(Σ_θ) * ε noise = torch.randn_like(x_t) nonzero_mask = torch.tensor(1.0 if t_step > 0 else 0.0, device=x_t.device).view(-1, *([1] * (len(x_t.shape) - 1))) x_t = mean + nonzero_mask * torch.sqrt(variance) * noise else: # Final step: use predicted x0 x_t = x0_pred # Final prediction sr_latent = x_t # Decode back to pixel space sr_image = autoencoder.decode(sr_latent) # (1, 3, 256, 256) # Clamp to [0, 1] sr_image = sr_image.clamp(0, 1) return sr_image def inference_with_chopping( model, autoencoder, lr_image_tensor, eta_schedule, device, chop_size=512, chop_stride=448, chop_bs=1, T=15, k=2.0, normalize_input=True, latent_flag=True, use_amp=False, ): """ Perform inference with chopping for large images. Args: model: Trained ResShift model autoencoder: VQGAN autoencoder lr_image_tensor: LR image tensor (1, 3, H, W) eta_schedule: Noise schedule device: Device to run inference on chop_size: Size of each patch chop_stride: Stride between patches chop_bs: Batch size for chopping T: Number of diffusion timesteps k: Noise scaling factor normalize_input: Whether to normalize input latent_flag: Whether working in latent space use_amp: Whether to use AMP Returns: SR image tensor (1, 3, H*sf, W*sf) """ # For now, implement simple version without chopping # Full chopping implementation would require more complex logic # This is a placeholder that processes the full image return inference_single_image( model, autoencoder, lr_image_tensor, eta_schedule, device, T, k, normalize_input, latent_flag, use_amp ) def load_model(checkpoint_path, ema_checkpoint_path=None, use_ema=False, device=device): """ Load model from checkpoint. Args: checkpoint_path: Path to model checkpoint ema_checkpoint_path: Path to EMA checkpoint (optional) use_ema: Whether to use EMA model device: Device to load model on Returns: Loaded model """ print(f"Loading model from: {checkpoint_path}") model = FullUNET() model = model.to(device) # Load checkpoint ckpt = torch.load(checkpoint_path, map_location=device) if 'state_dict' in ckpt: state_dict = ckpt['state_dict'] else: state_dict = ckpt # Handle compiled model checkpoints (strip _orig_mod. prefix) if any(k.startswith('_orig_mod.') for k in state_dict.keys()): print(" Detected compiled model checkpoint, stripping _orig_mod. prefix...") new_state_dict = {} for k, v in state_dict.items(): if k.startswith('_orig_mod.'): new_state_dict[k[10:]] = v # Remove '_orig_mod.' prefix else: new_state_dict[k] = v state_dict = new_state_dict model.load_state_dict(state_dict) print("✓ Model loaded") # Load EMA if requested if use_ema and ema_checkpoint_path: print(f"Loading EMA model from: {ema_checkpoint_path}") from ema import EMA ema = EMA(model, ema_rate=0.999, device=device) ema_ckpt = torch.load(ema_checkpoint_path, map_location=device) # Handle compiled model checkpoints (strip _orig_mod. prefix) if any(k.startswith('_orig_mod.') for k in ema_ckpt.keys()): print(" Detected compiled model in EMA checkpoint, stripping _orig_mod. prefix...") new_ema_ckpt = {} for k, v in ema_ckpt.items(): if k.startswith('_orig_mod.'): new_ema_ckpt[k[10:]] = v # Remove '_orig_mod.' prefix else: new_ema_ckpt[k] = v ema_ckpt = new_ema_ckpt ema.load_state_dict(ema_ckpt) ema.apply_to_model(model) print("✓ EMA model loaded and applied") return model def main(): args = get_parser() print("=" * 80) print("ResShift Inference") print("=" * 80) # Set random seed set_seed(args.seed) # Validate scale factor assert args.scale == 4, "We only support 4x super-resolution now!" # Calculate chopping stride if not provided if args.chop_stride < 0: if args.chop_size == 512: chop_stride = (512 - 64) * (4 // args.scale) elif args.chop_size == 256: chop_stride = (256 - 32) * (4 // args.scale) elif args.chop_size == 64: chop_stride = (64 - 16) * (4 // args.scale) else: raise ValueError("Chop size must be in [512, 256, 64]") else: chop_stride = args.chop_stride * (4 // args.scale) chop_size = args.chop_size * (4 // args.scale) print(f"Chopping size/stride: {chop_size}/{chop_stride}") # Load model model = load_model( args.checkpoint, args.ema_checkpoint, args.use_ema, device ) # Load VQGAN autoencoder print("\nLoading VQGAN autoencoder...") autoencoder = get_vqgan() print("✓ VQGAN autoencoder loaded") # Initialize noise schedule print("\nInitializing noise schedule...") eta = resshift_schedule().to(device) eta = eta[:, None, None, None] # (T, 1, 1, 1) print("✓ Noise schedule initialized") # Prepare input/output paths in_path = Path(args.in_path) out_path = Path(args.out_path) # Determine if input is file or directory if in_path.is_file(): input_files = [in_path] if out_path.suffix: # Output is a file output_files = [out_path] else: # Output is a directory output_files = [out_path / in_path.name] elif in_path.is_dir(): # Get all image files from directory image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'} input_files = [f for f in in_path.iterdir() if f.suffix.lower() in image_extensions] output_files = [out_path / f.name for f in input_files] out_path.mkdir(parents=True, exist_ok=True) else: raise ValueError(f"Input path does not exist: {in_path}") if not input_files: raise ValueError(f"No image files found in: {in_path}") print(f"\nFound {len(input_files)} image(s) to process") # Process each image print("\n" + "=" * 80) print("Running Inference") print("=" * 80) for idx, (input_file, output_file) in enumerate(zip(input_files, output_files), 1): print(f"\n[{idx}/{len(input_files)}] Processing: {input_file.name}") # Load input image lr_image, orig_size = load_image(input_file) # Run inference if args.chop_size < 512: # Use chopping for large images sr_image = inference_with_chopping( model=model, autoencoder=autoencoder, lr_image_tensor=lr_image, eta_schedule=eta, device=device, chop_size=chop_size, chop_stride=chop_stride, chop_bs=args.chop_bs, T=T, k=k, normalize_input=normalize_input, latent_flag=latent_flag, use_amp=args.use_amp, ) else: sr_image = inference_single_image( model=model, autoencoder=autoencoder, lr_image_tensor=lr_image, eta_schedule=eta, device=device, T=T, k=k, normalize_input=normalize_input, latent_flag=latent_flag, use_amp=args.use_amp, ) # Save output save_image(sr_image, output_file, orig_size=orig_size) print("\n" + "=" * 80) print("Inference Complete!") print("=" * 80) if __name__ == "__main__": main()