Spaces:
Runtime error
Runtime error
| # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py | |
| import glob | |
| import logging | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import onnxruntime | |
| import pandas as pd | |
| from PIL import Image | |
| from tqdm.rich import tqdm | |
| from animatediff.utils.util import prepare_wd14tagger | |
| logger = logging.getLogger(__name__) | |
| def make_square(img, target_size): | |
| old_size = img.shape[:2] | |
| desired_size = max(old_size) | |
| desired_size = max(desired_size, target_size) | |
| delta_w = desired_size - old_size[1] | |
| delta_h = desired_size - old_size[0] | |
| top, bottom = delta_h // 2, delta_h - (delta_h // 2) | |
| left, right = delta_w // 2, delta_w - (delta_w // 2) | |
| color = [255, 255, 255] | |
| new_im = cv2.copyMakeBorder( | |
| img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color | |
| ) | |
| return new_im | |
| def smart_resize(img, size): | |
| # Assumes the image has already gone through make_square | |
| if img.shape[0] > size: | |
| img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) | |
| elif img.shape[0] < size: | |
| img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) | |
| return img | |
| class Tagger: | |
| def __init__(self, general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format,is_cpu): | |
| prepare_wd14tagger() | |
| # self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CUDAExecutionProvider','CPUExecutionProvider']) | |
| if is_cpu: | |
| self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CPUExecutionProvider']) | |
| else: | |
| self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CUDAExecutionProvider']) | |
| df = pd.read_csv("data/models/WD14tagger/selected_tags.csv") | |
| self.tag_names = df["name"].tolist() | |
| self.rating_indexes = list(np.where(df["category"] == 9)[0]) | |
| self.general_indexes = list(np.where(df["category"] == 0)[0]) | |
| self.character_indexes = list(np.where(df["category"] == 4)[0]) | |
| self.general_threshold = general_threshold | |
| self.character_threshold = character_threshold | |
| self.ignore_tokens = ignore_tokens | |
| self.with_confidence = with_confidence | |
| self.is_danbooru_format = is_danbooru_format | |
| def __call__( | |
| self, | |
| image: Image, | |
| ): | |
| _, height, width, _ = self.model.get_inputs()[0].shape | |
| # Alpha to white | |
| image = image.convert("RGBA") | |
| new_image = Image.new("RGBA", image.size, "WHITE") | |
| new_image.paste(image, mask=image) | |
| image = new_image.convert("RGB") | |
| image = np.asarray(image) | |
| # PIL RGB to OpenCV BGR | |
| image = image[:, :, ::-1] | |
| image = make_square(image, height) | |
| image = smart_resize(image, height) | |
| image = image.astype(np.float32) | |
| image = np.expand_dims(image, 0) | |
| input_name = self.model.get_inputs()[0].name | |
| label_name = self.model.get_outputs()[0].name | |
| probs = self.model.run([label_name], {input_name: image})[0] | |
| labels = list(zip(self.tag_names, probs[0].astype(float))) | |
| # First 4 labels are actually ratings: pick one with argmax | |
| ratings_names = [labels[i] for i in self.rating_indexes] | |
| rating = dict(ratings_names) | |
| # Then we have general tags: pick any where prediction confidence > threshold | |
| general_names = [labels[i] for i in self.general_indexes] | |
| general_res = [x for x in general_names if x[1] > self.general_threshold] | |
| general_res = dict(general_res) | |
| # Everything else is characters: pick any where prediction confidence > threshold | |
| character_names = [labels[i] for i in self.character_indexes] | |
| character_res = [x for x in character_names if x[1] > self.character_threshold] | |
| character_res = dict(character_res) | |
| #logger.info(f"{rating=}") | |
| #logger.info(f"{general_res=}") | |
| #logger.info(f"{character_res=}") | |
| general_res = {k:general_res[k] for k in (general_res.keys() - set(self.ignore_tokens)) } | |
| character_res = {k:character_res[k] for k in (character_res.keys() - set(self.ignore_tokens)) } | |
| prompt = "" | |
| if self.with_confidence: | |
| prompt = [ f"({i}:{character_res[i]:.2f})" for i in (character_res.keys()) ] | |
| prompt += [ f"({i}:{general_res[i]:.2f})" for i in (general_res.keys()) ] | |
| else: | |
| prompt = [ i for i in (character_res.keys()) ] | |
| prompt += [ i for i in (general_res.keys()) ] | |
| prompt = ",".join(prompt) | |
| if not self.is_danbooru_format: | |
| prompt = prompt.replace("_", " ") | |
| #logger.info(f"{prompt=}") | |
| return prompt | |
| def get_labels(frame_dir, interval, general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format, is_cpu =False): | |
| import torch | |
| result = {} | |
| if os.path.isdir(frame_dir): | |
| png_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False)) | |
| png_map ={} | |
| for png_path in png_list: | |
| basename_without_ext = os.path.splitext(os.path.basename(png_path))[0] | |
| png_map[int(basename_without_ext)] = png_path | |
| with torch.no_grad(): | |
| tagger = Tagger(general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format, is_cpu) | |
| for i in tqdm(range(0, len(png_list), interval ), desc=f"WD14tagger"): | |
| path = png_map[i] | |
| #logger.info(f"{path=}") | |
| result[str(i)] = tagger( | |
| image= Image.open(path) | |
| ) | |
| tagger = None | |
| torch.cuda.empty_cache() | |
| return result | |