File size: 1,348 Bytes
9e81046 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
---
license: apache-2.0
---
```python
# You can use the following code to call our trained style encoder. Hope it helps.
import torchvision.transforms.functional as F
from torchvision import transforms
from transformers import (AutoModel, AutoProcessor, AutoTokenizer, AutoConfig,
CLIPImageProcessor, CLIPVisionModelWithProjection)
class SEStyleEmbedding:
def __init__(self, pretrained_path: str = "xingpng/OneIG-StyleEncoder", device: str = "cuda", dtype=torch.bfloat16):
self.device = torch.device(device)
self.dtype = dtype
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_path)
self.image_encoder.to(self.device, dtype=self.dtype)
self.image_encoder.eval()
self.processor = CLIPImageProcessor()
def _l2_normalize(self, x):
return torch.nn.functional.normalize(x, p=2, dim=-1)
def get_style_embedding(self, image_path: str):
image = Image.open(image_path).convert('RGB')
inputs = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype)
with torch.no_grad():
outputs = self.image_encoder(inputs)
image_embeds = outputs.image_embeds
image_embeds_norm = self._l2_normalize(image_embeds)
return image_embeds_norm |