test / animatediff /utils /tagger.py
dummy
a
314c40f
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
import glob
import logging
import os
import cv2
import numpy as np
import onnxruntime
import pandas as pd
from PIL import Image
from tqdm.rich import tqdm
from animatediff.utils.util import prepare_wd14tagger
logger = logging.getLogger(__name__)
def make_square(img, target_size):
old_size = img.shape[:2]
desired_size = max(old_size)
desired_size = max(desired_size, target_size)
delta_w = desired_size - old_size[1]
delta_h = desired_size - old_size[0]
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
left, right = delta_w // 2, delta_w - (delta_w // 2)
color = [255, 255, 255]
new_im = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
)
return new_im
def smart_resize(img, size):
# Assumes the image has already gone through make_square
if img.shape[0] > size:
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
elif img.shape[0] < size:
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
return img
class Tagger:
def __init__(self, general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format,is_cpu):
prepare_wd14tagger()
# self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CUDAExecutionProvider','CPUExecutionProvider'])
if is_cpu:
self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CPUExecutionProvider'])
else:
self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CUDAExecutionProvider'])
df = pd.read_csv("data/models/WD14tagger/selected_tags.csv")
self.tag_names = df["name"].tolist()
self.rating_indexes = list(np.where(df["category"] == 9)[0])
self.general_indexes = list(np.where(df["category"] == 0)[0])
self.character_indexes = list(np.where(df["category"] == 4)[0])
self.general_threshold = general_threshold
self.character_threshold = character_threshold
self.ignore_tokens = ignore_tokens
self.with_confidence = with_confidence
self.is_danbooru_format = is_danbooru_format
def __call__(
self,
image: Image,
):
_, height, width, _ = self.model.get_inputs()[0].shape
# Alpha to white
image = image.convert("RGBA")
new_image = Image.new("RGBA", image.size, "WHITE")
new_image.paste(image, mask=image)
image = new_image.convert("RGB")
image = np.asarray(image)
# PIL RGB to OpenCV BGR
image = image[:, :, ::-1]
image = make_square(image, height)
image = smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
probs = self.model.run([label_name], {input_name: image})[0]
labels = list(zip(self.tag_names, probs[0].astype(float)))
# First 4 labels are actually ratings: pick one with argmax
ratings_names = [labels[i] for i in self.rating_indexes]
rating = dict(ratings_names)
# Then we have general tags: pick any where prediction confidence > threshold
general_names = [labels[i] for i in self.general_indexes]
general_res = [x for x in general_names if x[1] > self.general_threshold]
general_res = dict(general_res)
# Everything else is characters: pick any where prediction confidence > threshold
character_names = [labels[i] for i in self.character_indexes]
character_res = [x for x in character_names if x[1] > self.character_threshold]
character_res = dict(character_res)
#logger.info(f"{rating=}")
#logger.info(f"{general_res=}")
#logger.info(f"{character_res=}")
general_res = {k:general_res[k] for k in (general_res.keys() - set(self.ignore_tokens)) }
character_res = {k:character_res[k] for k in (character_res.keys() - set(self.ignore_tokens)) }
prompt = ""
if self.with_confidence:
prompt = [ f"({i}:{character_res[i]:.2f})" for i in (character_res.keys()) ]
prompt += [ f"({i}:{general_res[i]:.2f})" for i in (general_res.keys()) ]
else:
prompt = [ i for i in (character_res.keys()) ]
prompt += [ i for i in (general_res.keys()) ]
prompt = ",".join(prompt)
if not self.is_danbooru_format:
prompt = prompt.replace("_", " ")
#logger.info(f"{prompt=}")
return prompt
def get_labels(frame_dir, interval, general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format, is_cpu =False):
import torch
result = {}
if os.path.isdir(frame_dir):
png_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False))
png_map ={}
for png_path in png_list:
basename_without_ext = os.path.splitext(os.path.basename(png_path))[0]
png_map[int(basename_without_ext)] = png_path
with torch.no_grad():
tagger = Tagger(general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format, is_cpu)
for i in tqdm(range(0, len(png_list), interval ), desc=f"WD14tagger"):
path = png_map[i]
#logger.info(f"{path=}")
result[str(i)] = tagger(
image= Image.open(path)
)
tagger = None
torch.cuda.empty_cache()
return result