Spaces:
Sleeping
Sleeping
| import logging | |
| import tempfile | |
| from typing import List, Optional, Tuple, Union | |
| import matplotlib | |
| import mediapy | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| from decord import VideoReader, cpu | |
| logger = logging.getLogger(__name__) | |
| dataset_res_dict = { | |
| "sintel": [448, 1024], | |
| "scannet": [640, 832], | |
| "KITTI": [384, 1280], | |
| "bonn": [512, 640], | |
| "NYUv2": [448, 640], | |
| } | |
| def read_video_frames( | |
| video_path: str, | |
| process_length: int, | |
| target_fps: int, | |
| max_res: int, | |
| dataset: str = "open", | |
| ) -> Tuple[np.ndarray, int]: | |
| """ | |
| Read video frames from a file, resize and downsample them. | |
| Args: | |
| video_path (str): Path to the video file. | |
| process_length (int): Maximum number of frames to process. | |
| target_fps (int): Target FPS for the output. | |
| max_res (int): Maximum resolution (height or width). | |
| dataset (str): Dataset name for resolution settings. | |
| Returns: | |
| Tuple[np.ndarray, int]: A tuple containing the frames (numpy array) and the actual FPS. | |
| """ | |
| if dataset == "open": | |
| logger.info(f"Processing video: {video_path}") | |
| vid = VideoReader(video_path, ctx=cpu(0)) | |
| logger.info( | |
| f"Original video shape: {(len(vid), *vid.get_batch([0]).shape[1:])}" | |
| ) | |
| original_height, original_width = vid.get_batch([0]).shape[1:3] | |
| height = round(original_height / 64) * 64 | |
| width = round(original_width / 64) * 64 | |
| if max(height, width) > max_res: | |
| scale = max_res / max(original_height, original_width) | |
| height = round(original_height * scale / 64) * 64 | |
| width = round(original_width * scale / 64) * 64 | |
| else: | |
| height = dataset_res_dict[dataset][0] | |
| width = dataset_res_dict[dataset][1] | |
| vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height) | |
| fps = vid.get_avg_fps() if target_fps == -1 else target_fps | |
| stride = round(vid.get_avg_fps() / fps) | |
| stride = max(stride, 1) | |
| frames_idx = list(range(0, len(vid), stride)) | |
| logger.info( | |
| f"Downsampled shape: {(len(frames_idx), *vid.get_batch([0]).shape[1:])}, with stride: {stride}" | |
| ) | |
| if process_length != -1 and process_length < len(frames_idx): | |
| frames_idx = frames_idx[:process_length] | |
| logger.info( | |
| f"Final processing shape: {(len(frames_idx), *vid.get_batch([0]).shape[1:])}" | |
| ) | |
| frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0 | |
| return frames, fps | |
| def save_video( | |
| video_frames: Union[List[np.ndarray], List[PIL.Image.Image], np.ndarray], | |
| output_video_path: Optional[str] = None, | |
| fps: int = 10, | |
| crf: int = 18, | |
| ) -> str: | |
| """ | |
| Save video frames to a file. | |
| Args: | |
| video_frames (Union[List[np.ndarray], List[PIL.Image.Image], np.ndarray]): List of frames or numpy array. | |
| output_video_path (Optional[str]): Path to save the video. If None, a temporary file is created. | |
| fps (int): Frames per second. | |
| crf (int): Constant Rate Factor for encoding quality. | |
| Returns: | |
| str: Path to the saved video. | |
| """ | |
| if output_video_path is None: | |
| output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name | |
| if isinstance(video_frames, np.ndarray): | |
| # If it's a numpy array, we assume it's already in the correct format or needs simple conversion | |
| if video_frames.dtype != np.uint8: | |
| video_frames = (video_frames * 255).astype(np.uint8) | |
| elif isinstance(video_frames[0], np.ndarray): | |
| video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] | |
| elif isinstance(video_frames[0], PIL.Image.Image): | |
| video_frames = [np.array(frame) for frame in video_frames] | |
| mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf) | |
| return output_video_path | |
| class ColorMapper: | |
| """ | |
| A color mapper to map depth values to a certain colormap. | |
| """ | |
| def __init__(self, colormap: str = "inferno"): | |
| """ | |
| Initialize the ColorMapper. | |
| Args: | |
| colormap (str): Name of the colormap to use. | |
| """ | |
| self.colormap = torch.tensor(matplotlib.colormaps[colormap].colors) | |
| def apply( | |
| self, | |
| image: torch.Tensor, | |
| v_min: Optional[float] = None, | |
| v_max: Optional[float] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply the colormap to an image. | |
| Args: | |
| image (torch.Tensor): Input image tensor. | |
| v_min (Optional[float]): Minimum value for normalization. | |
| v_max (Optional[float]): Maximum value for normalization. | |
| Returns: | |
| torch.Tensor: Color-mapped image. | |
| """ | |
| if v_min is None: | |
| v_min = image.min() | |
| if v_max is None: | |
| v_max = image.max() | |
| image = (image - v_min) / (v_max - v_min) | |
| image = (image * 255).long() | |
| # Clamp values to be within valid range for indexing | |
| image = torch.clamp(image, 0, 255) | |
| image = self.colormap[image] | |
| return image | |
| def vis_sequence_depth( | |
| depths: np.ndarray, v_min: Optional[float] = None, v_max: Optional[float] = None | |
| ) -> np.ndarray: | |
| """ | |
| Visualize a sequence of depth maps. | |
| Args: | |
| depths (np.ndarray): Input depth maps. | |
| v_min (Optional[float]): Minimum value for normalization. | |
| v_max (Optional[float]): Maximum value for normalization. | |
| Returns: | |
| np.ndarray: Visualized depth maps. | |
| """ | |
| visualizer = ColorMapper() | |
| if v_min is None: | |
| v_min = depths.min() | |
| if v_max is None: | |
| v_max = depths.max() | |
| res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy() | |
| return res | |