Spaces:
Configuration error
Configuration error
| import torch | |
| from torch.nn import functional as F | |
| import torchvision.transforms as T | |
| from torchvision.transforms import InterpolationMode | |
| from CLIP import clip | |
| from util.util import compose_text_with_templates | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class ClipExtractor(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| model = clip.load(cfg["clip_model_name"], device=device)[0] | |
| self.model = model.eval().requires_grad_(False) | |
| self.clip_input_size = 224 | |
| self.clip_normalize = T.Normalize( | |
| mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] | |
| ) | |
| self.basic_transform = T.Compose( | |
| [ | |
| # we added interpolation to CLIP positional embedding, allowing to work with arbitrary resolution. | |
| T.Resize(self.clip_input_size, max_size=380), | |
| self.clip_normalize, | |
| ] | |
| ) | |
| # list of augmentations we apply before calculating the CLIP losses | |
| self.augs = T.Compose( | |
| [ | |
| T.RandomHorizontalFlip(p=0.5), | |
| T.RandomApply( | |
| [ | |
| T.RandomAffine( | |
| degrees=15, | |
| translate=(0.1, 0.1), | |
| fill=cfg["clip_affine_transform_fill"], | |
| interpolation=InterpolationMode.BILINEAR, | |
| ) | |
| ], | |
| p=0.8, | |
| ), | |
| T.RandomPerspective( | |
| distortion_scale=0.4, | |
| p=0.5, | |
| interpolation=InterpolationMode.BILINEAR, | |
| fill=cfg["clip_affine_transform_fill"], | |
| ), | |
| T.RandomApply([T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)], p=0.7), | |
| T.RandomGrayscale(p=0.15), | |
| ] | |
| ) | |
| self.n_aug = cfg["n_aug"] | |
| def augment_input(self, input, n_aug=None, clip_input_size=None): | |
| if n_aug is None: | |
| n_aug = self.n_aug | |
| if clip_input_size is None: | |
| clip_input_size = self.clip_input_size | |
| cutouts = [] | |
| cutout = T.Resize(clip_input_size, max_size=320)(input) | |
| cutout_h, cutout_w = cutout.shape[-2:] | |
| cutout = self.augs(cutout) | |
| cutouts.append(cutout) | |
| sideY, sideX = input.shape[2:4] | |
| for _ in range(n_aug - 1): | |
| s = ( | |
| torch.zeros( | |
| 1, | |
| ) | |
| .uniform_(0.6, 1) | |
| .item() | |
| ) | |
| h = int(sideY * s) | |
| w = int(sideX * s) | |
| cutout = T.RandomCrop(size=(h, w))(input) | |
| cutout = T.Resize((cutout_h, cutout_w))(cutout) | |
| cutout = self.augs(cutout) | |
| cutouts.append(cutout) | |
| cutouts = torch.cat(cutouts) | |
| return cutouts | |
| def get_image_embedding(self, x, aug=True): | |
| if aug: | |
| views = self.augment_input(x) | |
| else: | |
| views = self.basic_transform(x) | |
| if type(views) == list: | |
| image_embeds = [] | |
| for view in views: | |
| image_embeds.append(self.encode_image(self.clip_normalize(view))) | |
| image_embeds = torch.cat(image_embeds) | |
| else: | |
| image_embeds = self.encode_image(self.clip_normalize(views)) | |
| return image_embeds | |
| def encode_image(self, x): | |
| return self.model.encode_image(x) | |
| def get_text_embedding(self, text, template, average_embeddings=False): | |
| if type(text) == str: | |
| text = [text] | |
| embeddings = [] | |
| for prompt in text: | |
| with torch.no_grad(): | |
| embedding = self.model.encode_text( | |
| clip.tokenize(compose_text_with_templates(prompt, template)).to(device) | |
| ) | |
| embeddings.append(embedding) | |
| embeddings = torch.cat(embeddings) | |
| if average_embeddings: | |
| embeddings = embeddings.mean(dim=0, keepdim=True) | |
| return embeddings | |
| def get_self_sim(self, x): | |
| x = self.basic_transform(x) | |
| return self.model.calculate_self_sim(x) | |