Spaces:
Configuration error
Configuration error
| import random | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torchvision import transforms | |
| from torchvision.transforms import InterpolationMode | |
| from models.image_model import Model | |
| class VideoModel(Model): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.net_preprocess = transforms.Compose([]) | |
| def resize_crops(crops, resize_factor): | |
| return torchvision.transforms.functional.resize( | |
| crops, | |
| [ | |
| crops.shape[-2] // resize_factor, | |
| crops.shape[-1] // resize_factor, | |
| ], | |
| InterpolationMode.BILINEAR, | |
| antialias=True, | |
| ) | |
| def process_crops(self, uv_values, crops, original_crops, alpha=None): | |
| resized_crops = [] | |
| cnn_output_crops = [] | |
| render_dict = {"edit": [], "alpha": [], "edit_on_greenscreen": [], "composite": []} | |
| atlas_crop = crops[0] | |
| for i in range(3): | |
| grid_sampled_atlas_crop = F.grid_sample( | |
| atlas_crop, | |
| uv_values[i], | |
| mode="bilinear", | |
| align_corners=self.config["align_corners"], | |
| ).clamp(min=0.0, max=1.0) | |
| resized_crops.append(grid_sampled_atlas_crop) | |
| cnn_output = self.netG(atlas_crop) | |
| cnn_output_crops.append(cnn_output[:, :3]) | |
| rendered_atlas_crops = self.render(cnn_output, bg_image=atlas_crop) | |
| for key, value in rendered_atlas_crops.items(): | |
| for i in range(3): | |
| sampled_frame_crop = F.grid_sample( | |
| value, | |
| uv_values[i], | |
| mode="bilinear", | |
| align_corners=self.config["align_corners"], | |
| ).clamp(min=0.0, max=1.0) | |
| if alpha is not None: | |
| sampled_frame_crop = sampled_frame_crop * alpha[i] | |
| if key == "edit_on_greenscreen": | |
| greenscreen = torch.zeros_like(sampled_frame_crop).to(sampled_frame_crop.device) | |
| greenscreen[:, 1, :, :] = 177 / 255 | |
| greenscreen[:, 2, :, :] = 64 / 255 | |
| sampled_frame_crop += (1 - alpha[i]) * greenscreen | |
| render_dict[key].append(sampled_frame_crop.squeeze(0)) | |
| # passing a random frame to the network | |
| frame_index = random.randint(0, 2) # randomly sample one of three frames | |
| rec_crop = original_crops[frame_index] | |
| resized_crops.append(rec_crop) | |
| cnn_output = self.netG(rec_crop) | |
| if alpha is not None: | |
| alpha_crop = alpha[frame_index] | |
| cnn_output = cnn_output * alpha_crop | |
| cnn_output_crops.append(cnn_output[:, :3]) | |
| rendered_frame_crop = self.render(cnn_output, bg_image=original_crops[frame_index]) | |
| for key, value in rendered_frame_crop.items(): | |
| render_dict[key].append(value.squeeze(0)) | |
| return render_dict, resized_crops, cnn_output_crops | |
| def process_atlas(self, atlas): | |
| atlas_edit = self.netG(atlas) | |
| rendered_atlas = self.render(atlas_edit, bg_image=atlas) | |
| return rendered_atlas | |
| def forward(self, input_dict): | |
| inputs = input_dict["global_crops"] | |
| outputs = {"background": {}, "foreground": {}} | |
| if self.config["finetune_foreground"]: | |
| if self.config["multiply_foreground_alpha"]: | |
| alpha = inputs["foreground_alpha"] | |
| else: | |
| alpha = None | |
| foreground_outputs, resized_crops, cnn_output_crops = self.process_crops( | |
| inputs["foreground_uvs"], | |
| inputs["foreground_atlas_crops"], | |
| inputs["original_foreground_crops"], | |
| alpha=alpha, | |
| ) | |
| outputs["foreground"]["output_crop"] = foreground_outputs | |
| outputs["foreground"]["cnn_inputs"] = resized_crops | |
| outputs["foreground"]["cnn_outputs"] = cnn_output_crops | |
| if "input_image" in input_dict.keys(): | |
| outputs["foreground"]["output_image"] = self.process_atlas(input_dict["input_image"]) | |
| elif self.config["finetune_background"]: | |
| background_outputs, resized_crops, cnn_output_crops = self.process_crops( | |
| inputs["background_uvs"], | |
| inputs["background_atlas_crops"], | |
| inputs["original_background_crops"], | |
| ) | |
| outputs["background"]["output_crop"] = background_outputs | |
| outputs["background"]["cnn_inputs"] = resized_crops | |
| outputs["background"]["cnn_outputs"] = cnn_output_crops | |
| if "input_image" in input_dict.keys(): | |
| outputs["background"]["output_image"] = self.process_atlas(input_dict["input_image"]) | |
| return outputs | |