File size: 1,348 Bytes
9e81046
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
---
license: apache-2.0
---
```python
# You can use the following code to call our trained style encoder. Hope it helps.
import torchvision.transforms.functional as F
from torchvision import transforms
from transformers import (AutoModel, AutoProcessor, AutoTokenizer, AutoConfig,
                            CLIPImageProcessor, CLIPVisionModelWithProjection)
class SEStyleEmbedding:
    def __init__(self, pretrained_path: str = "xingpng/OneIG-StyleEncoder", device: str = "cuda", dtype=torch.bfloat16):
        self.device = torch.device(device)
        self.dtype = dtype
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_path)
        self.image_encoder.to(self.device, dtype=self.dtype)
        self.image_encoder.eval()
        self.processor = CLIPImageProcessor()

    def _l2_normalize(self, x):
        return torch.nn.functional.normalize(x, p=2, dim=-1)

    def get_style_embedding(self, image_path: str):
        image = Image.open(image_path).convert('RGB')
        inputs = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype)

        with torch.no_grad():
            outputs = self.image_encoder(inputs)
            image_embeds = outputs.image_embeds
            image_embeds_norm = self._l2_normalize(image_embeds)
        return image_embeds_norm