DiffusionSR / src /autoencoder.py
shekkari21's picture
added files for inference
efb85e3
"""
VQGAN Autoencoder module for encoding/decoding images to/from latent space.
"""
import torch
import torch.nn as nn
from pathlib import Path
import sys
import os
from huggingface_hub import hf_hub_download
# Handle import of ldm from latent-diffusion repository
# Check if ldm directory exists locally (from latent-diffusion repo)
_ldm_path = Path(__file__).parent.parent / "ldm"
if _ldm_path.exists() and str(_ldm_path) not in sys.path:
sys.path.insert(0, str(_ldm_path.parent))
try:
from ldm.models.autoencoder import VQModelTorch
except ImportError:
# Fallback: try importing from site-packages if latent-diffusion is installed
try:
import importlib.util
spec = importlib.util.find_spec("ldm.models.autoencoder")
if spec is None:
raise ImportError("Could not find ldm.models.autoencoder")
from ldm.models.autoencoder import VQModelTorch
except ImportError as e:
raise ImportError(
"Could not import VQModelTorch from ldm.models.autoencoder. "
"Please ensure the latent-diffusion repository is cloned and the ldm directory exists, "
"or install latent-diffusion package. Error: " + str(e)
)
from config import (
autoencoder_ckpt_path,
autoencoder_use_fp16,
autoencoder_embed_dim,
autoencoder_n_embed,
autoencoder_double_z,
autoencoder_z_channels,
autoencoder_resolution,
autoencoder_in_channels,
autoencoder_out_ch,
autoencoder_ch,
autoencoder_ch_mult,
autoencoder_num_res_blocks,
autoencoder_attn_resolutions,
autoencoder_dropout,
autoencoder_padding_mode,
_project_root,
device
)
# Hugging Face repo ID for weights
HF_WEIGHTS_REPO_ID = "shekkari21/DiffusionSR-weights"
def load_vqgan(ckpt_path=None, device=device):
"""
Load VQGAN autoencoder from checkpoint.
Args:
ckpt_path: Path to checkpoint file. If None, uses config path.
device: Device to load model on.
Returns:
VQGAN model in eval mode.
"""
if ckpt_path is None:
ckpt_path = autoencoder_ckpt_path
# Resolve path relative to project root
if not Path(ckpt_path).is_absolute():
ckpt_path = _project_root / ckpt_path
# Download from Hugging Face if not found locally
if not Path(ckpt_path).exists():
print(f"VQGAN checkpoint not found locally. Downloading from Hugging Face...")
try:
# Files are in root of weights repo, download to local directory structure
local_weights_dir = _project_root / "pretrained_weights"
local_weights_dir.mkdir(parents=True, exist_ok=True)
# Download from root of weights repo
downloaded_path = hf_hub_download(
repo_id=HF_WEIGHTS_REPO_ID,
filename="autoencoder_vq_f4.pth",
local_dir=str(local_weights_dir),
local_dir_use_symlinks=False
)
ckpt_path = local_weights_dir / "autoencoder_vq_f4.pth"
print(f"✓ Downloaded VQGAN checkpoint: {ckpt_path}")
except Exception as e:
raise FileNotFoundError(
f"VQGAN checkpoint not found at: {ckpt_path}\n"
f"Could not download from Hugging Face: {e}\n"
f"Please ensure the file exists in the repository."
)
print(f"Loading VQGAN from: {ckpt_path}")
# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location=device)
# Extract state_dict
if isinstance(checkpoint, dict):
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
else:
raise ValueError(f"Unexpected checkpoint format: {type(checkpoint)}")
# Create model architecture
ddconfig = {
'double_z': autoencoder_double_z,
'z_channels': autoencoder_z_channels,
'resolution': autoencoder_resolution,
'in_channels': autoencoder_in_channels,
'out_ch': autoencoder_out_ch,
'ch': autoencoder_ch,
'ch_mult': autoencoder_ch_mult,
'num_res_blocks': autoencoder_num_res_blocks,
'attn_resolutions': autoencoder_attn_resolutions,
'dropout': autoencoder_dropout,
'padding_mode': autoencoder_padding_mode,
}
model = VQModelTorch(
ddconfig=ddconfig,
n_embed=autoencoder_n_embed,
embed_dim=autoencoder_embed_dim,
)
# Load state_dict
model.load_state_dict(state_dict, strict=False)
model.eval()
model.to(device)
if autoencoder_use_fp16:
model = model.half()
print(f"VQGAN loaded successfully on {device}")
return model
class VQGANWrapper(nn.Module):
"""
Simple wrapper for VQGAN autoencoder.
"""
def __init__(self, model):
super().__init__()
self.model = model
def encode(self, x):
"""
Encode image to latent space.
Args:
x: (B, 3, H, W) Image tensor in range [0, 1]
Returns:
z: (B, 3, H//4, W//4) Latent tensor
"""
# Ensure model is in eval mode
self.model.eval()
with torch.no_grad():
# Normalize to [-1, 1] if needed
if x.max() <= 1.0:
x = x * 2.0 - 1.0 # [0, 1] -> [-1, 1]
# Match model dtype (handle fp16 models)
model_dtype = next(self.model.parameters()).dtype
if x.dtype != model_dtype:
x = x.to(model_dtype)
# Ensure input is on same device as model
model_device = next(self.model.parameters()).device
if x.device != model_device:
x = x.to(model_device)
# Encode
z = self.model.encode(x)
# Extract latent from tuple/dict if needed
if isinstance(z, (tuple, list)):
z = z[0]
elif isinstance(z, dict):
z = z.get('z', z.get('latent', z))
# Convert back to float32 for consistency
if z.dtype != torch.float32:
z = z.float()
return z
def decode(self, z):
"""
Decode latent to image space.
Args:
z: (B, 3, H, W) Latent tensor
Returns:
x: (B, 3, H*4, W*4) Image tensor in range [0, 1]
"""
with torch.no_grad():
# Match model dtype (handle fp16 models)
model_dtype = next(self.model.parameters()).dtype
if z.dtype != model_dtype:
z = z.to(model_dtype)
# Decode
x = self.model.decode(z)
# Convert back to float32
if x.dtype != torch.float32:
x = x.float()
# Normalize back to [0, 1]
if x.min() < 0:
x = (x + 1.0) / 2.0 # [-1, 1] -> [0, 1]
x = torch.clamp(x, 0, 1)
return x
# Convenience function
def get_vqgan(ckpt_path=None, device=device):
"""Get VQGAN model instance."""
model = load_vqgan(ckpt_path=ckpt_path, device=device)
return VQGANWrapper(model)