|
|
""" |
|
|
Salia Ultralytics Detector Provider (ComfyUI custom node) |
|
|
|
|
|
Goal: |
|
|
- Provide the same outputs as Impact-Subpack's `UltralyticsDetectorProvider`: |
|
|
- BBOX_DETECTOR |
|
|
- SEGM_DETECTOR |
|
|
- But packaged so you can drop it into your own custom node folder (your Salia_* environment) |
|
|
without requiring ComfyUI-Impact-Subpack. |
|
|
|
|
|
Notes: |
|
|
- This file intentionally keeps dependencies minimal and self-contained. |
|
|
- It uses `ultralytics.YOLO` to run `.pt` models directly (no TensorRT build step). |
|
|
- For PyTorch >= 2.6, `torch.load` defaults to `weights_only=True` which can break |
|
|
legacy `.pt` checkpoints. This file adds an OPTIONAL whitelist-based fallback |
|
|
to `weights_only=False` (unsafe) for specifically trusted model filenames. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import logging |
|
|
import pickle |
|
|
from datetime import datetime |
|
|
from contextlib import contextmanager |
|
|
from collections import namedtuple |
|
|
|
|
|
import folder_paths |
|
|
|
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
try: |
|
|
import cv2 |
|
|
except Exception: |
|
|
cv2 = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_SUPPORTED_PT_EXTS = getattr(folder_paths, "supported_pt_extensions", [".pt", ".pth", ".ckpt", ".safetensors"]) |
|
|
|
|
|
|
|
|
def _add_folder_path_and_extensions(folder_name: str, paths: list[str], extensions: list[str] | tuple[str, ...]): |
|
|
"""Add/merge a folder_paths entry without depending on Impact-Pack helpers.""" |
|
|
if folder_name in folder_paths.folder_names_and_paths: |
|
|
existing_paths, existing_exts = folder_paths.folder_names_and_paths[folder_name] |
|
|
merged_paths = list(existing_paths) |
|
|
for p in paths: |
|
|
if p not in merged_paths: |
|
|
merged_paths.append(p) |
|
|
merged_exts = list(existing_exts) |
|
|
for ext in extensions: |
|
|
if ext not in merged_exts: |
|
|
merged_exts.append(ext) |
|
|
folder_paths.folder_names_and_paths[folder_name] = (merged_paths, tuple(merged_exts)) |
|
|
else: |
|
|
folder_paths.folder_names_and_paths[folder_name] = (list(paths), tuple(extensions)) |
|
|
|
|
|
|
|
|
def _update_model_paths(base_path: str): |
|
|
"""Register standard Impact-Subpack ultralytics model locations.""" |
|
|
_add_folder_path_and_extensions( |
|
|
"ultralytics_bbox", |
|
|
[os.path.join(base_path, "ultralytics", "bbox")], |
|
|
_SUPPORTED_PT_EXTS, |
|
|
) |
|
|
_add_folder_path_and_extensions( |
|
|
"ultralytics_segm", |
|
|
[os.path.join(base_path, "ultralytics", "segm")], |
|
|
_SUPPORTED_PT_EXTS, |
|
|
) |
|
|
_add_folder_path_and_extensions( |
|
|
"ultralytics", |
|
|
[os.path.join(base_path, "ultralytics")], |
|
|
_SUPPORTED_PT_EXTS, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
_update_model_paths(folder_paths.models_dir) |
|
|
if "download_model_base" in folder_paths.folder_names_and_paths: |
|
|
try: |
|
|
_update_model_paths(folder_paths.get_folder_paths("download_model_base")[0]) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
for local_dir in [ |
|
|
os.path.join(_THIS_DIR, "nodes"), |
|
|
os.path.join(_THIS_DIR, "models"), |
|
|
_THIS_DIR, |
|
|
]: |
|
|
if os.path.isdir(local_dir): |
|
|
_add_folder_path_and_extensions("ultralytics_bbox", [local_dir], _SUPPORTED_PT_EXTS) |
|
|
_add_folder_path_and_extensions("ultralytics_segm", [local_dir], _SUPPORTED_PT_EXTS) |
|
|
_add_folder_path_and_extensions("ultralytics", [local_dir], _SUPPORTED_PT_EXTS) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_ORIG_TORCH_LOAD = torch.load |
|
|
|
|
|
|
|
|
def _get_whitelist_file() -> str | None: |
|
|
"""Create/return the whitelist file path under ComfyUI's user directory.""" |
|
|
try: |
|
|
user_dir = folder_paths.get_user_directory() |
|
|
except Exception: |
|
|
user_dir = None |
|
|
|
|
|
if not user_dir or not os.path.isdir(user_dir): |
|
|
return None |
|
|
|
|
|
wl_dir = os.path.join(user_dir, "default", "ComfyUI-Salia-Ultralytics") |
|
|
wl_file = os.path.join(wl_dir, "model-whitelist.txt") |
|
|
try: |
|
|
os.makedirs(wl_dir, exist_ok=True) |
|
|
if not os.path.exists(wl_file): |
|
|
with open(wl_file, "w", encoding="utf-8") as f: |
|
|
f.write("# Add base filenames of trusted legacy models here (one per line).\n") |
|
|
f.write("# Example: eyes.pt\n") |
|
|
f.write("# These will be allowed to load with weights_only=False if safe loading fails.\n") |
|
|
f.write("# WARNING: Only add models you trust.\n") |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
return wl_file |
|
|
|
|
|
|
|
|
_WHITELIST_PATH = _get_whitelist_file() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_model_load_log_file() -> str: |
|
|
""" |
|
|
Log file path used to record which ultralytics model file was actually loaded. |
|
|
Prefer the same ComfyUI user dir used for the whitelist (if available). |
|
|
""" |
|
|
|
|
|
if _WHITELIST_PATH: |
|
|
base_dir = os.path.dirname(_WHITELIST_PATH) |
|
|
return os.path.join(base_dir, "model-load-log.txt") |
|
|
|
|
|
|
|
|
try: |
|
|
user_dir = folder_paths.get_user_directory() |
|
|
except Exception: |
|
|
user_dir = None |
|
|
|
|
|
if user_dir and os.path.isdir(user_dir): |
|
|
base_dir = os.path.join(user_dir, "default", "ComfyUI-Salia-Ultralytics") |
|
|
try: |
|
|
os.makedirs(base_dir, exist_ok=True) |
|
|
except Exception: |
|
|
pass |
|
|
return os.path.join(base_dir, "model-load-log.txt") |
|
|
|
|
|
|
|
|
return os.path.join(_THIS_DIR, "model-load-log.txt") |
|
|
|
|
|
|
|
|
_MODEL_LOAD_LOG_PATH = _get_model_load_log_file() |
|
|
|
|
|
|
|
|
def _find_all_model_paths(model_name: str) -> list[str]: |
|
|
""" |
|
|
Find all possible on-disk matches across the registered ultralytics folders. |
|
|
Useful if the same filename exists in multiple locations. |
|
|
""" |
|
|
matches: list[str] = [] |
|
|
|
|
|
try: |
|
|
ultra_roots = folder_paths.get_folder_paths("ultralytics") |
|
|
except Exception: |
|
|
ultra_roots = [] |
|
|
|
|
|
try: |
|
|
bbox_roots = folder_paths.get_folder_paths("ultralytics_bbox") |
|
|
except Exception: |
|
|
bbox_roots = [] |
|
|
|
|
|
try: |
|
|
segm_roots = folder_paths.get_folder_paths("ultralytics_segm") |
|
|
except Exception: |
|
|
segm_roots = [] |
|
|
|
|
|
def add_if_exists(root: str, rel: str): |
|
|
p = os.path.join(root, rel) |
|
|
if os.path.exists(p): |
|
|
matches.append(os.path.abspath(p)) |
|
|
|
|
|
|
|
|
for r in ultra_roots: |
|
|
add_if_exists(r, model_name) |
|
|
|
|
|
|
|
|
if model_name.startswith("bbox/"): |
|
|
rel = model_name[5:] |
|
|
for r in bbox_roots: |
|
|
add_if_exists(r, rel) |
|
|
elif model_name.startswith("segm/"): |
|
|
rel = model_name[5:] |
|
|
for r in segm_roots: |
|
|
add_if_exists(r, rel) |
|
|
|
|
|
|
|
|
out: list[str] = [] |
|
|
seen = set() |
|
|
for p in matches: |
|
|
if p not in seen: |
|
|
seen.add(p) |
|
|
out.append(p) |
|
|
return out |
|
|
|
|
|
|
|
|
def _log_selected_model(model_name: str, model_path: str, matches: list[str] | None = None): |
|
|
""" |
|
|
Prints the resolved model path to console AND appends it to a log file. |
|
|
""" |
|
|
|
|
|
print(f"[Salia Ultralytics] Selected model_name: {model_name}") |
|
|
print(f"[Salia Ultralytics] Resolved model_path: {model_path}") |
|
|
if matches and len(matches) > 1: |
|
|
print("[Salia Ultralytics] Multiple matches found (first one is used by get_full_path):") |
|
|
for p in matches: |
|
|
print(f" - {p}") |
|
|
print(f"[Salia Ultralytics] Model load log file: {_MODEL_LOAD_LOG_PATH}") |
|
|
|
|
|
|
|
|
logging.info("[Salia Ultralytics] Selected model_name: %s", model_name) |
|
|
logging.info("[Salia Ultralytics] Resolved model_path: %s", model_path) |
|
|
if matches and len(matches) > 1: |
|
|
logging.warning("[Salia Ultralytics] Multiple matches found (first one is used by get_full_path):") |
|
|
for p in matches: |
|
|
logging.warning(" - %s", p) |
|
|
logging.info("[Salia Ultralytics] Model load log file: %s", _MODEL_LOAD_LOG_PATH) |
|
|
|
|
|
|
|
|
try: |
|
|
ts = datetime.now().isoformat(timespec="seconds") |
|
|
exists = os.path.isfile(model_path) |
|
|
size = os.path.getsize(model_path) if exists else -1 |
|
|
|
|
|
log_dir = os.path.dirname(_MODEL_LOAD_LOG_PATH) |
|
|
if log_dir: |
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
|
|
|
with open(_MODEL_LOAD_LOG_PATH, "a", encoding="utf-8") as f: |
|
|
f.write(f"{ts}\t{model_name}\t{model_path}\texists={exists}\tsize={size}\n") |
|
|
if matches and len(matches) > 1: |
|
|
for p in matches: |
|
|
f.write(f"{ts}\tmatch\t{p}\n") |
|
|
except Exception as e: |
|
|
logging.warning("[Salia Ultralytics] Failed to write model-load log to %s: %s", _MODEL_LOAD_LOG_PATH, e) |
|
|
|
|
|
|
|
|
def _load_whitelist(filepath: str | None) -> set[str]: |
|
|
if not filepath: |
|
|
return set() |
|
|
try: |
|
|
approved: set[str] = set() |
|
|
with open(filepath, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line and not line.startswith("#"): |
|
|
approved.add(os.path.basename(line)) |
|
|
return approved |
|
|
except Exception: |
|
|
return set() |
|
|
|
|
|
|
|
|
_MODEL_WHITELIST = _load_whitelist(_WHITELIST_PATH) |
|
|
|
|
|
|
|
|
def _torch_load_wrapper(*args, **kwargs): |
|
|
"""Try safe load first; if it fails due to weights-only restrictions, allow fallback if whitelisted.""" |
|
|
filename = None |
|
|
if args and isinstance(args[0], str): |
|
|
filename = os.path.basename(args[0]) |
|
|
elif isinstance(kwargs.get("f"), str): |
|
|
filename = os.path.basename(kwargs["f"]) |
|
|
|
|
|
try: |
|
|
return _ORIG_TORCH_LOAD(*args, **kwargs) |
|
|
except pickle.UnpicklingError as e: |
|
|
msg = str(e) |
|
|
|
|
|
maybe_weights_only_error = ( |
|
|
"Weights only load failed" in msg |
|
|
or "Unsupported global" in msg |
|
|
or "disallowed" in msg |
|
|
or "not allowed" in msg |
|
|
or "getattr" in msg |
|
|
) |
|
|
|
|
|
if not maybe_weights_only_error: |
|
|
raise |
|
|
|
|
|
|
|
|
global _MODEL_WHITELIST |
|
|
_MODEL_WHITELIST = _load_whitelist(_WHITELIST_PATH) |
|
|
|
|
|
if filename and filename in _MODEL_WHITELIST: |
|
|
logging.warning( |
|
|
"[Salia Ultralytics] Safe torch.load failed for '%s'. Retrying with weights_only=False because it's whitelisted (%s).", |
|
|
filename, |
|
|
_WHITELIST_PATH, |
|
|
) |
|
|
retry_kwargs = dict(kwargs) |
|
|
retry_kwargs["weights_only"] = False |
|
|
return _ORIG_TORCH_LOAD(*args, **retry_kwargs) |
|
|
|
|
|
logging.error( |
|
|
"[Salia Ultralytics] Blocked unsafe model load for '%s'.\n" |
|
|
"Safe loading failed and the file is not whitelisted.\n" |
|
|
"If you TRUST this model, add its base name to: %s", |
|
|
filename or "[unknown]", |
|
|
_WHITELIST_PATH or "[whitelist path unavailable]", |
|
|
) |
|
|
raise |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def _patched_torch_load_for_ultralytics(): |
|
|
"""Patch torch.load only while ultralytics loads a checkpoint.""" |
|
|
|
|
|
if not hasattr(torch.serialization, "safe_globals"): |
|
|
yield |
|
|
return |
|
|
|
|
|
prev = torch.load |
|
|
torch.load = _torch_load_wrapper |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
torch.load = prev |
|
|
|
|
|
|
|
|
def _load_yolo(model_path: str): |
|
|
"""Load an Ultralytics YOLO model (with optional safe-load fallback).""" |
|
|
try: |
|
|
from ultralytics import YOLO |
|
|
except Exception as e: |
|
|
raise ImportError( |
|
|
"[Salia Ultralytics] ultralytics is not installed. Install it in your ComfyUI env, e.g.:\n" |
|
|
"pip install ultralytics" |
|
|
) from e |
|
|
|
|
|
with _patched_torch_load_for_ultralytics(): |
|
|
return YOLO(model_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor2np_rgb(image: torch.Tensor) -> np.ndarray: |
|
|
"""Convert a ComfyUI IMAGE tensor to a uint8 RGB numpy image.""" |
|
|
|
|
|
if not isinstance(image, torch.Tensor): |
|
|
raise TypeError(f"Expected torch.Tensor, got {type(image)}") |
|
|
|
|
|
if image.dim() == 4: |
|
|
img = image[0] |
|
|
else: |
|
|
img = image |
|
|
|
|
|
img = img.detach() |
|
|
if img.is_cuda: |
|
|
img = img.cpu() |
|
|
|
|
|
img = img.clamp(0, 1).numpy() |
|
|
if img.shape[-1] == 1: |
|
|
img = np.repeat(img, 3, axis=-1) |
|
|
|
|
|
img_u8 = (img * 255.0).round().astype(np.uint8) |
|
|
return img_u8 |
|
|
|
|
|
|
|
|
def tensor2pil(image: torch.Tensor) -> Image.Image: |
|
|
return Image.fromarray(_tensor2np_rgb(image)) |
|
|
|
|
|
|
|
|
def make_crop_region(w: int, h: int, bbox_xyxy, crop_factor: float, crop_min_size: int | None = None): |
|
|
x1, y1, x2, y2 = [float(v) for v in bbox_xyxy] |
|
|
bbox_w = max(1.0, x2 - x1) |
|
|
bbox_h = max(1.0, y2 - y1) |
|
|
|
|
|
crop_w = bbox_w * float(crop_factor) |
|
|
crop_h = bbox_h * float(crop_factor) |
|
|
|
|
|
if crop_min_size is not None: |
|
|
crop_w = max(crop_w, float(crop_min_size)) |
|
|
crop_h = max(crop_h, float(crop_min_size)) |
|
|
|
|
|
cx = (x1 + x2) / 2.0 |
|
|
cy = (y1 + y2) / 2.0 |
|
|
|
|
|
rx1 = int(round(cx - crop_w / 2.0)) |
|
|
ry1 = int(round(cy - crop_h / 2.0)) |
|
|
rx2 = int(round(cx + crop_w / 2.0)) |
|
|
ry2 = int(round(cy + crop_h / 2.0)) |
|
|
|
|
|
rx1 = max(0, min(w - 1, rx1)) |
|
|
ry1 = max(0, min(h - 1, ry1)) |
|
|
rx2 = max(rx1 + 1, min(w, rx2)) |
|
|
ry2 = max(ry1 + 1, min(h, ry2)) |
|
|
|
|
|
return (rx1, ry1, rx2, ry2) |
|
|
|
|
|
|
|
|
def crop_image(image: torch.Tensor, crop_region): |
|
|
x1, y1, x2, y2 = crop_region |
|
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
|
|
if image.dim() == 4: |
|
|
return image[:, y1:y2, x1:x2, :] |
|
|
if image.dim() == 3: |
|
|
return image[y1:y2, x1:x2, :] |
|
|
raise ValueError(f"Unexpected image tensor shape: {tuple(image.shape)}") |
|
|
|
|
|
|
|
|
def crop_ndarray2(arr: np.ndarray, crop_region): |
|
|
x1, y1, x2, y2 = crop_region |
|
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
|
|
return arr[y1:y2, x1:x2] |
|
|
|
|
|
|
|
|
def dilate_masks(segmasks, dilation: int): |
|
|
if dilation <= 0: |
|
|
return segmasks |
|
|
if cv2 is None: |
|
|
raise ImportError( |
|
|
"[Salia Ultralytics] opencv-python is required for mask dilation but cv2 could not be imported.\n" |
|
|
"Install: pip install opencv-python-headless" |
|
|
) |
|
|
|
|
|
k = int(dilation) |
|
|
ksize = k * 2 + 1 |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize)) |
|
|
|
|
|
out = [] |
|
|
for bbox, mask, conf in segmasks: |
|
|
m = (mask > 0.5).astype(np.uint8) * 255 |
|
|
m = cv2.dilate(m, kernel, iterations=1) |
|
|
out.append((bbox, (m > 0).astype(np.float32), conf)) |
|
|
return out |
|
|
|
|
|
|
|
|
def combine_masks(segmasks, out_shape_hw: tuple[int, int] | None = None) -> torch.Tensor: |
|
|
if not segmasks: |
|
|
if out_shape_hw is None: |
|
|
return torch.zeros((1, 1, 1), dtype=torch.float32) |
|
|
h, w = out_shape_hw |
|
|
return torch.zeros((1, h, w), dtype=torch.float32) |
|
|
|
|
|
base = segmasks[0][1] |
|
|
combined = np.zeros_like(base, dtype=np.float32) |
|
|
for _, m, _ in segmasks: |
|
|
combined = np.maximum(combined, m.astype(np.float32)) |
|
|
return torch.from_numpy(combined).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SEG = namedtuple( |
|
|
"SEG", |
|
|
[ |
|
|
"cropped_image", |
|
|
"cropped_mask", |
|
|
"confidence", |
|
|
"crop_region", |
|
|
"bbox", |
|
|
"label", |
|
|
"control_net_wrapper", |
|
|
], |
|
|
defaults=[None], |
|
|
) |
|
|
|
|
|
|
|
|
class NO_BBOX_DETECTOR: |
|
|
pass |
|
|
|
|
|
|
|
|
class NO_SEGM_DETECTOR: |
|
|
pass |
|
|
|
|
|
|
|
|
def _create_segmasks(results): |
|
|
|
|
|
bboxes = results[1] |
|
|
segms = results[2] |
|
|
confs = results[3] |
|
|
|
|
|
out = [] |
|
|
for i in range(len(segms)): |
|
|
out.append((bboxes[i], segms[i].astype(np.float32), confs[i])) |
|
|
return out |
|
|
|
|
|
|
|
|
def _inference_bbox(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""): |
|
|
pred = model(image_pil, conf=confidence, device=device) |
|
|
|
|
|
bboxes = pred[0].boxes.xyxy.cpu().numpy() |
|
|
if bboxes.shape[0] == 0: |
|
|
return [[], [], [], []] |
|
|
|
|
|
|
|
|
np_img = np.array(image_pil) |
|
|
if np_img.ndim == 2: |
|
|
h, w = np_img.shape |
|
|
else: |
|
|
h, w = np_img.shape[0], np_img.shape[1] |
|
|
|
|
|
segms = [] |
|
|
for x0, y0, x1, y1 in bboxes: |
|
|
m = np.zeros((h, w), dtype=np.uint8) |
|
|
x0i, y0i, x1i, y1i = int(x0), int(y0), int(x1), int(y1) |
|
|
x0i = max(0, min(w - 1, x0i)) |
|
|
x1i = max(0, min(w, x1i)) |
|
|
y0i = max(0, min(h - 1, y0i)) |
|
|
y1i = max(0, min(h, y1i)) |
|
|
if cv2 is not None: |
|
|
cv2.rectangle(m, (x0i, y0i), (x1i, y1i), 255, -1) |
|
|
else: |
|
|
m[y0i:y1i, x0i:x1i] = 255 |
|
|
segms.append((m > 0)) |
|
|
|
|
|
labels = [] |
|
|
confs = [] |
|
|
for i in range(len(bboxes)): |
|
|
labels.append(pred[0].names[int(pred[0].boxes[i].cls.item())]) |
|
|
confs.append(pred[0].boxes[i].conf.detach().cpu().numpy()) |
|
|
|
|
|
return [labels, list(bboxes), segms, confs] |
|
|
|
|
|
|
|
|
def _inference_segm(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""): |
|
|
pred = model(image_pil, conf=confidence, device=device) |
|
|
|
|
|
bboxes = pred[0].boxes.xyxy.cpu().numpy() |
|
|
if bboxes.shape[0] == 0: |
|
|
return [[], [], [], []] |
|
|
|
|
|
if pred[0].masks is None or pred[0].masks.data is None: |
|
|
|
|
|
return _inference_bbox(model, image_pil, confidence=confidence, device=device) |
|
|
|
|
|
segms = pred[0].masks.data.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
h_orig = image_pil.size[1] |
|
|
w_orig = image_pil.size[0] |
|
|
|
|
|
results = [[], [], [], []] |
|
|
|
|
|
for i in range(len(bboxes)): |
|
|
results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())]) |
|
|
results[1].append(bboxes[i]) |
|
|
|
|
|
mask = torch.from_numpy(segms[i]).float() |
|
|
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(h_orig, w_orig), mode="bilinear", align_corners=False) |
|
|
mask = mask.squeeze(0).squeeze(0) |
|
|
|
|
|
results[2].append(mask.numpy()) |
|
|
results[3].append(pred[0].boxes[i].conf.detach().cpu().numpy()) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
class SaliaUltraBBoxDetector: |
|
|
def __init__(self, model): |
|
|
self.model = model |
|
|
|
|
|
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): |
|
|
drop_size = max(int(drop_size), 1) |
|
|
detected = _inference_bbox(self.model, tensor2pil(image), confidence=float(threshold)) |
|
|
segmasks = _create_segmasks(detected) |
|
|
|
|
|
if int(dilation) > 0: |
|
|
segmasks = dilate_masks(segmasks, int(dilation)) |
|
|
|
|
|
items = [] |
|
|
h = image.shape[1] |
|
|
w = image.shape[2] |
|
|
|
|
|
for (bbox, mask, conf), label in zip(segmasks, detected[0]): |
|
|
x1, y1, x2, y2 = bbox |
|
|
if (x2 - x1) > drop_size and (y2 - y1) > drop_size: |
|
|
crop_region = make_crop_region(w, h, bbox, float(crop_factor)) |
|
|
|
|
|
if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"): |
|
|
crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region) |
|
|
|
|
|
cropped_image = crop_image(image, crop_region) |
|
|
cropped_mask = crop_ndarray2(mask, crop_region) |
|
|
|
|
|
items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None)) |
|
|
|
|
|
segs = (image.shape[1], image.shape[2]), items |
|
|
|
|
|
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): |
|
|
segs = detailer_hook.post_detection(segs) |
|
|
|
|
|
return segs |
|
|
|
|
|
def detect_combined(self, image, threshold, dilation): |
|
|
detected = _inference_bbox(self.model, tensor2pil(image), confidence=float(threshold)) |
|
|
segmasks = _create_segmasks(detected) |
|
|
if int(dilation) > 0: |
|
|
segmasks = dilate_masks(segmasks, int(dilation)) |
|
|
return combine_masks(segmasks, out_shape_hw=(image.shape[1], image.shape[2])) |
|
|
|
|
|
def setAux(self, x): |
|
|
pass |
|
|
|
|
|
|
|
|
class SaliaUltraSegmDetector: |
|
|
def __init__(self, model): |
|
|
self.model = model |
|
|
|
|
|
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): |
|
|
drop_size = max(int(drop_size), 1) |
|
|
detected = _inference_segm(self.model, tensor2pil(image), confidence=float(threshold)) |
|
|
segmasks = _create_segmasks(detected) |
|
|
|
|
|
if int(dilation) > 0: |
|
|
segmasks = dilate_masks(segmasks, int(dilation)) |
|
|
|
|
|
items = [] |
|
|
h = image.shape[1] |
|
|
w = image.shape[2] |
|
|
|
|
|
for (bbox, mask, conf), label in zip(segmasks, detected[0]): |
|
|
x1, y1, x2, y2 = bbox |
|
|
if (x2 - x1) > drop_size and (y2 - y1) > drop_size: |
|
|
crop_region = make_crop_region(w, h, bbox, float(crop_factor)) |
|
|
|
|
|
if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"): |
|
|
crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region) |
|
|
|
|
|
cropped_image = crop_image(image, crop_region) |
|
|
cropped_mask = crop_ndarray2(mask, crop_region) |
|
|
|
|
|
items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None)) |
|
|
|
|
|
segs = (image.shape[1], image.shape[2]), items |
|
|
|
|
|
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): |
|
|
segs = detailer_hook.post_detection(segs) |
|
|
|
|
|
return segs |
|
|
|
|
|
def detect_combined(self, image, threshold, dilation): |
|
|
detected = _inference_segm(self.model, tensor2pil(image), confidence=float(threshold)) |
|
|
segmasks = _create_segmasks(detected) |
|
|
if int(dilation) > 0: |
|
|
segmasks = dilate_masks(segmasks, int(dilation)) |
|
|
return combine_masks(segmasks, out_shape_hw=(image.shape[1], image.shape[2])) |
|
|
|
|
|
def setAux(self, x): |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SaliaUltralyticsDetectorProvider2: |
|
|
"""Load an Ultralytics `.pt` model and provide Impact-compatible detectors.""" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
bboxs = ["bbox/" + x for x in folder_paths.get_filename_list("ultralytics_bbox")] |
|
|
segms = ["segm/" + x for x in folder_paths.get_filename_list("ultralytics_segm")] |
|
|
return {"required": {"model_name": (bboxs + segms,)}} |
|
|
|
|
|
RETURN_TYPES = ("BBOX_DETECTOR", "SEGM_DETECTOR") |
|
|
FUNCTION = "doit" |
|
|
CATEGORY = "Salia/Detectors" |
|
|
|
|
|
def doit(self, model_name: str): |
|
|
|
|
|
model_path = folder_paths.get_full_path("ultralytics", model_name) |
|
|
|
|
|
if model_path is None: |
|
|
if model_name.startswith("bbox/"): |
|
|
model_path = folder_paths.get_full_path("ultralytics_bbox", model_name[5:]) |
|
|
elif model_name.startswith("segm/"): |
|
|
model_path = folder_paths.get_full_path("ultralytics_segm", model_name[5:]) |
|
|
|
|
|
if model_path is None: |
|
|
cands = [] |
|
|
try: |
|
|
cands.extend(folder_paths.get_folder_paths("ultralytics")) |
|
|
if model_name.startswith("bbox/"): |
|
|
cands.extend(folder_paths.get_folder_paths("ultralytics_bbox")) |
|
|
elif model_name.startswith("segm/"): |
|
|
cands.extend(folder_paths.get_folder_paths("ultralytics_segm")) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
formatted = "\n\t".join(cands) |
|
|
raise ValueError( |
|
|
f"[Salia Ultralytics] model file '{model_name}' was not found.\n" |
|
|
f"Searched these folders:\n\t{formatted}\n" |
|
|
f"Tip: put bbox models in 'models/ultralytics/bbox' or segm models in 'models/ultralytics/segm'." |
|
|
) |
|
|
|
|
|
|
|
|
matches = _find_all_model_paths(model_name) |
|
|
_log_selected_model(model_name, os.path.abspath(model_path), matches) |
|
|
|
|
|
model = _load_yolo(model_path) |
|
|
|
|
|
if model_name.startswith("bbox/"): |
|
|
return SaliaUltraBBoxDetector(model), NO_SEGM_DETECTOR() |
|
|
else: |
|
|
return SaliaUltraBBoxDetector(model), SaliaUltraSegmDetector(model) |
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"SaliaUltralyticsDetectorProvider2": SaliaUltralyticsDetectorProvider2, |
|
|
} |
|
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
"SaliaUltralyticsDetectorProvider2": "Salia Ultralytics Detector 2 (Salia)", |
|
|
} |
|
|
|