Spaces:
Configuration error
Configuration error
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from madgrad import MADGRAD | |
| from torchvision import transforms | |
| def get_optimizer(cfg, params): | |
| if cfg["optimizer"] == "adam": | |
| optimizer = torch.optim.Adam(params, lr=cfg["lr"]) | |
| elif cfg["optimizer"] == "radam": | |
| optimizer = torch.optim.RAdam(params, lr=cfg["lr"]) | |
| elif cfg["optimizer"] == "madgrad": | |
| optimizer = MADGRAD(params, lr=cfg["lr"], weight_decay=0.01, momentum=0.9) | |
| elif cfg["optimizer"] == "rmsprop": | |
| optimizer = torch.optim.RMSprop(params, lr=cfg["lr"], weight_decay=0.01) | |
| elif cfg["optimizer"] == "sgd": | |
| optimizer = torch.optim.SGD(params, lr=cfg["lr"]) | |
| else: | |
| return NotImplementedError("optimizer [%s] is not implemented", cfg["optimizer"]) | |
| return optimizer | |
| def get_text_criterion(cfg): | |
| if cfg["text_criterion"] == "spherical": | |
| text_criterion = spherical_dist_loss | |
| elif cfg["text_criterion"] == "cosine": | |
| text_criterion = cosine_loss | |
| else: | |
| return NotImplementedError("text criterion [%s] is not implemented", cfg["text_criterion"]) | |
| return text_criterion | |
| def spherical_dist_loss(x, y): | |
| x = F.normalize(x, dim=-1) | |
| y = F.normalize(y, dim=-1) | |
| return ((x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)).mean() | |
| def cosine_loss(x, y, scaling=1.2): | |
| return scaling * (1 - F.cosine_similarity(x, y).mean()) | |
| def tensor2im(input_image, imtype=np.uint8): | |
| if not isinstance(input_image, np.ndarray): | |
| if isinstance(input_image, torch.Tensor): # get the data from a variable | |
| image_tensor = input_image.data | |
| else: | |
| return input_image | |
| image_numpy = image_tensor[0].clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array | |
| image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling | |
| else: # if it is a numpy array, do nothing | |
| image_numpy = input_image | |
| return image_numpy.astype(imtype) | |
| def get_screen_template(): | |
| return [ | |
| "{} over a green screen.", | |
| "{} in front of a green screen.", | |
| ] | |
| def get_augmentations_template(): | |
| templates = [ | |
| "photo of {}.", | |
| "high quality photo of {}.", | |
| "a photo of {}.", | |
| "the photo of {}.", | |
| "image of {}.", | |
| "an image of {}.", | |
| "high quality image of {}.", | |
| "a high quality image of {}.", | |
| "the {}.", | |
| "a {}.", | |
| "{}.", | |
| "{}", | |
| "{}!", | |
| "{}...", | |
| ] | |
| return templates | |
| def compose_text_with_templates(text: str, templates) -> list: | |
| return [template.format(text) for template in templates] | |
| def get_mask_boundary(img, mask): | |
| mask = mask.squeeze() # mask.shape -> (H, W) | |
| if torch.sum(mask) > 0: | |
| y, x = torch.where(mask) | |
| y0, x0 = y.min(), x.min() | |
| y1, x1 = y.max(), x.max() | |
| return img[:, :, y0:y1, x0:x1] | |
| else: | |
| return img | |
| def load_video(folder: str, resize=(432, 768), num_frames=70): | |
| resy, resx = resize | |
| folder = Path(folder) | |
| input_files = sorted(list(folder.glob("*.jpg")) + list(folder.glob("*.png")))[:num_frames] | |
| video = torch.zeros((len(input_files), 3, resy, resx)) | |
| for i, file in enumerate(input_files): | |
| video[i] = transforms.ToTensor()(Image.open(str(file)).resize((resx, resy), Image.LANCZOS)) | |
| return video | |