import gc import logging import os from typing import List, Optional import numpy as np import torch from diffusers.training_utils import set_seed from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter from depthcrafter.utils import read_video_frames, save_video, vis_sequence_depth logger = logging.getLogger(__name__) class DepthCrafterInference: """ Inference class for DepthCrafter. """ def __init__( self, unet_path: str, pre_train_path: str, cpu_offload: Optional[str] = "model", device: str = "cuda", ): """ Initialize the DepthCrafter inference pipeline. Args: unet_path (str): Path to the UNet model. pre_train_path (str): Path to the pre-trained model. cpu_offload (Optional[str]): CPU offload strategy ("model", "sequential", or None). device (str): Device to run the model on ("cuda" or "cpu"). """ # Determine dtype based on device if device == "cpu": dtype = torch.float32 variant = None else: dtype = torch.float16 variant = "fp16" logger.info(f"Loading UNet from {unet_path}") unet = DiffusersUNetSpatioTemporalConditionModelDepthCrafter.from_pretrained( unet_path, low_cpu_mem_usage=True, torch_dtype=dtype, ) logger.info(f"Loading pipeline from {pre_train_path}") pipeline_kwargs = { "unet": unet, "torch_dtype": dtype, } if variant is not None: pipeline_kwargs["variant"] = variant self.pipe = DepthCrafterPipeline.from_pretrained( pre_train_path, **pipeline_kwargs ) if cpu_offload is not None: if cpu_offload == "sequential": self.pipe.enable_sequential_cpu_offload() elif cpu_offload == "model": self.pipe.enable_model_cpu_offload() else: raise ValueError(f"Unknown cpu offload option: {cpu_offload}") else: self.pipe.to(device) try: self.pipe.enable_xformers_memory_efficient_attention() except (ImportError, ModuleNotFoundError, AttributeError) as e: logger.warning(f"Xformers is not enabled: {e}") self.pipe.enable_attention_slicing() def infer( self, video_path: str, num_denoising_steps: int, guidance_scale: float, save_folder: str = "./demo_output", window_size: int = 110, process_length: int = 195, overlap: int = 25, max_res: int = 1024, dataset: str = "open", target_fps: int = 15, seed: int = 42, track_time: bool = True, save_npz: bool = False, save_exr: bool = False, ) -> List[str]: """ Run inference on a video. Args: video_path (str): Path to the input video. num_denoising_steps (int): Number of denoising steps. guidance_scale (float): Guidance scale. save_folder (str): Folder to save output. window_size (int): Window size for sliding window inference. process_length (int): Maximum number of frames to process. overlap (int): Overlap between windows. max_res (int): Maximum resolution. dataset (str): Dataset name for resolution settings. target_fps (int): Target FPS for output video. seed (int): Random seed. track_time (bool): Whether to track execution time. save_npz (bool): Whether to save depth map as .npz. save_exr (bool): Whether to save depth map as .exr. Returns: List[str]: List of paths to saved files. """ set_seed(seed) frames, target_fps = read_video_frames( video_path, process_length, target_fps, max_res, dataset, ) with torch.inference_mode(): res = self.pipe( frames, height=frames.shape[1], width=frames.shape[2], output_type="np", guidance_scale=guidance_scale, num_inference_steps=num_denoising_steps, window_size=window_size, overlap=overlap, track_time=track_time, ).frames[0] res = res.sum(-1) / res.shape[-1] res = (res - res.min()) / (res.max() - res.min()) vis = vis_sequence_depth(res) save_path = os.path.join( save_folder, os.path.splitext(os.path.basename(video_path))[0] ) os.makedirs(os.path.dirname(save_path), exist_ok=True) save_video(res, save_path + "_depth.mp4", fps=target_fps) save_video(vis, save_path + "_vis.mp4", fps=target_fps) save_video(frames, save_path + "_input.mp4", fps=target_fps) if save_npz: np.savez_compressed(save_path + ".npz", depth=res) if save_exr: self._save_exr(res, save_path) return [ save_path + "_input.mp4", save_path + "_vis.mp4", save_path + "_depth.mp4", ] def _save_exr(self, res: np.ndarray, save_path: str): """ Save results as EXR files. """ try: import OpenEXR import Imath except ImportError: logger.error("OpenEXR or Imath not installed. Skipping EXR saving.") return os.makedirs(save_path, exist_ok=True) logger.info(f"Saving EXR results to {save_path}") for i, frame in enumerate(res): output_exr = f"{save_path}/frame_{i:04d}.exr" header = OpenEXR.Header(frame.shape[1], frame.shape[0]) header["channels"] = { "Z": Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT)) } exr_file = OpenEXR.OutputFile(output_exr, header) exr_file.writePixels({"Z": frame.tobytes()}) exr_file.close() def clear_cache(self): """Clear CUDA cache.""" gc.collect() torch.cuda.empty_cache()