Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090) | |
| # Github source: https://github.com/aim-uofa/GenPercept | |
| # Copyright (c) 2024, Advanced Intelligent Machines (AIM) | |
| # Licensed under The BSD 2-Clause License [see LICENSE for details] | |
| # By Guangkai Xu | |
| # Based on Marigold, diffusers codebases | |
| # https://github.com/prs-eth/marigold | |
| # https://github.com/huggingface/diffusers | |
| # -------------------------------------------------------- | |
| import logging | |
| from typing import Dict, Optional, Union | |
| import numpy as np | |
| import torch | |
| from diffusers import ( | |
| AutoencoderKL, | |
| DDIMScheduler, | |
| DiffusionPipeline, | |
| LCMScheduler, | |
| UNet2DConditionModel, | |
| ) | |
| from diffusers.utils import BaseOutput | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from torchvision.transforms import InterpolationMode | |
| from torchvision.transforms.functional import pil_to_tensor, resize | |
| from tqdm.auto import tqdm | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from .util.batchsize import find_batch_size | |
| from .util.ensemble import ensemble_depth | |
| from .util.image_util import ( | |
| chw2hwc, | |
| colorize_depth_maps, | |
| get_tv_resample_method, | |
| resize_max_res, | |
| ) | |
| import matplotlib.pyplot as plt | |
| from genpercept.models.dpt_head import DPTNeckHeadForUnetAfterUpsampleIdentity | |
| class GenPerceptOutput(BaseOutput): | |
| """ | |
| Output class for GenPercept general perception pipeline. | |
| Args: | |
| pred_np (`np.ndarray`): | |
| Predicted result, with values in the range of [0, 1]. | |
| pred_colored (`PIL.Image.Image`): | |
| Colorized result, with the shape of [3, H, W] and values in [0, 1]. | |
| """ | |
| pred_np: np.ndarray | |
| pred_colored: Union[None, Image.Image] | |
| class GenPerceptPipeline(DiffusionPipeline): | |
| """ | |
| Pipeline for general perception using GenPercept: https://github.com/aim-uofa/GenPercept. | |
| This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | |
| library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | |
| Args: | |
| unet (`UNet2DConditionModel`): | |
| Conditional U-Net to denoise the perception latent, conditioned on image latent. | |
| vae (`AutoencoderKL`): | |
| Variational Auto-Encoder (VAE) Model to encode and decode images and results | |
| to and from latent representations. | |
| scheduler (`DDIMScheduler`): | |
| A scheduler to be used in combination with `unet` to denoise the encoded image latents. | |
| text_encoder (`CLIPTextModel`): | |
| Text-encoder, for empty text embedding. | |
| tokenizer (`CLIPTokenizer`): | |
| CLIP tokenizer. | |
| default_denoising_steps (`int`, *optional*): | |
| The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable | |
| quality with the given model. This value must be set in the model config. When the pipeline is called | |
| without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure | |
| reasonable results with various model flavors compatible with the pipeline, such as those relying on very | |
| short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). | |
| default_processing_resolution (`int`, *optional*): | |
| The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in | |
| the model config. When the pipeline is called without explicitly setting `processing_resolution`, the | |
| default value is used. This is required to ensure reasonable results with various model flavors trained | |
| with varying optimal processing resolution values. | |
| """ | |
| latent_scale_factor = 0.18215 | |
| def __init__( | |
| self, | |
| unet: UNet2DConditionModel, | |
| vae: AutoencoderKL, | |
| scheduler: Union[DDIMScheduler, LCMScheduler], | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| default_denoising_steps: Optional[int] = 10, | |
| default_processing_resolution: Optional[int] = 768, | |
| rgb_blending = False, | |
| customized_head = None, | |
| genpercept_pipeline = True, | |
| ): | |
| super().__init__() | |
| self.genpercept_pipeline = genpercept_pipeline | |
| if self.genpercept_pipeline: | |
| default_denoising_steps = 1 | |
| rgb_blending = True | |
| self.register_modules( | |
| unet=unet, | |
| customized_head=customized_head, | |
| vae=vae, | |
| scheduler=scheduler, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| ) | |
| self.register_to_config( | |
| default_denoising_steps=default_denoising_steps, | |
| default_processing_resolution=default_processing_resolution, | |
| rgb_blending=rgb_blending, | |
| ) | |
| self.default_denoising_steps = default_denoising_steps | |
| self.default_processing_resolution = default_processing_resolution | |
| self.rgb_blending = rgb_blending | |
| self.text_embed = None | |
| self.customized_head = customized_head | |
| if self.customized_head: | |
| assert self.rgb_blending and self.scheduler.beta_start == 1 and self.scheduler.beta_end == 1 | |
| assert self.genpercept_pipeline | |
| def __call__( | |
| self, | |
| input_image: Union[Image.Image, torch.Tensor], | |
| denoising_steps: Optional[int] = None, | |
| ensemble_size: int = 1, | |
| processing_res: Optional[int] = None, | |
| match_input_res: bool = True, | |
| resample_method: str = "bilinear", | |
| batch_size: int = 0, | |
| generator: Union[torch.Generator, None] = None, | |
| color_map: Union[str, None] = None, | |
| show_progress_bar: bool = True, | |
| ensemble_kwargs: Dict = None, | |
| mode = None, | |
| fix_timesteps = None, | |
| prompt = "", | |
| ) -> GenPerceptOutput: | |
| """ | |
| Function invoked when calling the pipeline. | |
| Args: | |
| input_image (`Image`): | |
| Input RGB (or gray-scale) image. | |
| denoising_steps (`int`, *optional*, defaults to `None`): | |
| Number of denoising diffusion steps during inference. The default value `None` results in automatic | |
| selection. | |
| ensemble_size (`int`, *optional*, defaults to `10`): | |
| Number of predictions to be ensembled. | |
| processing_res (`int`, *optional*, defaults to `None`): | |
| Effective processing resolution. When set to `0`, processes at the original image resolution. This | |
| produces crisper predictions, but may also lead to the overall loss of global context. The default | |
| value `None` resolves to the optimal value from the model config. | |
| match_input_res (`bool`, *optional*, defaults to `True`): | |
| Resize perception result to match input resolution. | |
| Only valid if `processing_res` > 0. | |
| resample_method: (`str`, *optional*, defaults to `bilinear`): | |
| Resampling method used to resize images and perception results. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. | |
| batch_size (`int`, *optional*, defaults to `0`): | |
| Inference batch size, no bigger than `num_ensemble`. | |
| If set to 0, the script will automatically decide the proper batch size. | |
| generator (`torch.Generator`, *optional*, defaults to `None`) | |
| Random generator for initial noise generation. | |
| show_progress_bar (`bool`, *optional*, defaults to `True`): | |
| Display a progress bar of diffusion denoising. | |
| color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized result generation): | |
| Colormap used to colorize the result. | |
| ensemble_kwargs (`dict`, *optional*, defaults to `None`): | |
| Arguments for detailed ensembling settings. | |
| Returns: | |
| `GenPerceptOutput`: Output class for GenPercept general perception pipeline, including: | |
| - **pred_np** (`np.ndarray`) Predicted result, with values in the range of [0, 1] | |
| - **pred_colored** (`PIL.Image.Image`) Colorized result, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None` | |
| """ | |
| assert mode is not None, "mode of GenPerceptPipeline can be chosen from ['depth', 'normal', 'seg', 'matting', 'dis']." | |
| self.mode = mode | |
| # Model-specific optimal default values leading to fast and reasonable results. | |
| if denoising_steps is None: | |
| denoising_steps = self.default_denoising_steps | |
| if processing_res is None: | |
| processing_res = self.default_processing_resolution | |
| assert processing_res >= 0 | |
| assert ensemble_size >= 1 | |
| if self.genpercept_pipeline: | |
| assert ensemble_size == 1 | |
| assert denoising_steps == 1 | |
| else: | |
| # Check if denoising step is reasonable | |
| self._check_inference_step(denoising_steps) | |
| resample_method: InterpolationMode = get_tv_resample_method(resample_method) | |
| # ----------------- Image Preprocess ----------------- | |
| # Convert to torch tensor | |
| if isinstance(input_image, Image.Image): | |
| input_image = input_image.convert("RGB") | |
| # convert to torch tensor [H, W, rgb] -> [rgb, H, W] | |
| rgb = pil_to_tensor(input_image) | |
| rgb = rgb.unsqueeze(0) # [1, rgb, H, W] | |
| elif isinstance(input_image, torch.Tensor): | |
| rgb = input_image | |
| else: | |
| raise TypeError(f"Unknown input type: {type(input_image) = }") | |
| input_size = rgb.shape | |
| assert ( | |
| 4 == rgb.dim() and 3 == input_size[-3] | |
| ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" | |
| # Resize image | |
| if processing_res > 0: | |
| rgb = resize_max_res( | |
| rgb, | |
| max_edge_resolution=processing_res, | |
| resample_method=resample_method, | |
| ) | |
| # Normalize rgb values | |
| rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] | |
| rgb_norm = rgb_norm.to(self.dtype) | |
| assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 | |
| # ----------------- Perception Inference ----------------- | |
| # Batch repeated input image | |
| duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1) | |
| single_rgb_dataset = TensorDataset(duplicated_rgb) | |
| if batch_size > 0: | |
| _bs = batch_size | |
| else: | |
| _bs = find_batch_size( | |
| ensemble_size=ensemble_size, | |
| input_res=max(rgb_norm.shape[1:]), | |
| dtype=self.dtype, | |
| ) | |
| single_rgb_loader = DataLoader( | |
| single_rgb_dataset, batch_size=_bs, shuffle=False | |
| ) | |
| # Predict results (batched) | |
| pipe_pred_ls = [] | |
| if show_progress_bar: | |
| iterable = tqdm( | |
| single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False | |
| ) | |
| else: | |
| iterable = single_rgb_loader | |
| for batch in iterable: | |
| (batched_img,) = batch | |
| pipe_pred_raw = self.single_infer( | |
| rgb_in=batched_img, | |
| num_inference_steps=denoising_steps, | |
| show_pbar=show_progress_bar, | |
| generator=generator, | |
| fix_timesteps=fix_timesteps, | |
| prompt=prompt, | |
| ) | |
| pipe_pred_ls.append(pipe_pred_raw.detach()) | |
| pipe_preds = torch.concat(pipe_pred_ls, dim=0) | |
| torch.cuda.empty_cache() # clear vram cache for ensembling | |
| # ----------------- Test-time ensembling ----------------- | |
| if ensemble_size > 1: | |
| pipe_pred, _ = ensemble_depth( | |
| pipe_preds, | |
| scale_invariant=True, | |
| shift_invariant=True, | |
| max_res=50, | |
| **(ensemble_kwargs or {}), | |
| ) | |
| else: | |
| pipe_pred = pipe_preds | |
| # Resize back to original resolution | |
| if match_input_res: | |
| pipe_pred = resize( | |
| pipe_pred, | |
| input_size[-2:], | |
| interpolation=resample_method, | |
| antialias=True, | |
| ) | |
| # Convert to numpy | |
| pipe_pred = pipe_pred.squeeze() | |
| pipe_pred = pipe_pred.cpu().numpy() | |
| # Clip output range | |
| pipe_pred = pipe_pred.clip(0, 1) | |
| # Colorize | |
| if color_map is not None: | |
| assert self.mode in ['depth', 'disparity'] | |
| pred_colored = colorize_depth_maps( | |
| pipe_pred, 0, 1, cmap=color_map | |
| ).squeeze() # [3, H, W], value in (0, 1) | |
| pred_colored = (pred_colored * 255).astype(np.uint8) | |
| pred_colored_hwc = chw2hwc(pred_colored) | |
| pred_colored_img = Image.fromarray(pred_colored_hwc) | |
| else: | |
| pred_colored_img = (pipe_pred * 255.0).astype(np.uint8) | |
| if len(pred_colored_img.shape) == 3 and pred_colored_img.shape[0] == 3: | |
| pred_colored_img = np.transpose(pred_colored_img, (1, 2, 0)) | |
| pred_colored_img = Image.fromarray(pred_colored_img) | |
| if len(pipe_pred.shape) == 3 and pipe_pred.shape[0] == 3: | |
| pipe_pred = np.transpose(pipe_pred, (1, 2, 0)) | |
| return GenPerceptOutput( | |
| pred_np=pipe_pred, | |
| pred_colored=pred_colored_img, | |
| ) | |
| def _check_inference_step(self, n_step: int) -> None: | |
| """ | |
| Check if denoising step is reasonable | |
| Args: | |
| n_step (`int`): denoising steps | |
| """ | |
| assert n_step >= 1 | |
| if isinstance(self.scheduler, DDIMScheduler): | |
| if n_step < 10: | |
| logging.warning( | |
| f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference." | |
| ) | |
| elif isinstance(self.scheduler, LCMScheduler): | |
| if not 1 <= n_step <= 4: | |
| logging.warning( | |
| f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps." | |
| ) | |
| else: | |
| raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") | |
| def encode_text(self, prompt): | |
| """ | |
| Encode text embedding for empty prompt | |
| """ | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="do_not_pad", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) | |
| self.text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) | |
| def single_infer( | |
| self, | |
| rgb_in: torch.Tensor, | |
| num_inference_steps: int, | |
| generator: Union[torch.Generator, None], | |
| show_pbar: bool, | |
| fix_timesteps = None, | |
| prompt = "", | |
| ) -> torch.Tensor: | |
| """ | |
| Perform an individual perception result without ensembling. | |
| Args: | |
| rgb_in (`torch.Tensor`): | |
| Input RGB image. | |
| num_inference_steps (`int`): | |
| Number of diffusion denoisign steps (DDIM) during inference. | |
| show_pbar (`bool`): | |
| Display a progress bar of diffusion denoising. | |
| generator (`torch.Generator`) | |
| Random generator for initial noise generation. | |
| Returns: | |
| `torch.Tensor`: Predicted result. | |
| """ | |
| device = self.device | |
| rgb_in = rgb_in.to(device) | |
| # Set timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| if fix_timesteps: | |
| timesteps = torch.tensor([fix_timesteps]).long().repeat(self.scheduler.timesteps.shape[0]).to(device) | |
| else: | |
| timesteps = self.scheduler.timesteps # [T] | |
| # Encode image | |
| rgb_latent = self.encode_rgb(rgb_in) | |
| if not (self.rgb_blending or self.genpercept_pipeline): | |
| # Initial result (noise) | |
| pred_latent = torch.randn( | |
| rgb_latent.shape, | |
| device=device, | |
| dtype=self.dtype, | |
| generator=generator, | |
| ) # [B, 4, h, w] | |
| else: | |
| pred_latent = rgb_latent | |
| # Batched empty text embedding | |
| if self.text_embed is None: | |
| self.encode_text(prompt) | |
| batch_text_embed = self.text_embed.repeat( | |
| (rgb_latent.shape[0], 1, 1) | |
| ).to(device) # [B, 2, 1024] | |
| # Denoising loop | |
| if show_pbar: | |
| iterable = tqdm( | |
| enumerate(timesteps), | |
| total=len(timesteps), | |
| leave=False, | |
| desc=" " * 4 + "Diffusion denoising", | |
| ) | |
| else: | |
| iterable = enumerate(timesteps) | |
| if not self.customized_head: | |
| for i, t in iterable: | |
| if self.genpercept_pipeline and i > 0: | |
| assert ValueError, "GenPercept only forward once." | |
| if not (self.rgb_blending or self.genpercept_pipeline): | |
| unet_input = torch.cat( | |
| [rgb_latent, pred_latent], dim=1 | |
| ) # this order is important | |
| else: | |
| unet_input = pred_latent | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| unet_input, t, encoder_hidden_states=batch_text_embed | |
| ).sample # [B, 4, h, w] | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| step_output = self.scheduler.step( | |
| noise_pred, t, pred_latent, generator=generator | |
| ) | |
| pred_latent = step_output.prev_sample | |
| pred_latent = step_output.pred_original_sample # NOTE: for GenPercept, it is equivalent to "pred_latent = - noise_pred" | |
| pred = self.decode_pred(pred_latent) | |
| # clip prediction | |
| pred = torch.clip(pred, -1.0, 1.0) | |
| # shift to [0, 1] | |
| pred = (pred + 1.0) / 2.0 | |
| elif isinstance(self.customized_head, DPTNeckHeadForUnetAfterUpsampleIdentity): | |
| unet_input = pred_latent | |
| model_pred_output = self.unet( | |
| unet_input, timesteps, encoder_hidden_states=batch_text_embed, return_feature=True | |
| ) # [B, 4, h, w] | |
| unet_features = model_pred_output.multi_level_feats[::-1] | |
| pred = self.customized_head(hidden_states=unet_features).prediction[:, None] | |
| # shift to [0, 1] | |
| pred = (pred - pred.min()) / (pred.max() - pred.min()) | |
| else: | |
| raise ValueError | |
| return pred | |
| def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Encode RGB image into latent. | |
| Args: | |
| rgb_in (`torch.Tensor`): | |
| Input RGB image to be encoded. | |
| Returns: | |
| `torch.Tensor`: Image latent. | |
| """ | |
| # encode | |
| h = self.vae.encoder(rgb_in) | |
| moments = self.vae.quant_conv(h) | |
| mean, logvar = torch.chunk(moments, 2, dim=1) | |
| # scale latent | |
| rgb_latent = mean * self.latent_scale_factor | |
| return rgb_latent | |
| def decode_pred(self, pred_latent: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Decode pred latent into result. | |
| Args: | |
| pred_latent (`torch.Tensor`): | |
| pred latent to be decoded. | |
| Returns: | |
| `torch.Tensor`: Decoded result. | |
| """ | |
| # scale latent | |
| pred_latent = pred_latent / self.latent_scale_factor | |
| # decode | |
| z = self.vae.post_quant_conv(pred_latent) | |
| stacked = self.vae.decoder(z) | |
| if self.mode in ['depth', 'matting', 'dis', 'disparity']: | |
| # mean of output channels | |
| stacked = stacked.mean(dim=1, keepdim=True) | |
| return stacked | |