MyCustomNodes / Salia_UltralyticsDetectorProvider2.py
saliacoel's picture
Upload Salia_UltralyticsDetectorProvider2.py
01cf9f1 verified
"""
Salia Ultralytics Detector Provider (ComfyUI custom node)
Goal:
- Provide the same outputs as Impact-Subpack's `UltralyticsDetectorProvider`:
- BBOX_DETECTOR
- SEGM_DETECTOR
- But packaged so you can drop it into your own custom node folder (your Salia_* environment)
without requiring ComfyUI-Impact-Subpack.
Notes:
- This file intentionally keeps dependencies minimal and self-contained.
- It uses `ultralytics.YOLO` to run `.pt` models directly (no TensorRT build step).
- For PyTorch >= 2.6, `torch.load` defaults to `weights_only=True` which can break
legacy `.pt` checkpoints. This file adds an OPTIONAL whitelist-based fallback
to `weights_only=False` (unsafe) for specifically trusted model filenames.
"""
from __future__ import annotations
import os
import logging
import pickle
from datetime import datetime
from contextlib import contextmanager
from collections import namedtuple
import folder_paths
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
try:
import cv2 # opencv-python or opencv-python-headless
except Exception:
cv2 = None
# ---------------------------
# Model folders (same layout as Impact Subpack)
# ---------------------------
_SUPPORTED_PT_EXTS = getattr(folder_paths, "supported_pt_extensions", [".pt", ".pth", ".ckpt", ".safetensors"])
def _add_folder_path_and_extensions(folder_name: str, paths: list[str], extensions: list[str] | tuple[str, ...]):
"""Add/merge a folder_paths entry without depending on Impact-Pack helpers."""
if folder_name in folder_paths.folder_names_and_paths:
existing_paths, existing_exts = folder_paths.folder_names_and_paths[folder_name]
merged_paths = list(existing_paths)
for p in paths:
if p not in merged_paths:
merged_paths.append(p)
merged_exts = list(existing_exts)
for ext in extensions:
if ext not in merged_exts:
merged_exts.append(ext)
folder_paths.folder_names_and_paths[folder_name] = (merged_paths, tuple(merged_exts))
else:
folder_paths.folder_names_and_paths[folder_name] = (list(paths), tuple(extensions))
def _update_model_paths(base_path: str):
"""Register standard Impact-Subpack ultralytics model locations."""
_add_folder_path_and_extensions(
"ultralytics_bbox",
[os.path.join(base_path, "ultralytics", "bbox")],
_SUPPORTED_PT_EXTS,
)
_add_folder_path_and_extensions(
"ultralytics_segm",
[os.path.join(base_path, "ultralytics", "segm")],
_SUPPORTED_PT_EXTS,
)
_add_folder_path_and_extensions(
"ultralytics",
[os.path.join(base_path, "ultralytics")],
_SUPPORTED_PT_EXTS,
)
# Register common folders (models_dir + ComfyUI-Manager download_model_base)
_update_model_paths(folder_paths.models_dir)
if "download_model_base" in folder_paths.folder_names_and_paths:
try:
_update_model_paths(folder_paths.get_folder_paths("download_model_base")[0])
except Exception:
pass
# Also register local folder(s) inside THIS custom-node extension, so you can keep
# models next to your Salia_*.py files if you want.
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
for local_dir in [
os.path.join(_THIS_DIR, "nodes"),
os.path.join(_THIS_DIR, "models"),
_THIS_DIR,
]:
if os.path.isdir(local_dir):
_add_folder_path_and_extensions("ultralytics_bbox", [local_dir], _SUPPORTED_PT_EXTS)
_add_folder_path_and_extensions("ultralytics_segm", [local_dir], _SUPPORTED_PT_EXTS)
_add_folder_path_and_extensions("ultralytics", [local_dir], _SUPPORTED_PT_EXTS)
# ---------------------------
# Optional safe-load fallback (PyTorch >= 2.6)
# ---------------------------
_ORIG_TORCH_LOAD = torch.load
def _get_whitelist_file() -> str | None:
"""Create/return the whitelist file path under ComfyUI's user directory."""
try:
user_dir = folder_paths.get_user_directory()
except Exception:
user_dir = None
if not user_dir or not os.path.isdir(user_dir):
return None
wl_dir = os.path.join(user_dir, "default", "ComfyUI-Salia-Ultralytics")
wl_file = os.path.join(wl_dir, "model-whitelist.txt")
try:
os.makedirs(wl_dir, exist_ok=True)
if not os.path.exists(wl_file):
with open(wl_file, "w", encoding="utf-8") as f:
f.write("# Add base filenames of trusted legacy models here (one per line).\n")
f.write("# Example: eyes.pt\n")
f.write("# These will be allowed to load with weights_only=False if safe loading fails.\n")
f.write("# WARNING: Only add models you trust.\n")
except Exception:
return None
return wl_file
_WHITELIST_PATH = _get_whitelist_file()
# ---------------------------
# Model path logging (requested)
# ---------------------------
def _get_model_load_log_file() -> str:
"""
Log file path used to record which ultralytics model file was actually loaded.
Prefer the same ComfyUI user dir used for the whitelist (if available).
"""
# If whitelist exists, put log next to it (same directory).
if _WHITELIST_PATH:
base_dir = os.path.dirname(_WHITELIST_PATH)
return os.path.join(base_dir, "model-load-log.txt")
# Fallback: try ComfyUI user directory
try:
user_dir = folder_paths.get_user_directory()
except Exception:
user_dir = None
if user_dir and os.path.isdir(user_dir):
base_dir = os.path.join(user_dir, "default", "ComfyUI-Salia-Ultralytics")
try:
os.makedirs(base_dir, exist_ok=True)
except Exception:
pass
return os.path.join(base_dir, "model-load-log.txt")
# Last resort: next to this python file
return os.path.join(_THIS_DIR, "model-load-log.txt")
_MODEL_LOAD_LOG_PATH = _get_model_load_log_file()
def _find_all_model_paths(model_name: str) -> list[str]:
"""
Find all possible on-disk matches across the registered ultralytics folders.
Useful if the same filename exists in multiple locations.
"""
matches: list[str] = []
try:
ultra_roots = folder_paths.get_folder_paths("ultralytics")
except Exception:
ultra_roots = []
try:
bbox_roots = folder_paths.get_folder_paths("ultralytics_bbox")
except Exception:
bbox_roots = []
try:
segm_roots = folder_paths.get_folder_paths("ultralytics_segm")
except Exception:
segm_roots = []
def add_if_exists(root: str, rel: str):
p = os.path.join(root, rel)
if os.path.exists(p):
matches.append(os.path.abspath(p))
# model_name might be "bbox/foo.pt" or "segm/foo.pt" (includes subfolder)
for r in ultra_roots:
add_if_exists(r, model_name)
# Also search the specialized bbox/segm roots with the prefix stripped
if model_name.startswith("bbox/"):
rel = model_name[5:]
for r in bbox_roots:
add_if_exists(r, rel)
elif model_name.startswith("segm/"):
rel = model_name[5:]
for r in segm_roots:
add_if_exists(r, rel)
# De-dupe preserving order
out: list[str] = []
seen = set()
for p in matches:
if p not in seen:
seen.add(p)
out.append(p)
return out
def _log_selected_model(model_name: str, model_path: str, matches: list[str] | None = None):
"""
Prints the resolved model path to console AND appends it to a log file.
"""
# 1) Console output
print(f"[Salia Ultralytics] Selected model_name: {model_name}")
print(f"[Salia Ultralytics] Resolved model_path: {model_path}")
if matches and len(matches) > 1:
print("[Salia Ultralytics] Multiple matches found (first one is used by get_full_path):")
for p in matches:
print(f" - {p}")
print(f"[Salia Ultralytics] Model load log file: {_MODEL_LOAD_LOG_PATH}")
# Also emit to python logging (ComfyUI typically captures this)
logging.info("[Salia Ultralytics] Selected model_name: %s", model_name)
logging.info("[Salia Ultralytics] Resolved model_path: %s", model_path)
if matches and len(matches) > 1:
logging.warning("[Salia Ultralytics] Multiple matches found (first one is used by get_full_path):")
for p in matches:
logging.warning(" - %s", p)
logging.info("[Salia Ultralytics] Model load log file: %s", _MODEL_LOAD_LOG_PATH)
# 2) File append
try:
ts = datetime.now().isoformat(timespec="seconds")
exists = os.path.isfile(model_path)
size = os.path.getsize(model_path) if exists else -1
log_dir = os.path.dirname(_MODEL_LOAD_LOG_PATH)
if log_dir:
os.makedirs(log_dir, exist_ok=True)
with open(_MODEL_LOAD_LOG_PATH, "a", encoding="utf-8") as f:
f.write(f"{ts}\t{model_name}\t{model_path}\texists={exists}\tsize={size}\n")
if matches and len(matches) > 1:
for p in matches:
f.write(f"{ts}\tmatch\t{p}\n")
except Exception as e:
logging.warning("[Salia Ultralytics] Failed to write model-load log to %s: %s", _MODEL_LOAD_LOG_PATH, e)
def _load_whitelist(filepath: str | None) -> set[str]:
if not filepath:
return set()
try:
approved: set[str] = set()
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
approved.add(os.path.basename(line))
return approved
except Exception:
return set()
_MODEL_WHITELIST = _load_whitelist(_WHITELIST_PATH)
def _torch_load_wrapper(*args, **kwargs):
"""Try safe load first; if it fails due to weights-only restrictions, allow fallback if whitelisted."""
filename = None
if args and isinstance(args[0], str):
filename = os.path.basename(args[0])
elif isinstance(kwargs.get("f"), str):
filename = os.path.basename(kwargs["f"])
try:
return _ORIG_TORCH_LOAD(*args, **kwargs)
except pickle.UnpicklingError as e:
msg = str(e)
# Heuristic: this is the common PyTorch >=2.6 safe-load failure mode.
maybe_weights_only_error = (
"Weights only load failed" in msg
or "Unsupported global" in msg
or "disallowed" in msg
or "not allowed" in msg
or "getattr" in msg
)
if not maybe_weights_only_error:
raise
# Refresh whitelist from disk (so users can edit without restarting, sometimes)
global _MODEL_WHITELIST
_MODEL_WHITELIST = _load_whitelist(_WHITELIST_PATH)
if filename and filename in _MODEL_WHITELIST:
logging.warning(
"[Salia Ultralytics] Safe torch.load failed for '%s'. Retrying with weights_only=False because it's whitelisted (%s).",
filename,
_WHITELIST_PATH,
)
retry_kwargs = dict(kwargs)
retry_kwargs["weights_only"] = False
return _ORIG_TORCH_LOAD(*args, **retry_kwargs)
logging.error(
"[Salia Ultralytics] Blocked unsafe model load for '%s'.\n"
"Safe loading failed and the file is not whitelisted.\n"
"If you TRUST this model, add its base name to: %s",
filename or "[unknown]",
_WHITELIST_PATH or "[whitelist path unavailable]",
)
raise
@contextmanager
def _patched_torch_load_for_ultralytics():
"""Patch torch.load only while ultralytics loads a checkpoint."""
# If PyTorch doesn't even have the safe-loader feature, don't patch.
if not hasattr(torch.serialization, "safe_globals"):
yield
return
prev = torch.load
torch.load = _torch_load_wrapper
try:
yield
finally:
torch.load = prev
def _load_yolo(model_path: str):
"""Load an Ultralytics YOLO model (with optional safe-load fallback)."""
try:
from ultralytics import YOLO # lazy import
except Exception as e:
raise ImportError(
"[Salia Ultralytics] ultralytics is not installed. Install it in your ComfyUI env, e.g.:\n"
"pip install ultralytics"
) from e
with _patched_torch_load_for_ultralytics():
return YOLO(model_path)
# ---------------------------
# Minimal Impact-compatible utilities (self-contained)
# ---------------------------
def _tensor2np_rgb(image: torch.Tensor) -> np.ndarray:
"""Convert a ComfyUI IMAGE tensor to a uint8 RGB numpy image."""
# ComfyUI image is usually: (B,H,W,C) float in [0,1]
if not isinstance(image, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(image)}")
if image.dim() == 4:
img = image[0]
else:
img = image
img = img.detach()
if img.is_cuda:
img = img.cpu()
img = img.clamp(0, 1).numpy()
if img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)
img_u8 = (img * 255.0).round().astype(np.uint8)
return img_u8
def tensor2pil(image: torch.Tensor) -> Image.Image:
return Image.fromarray(_tensor2np_rgb(image))
def make_crop_region(w: int, h: int, bbox_xyxy, crop_factor: float, crop_min_size: int | None = None):
x1, y1, x2, y2 = [float(v) for v in bbox_xyxy]
bbox_w = max(1.0, x2 - x1)
bbox_h = max(1.0, y2 - y1)
crop_w = bbox_w * float(crop_factor)
crop_h = bbox_h * float(crop_factor)
if crop_min_size is not None:
crop_w = max(crop_w, float(crop_min_size))
crop_h = max(crop_h, float(crop_min_size))
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
rx1 = int(round(cx - crop_w / 2.0))
ry1 = int(round(cy - crop_h / 2.0))
rx2 = int(round(cx + crop_w / 2.0))
ry2 = int(round(cy + crop_h / 2.0))
rx1 = max(0, min(w - 1, rx1))
ry1 = max(0, min(h - 1, ry1))
rx2 = max(rx1 + 1, min(w, rx2))
ry2 = max(ry1 + 1, min(h, ry2))
return (rx1, ry1, rx2, ry2)
def crop_image(image: torch.Tensor, crop_region):
x1, y1, x2, y2 = crop_region
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
if image.dim() == 4:
return image[:, y1:y2, x1:x2, :]
if image.dim() == 3:
return image[y1:y2, x1:x2, :]
raise ValueError(f"Unexpected image tensor shape: {tuple(image.shape)}")
def crop_ndarray2(arr: np.ndarray, crop_region):
x1, y1, x2, y2 = crop_region
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
return arr[y1:y2, x1:x2]
def dilate_masks(segmasks, dilation: int):
if dilation <= 0:
return segmasks
if cv2 is None:
raise ImportError(
"[Salia Ultralytics] opencv-python is required for mask dilation but cv2 could not be imported.\n"
"Install: pip install opencv-python-headless"
)
k = int(dilation)
ksize = k * 2 + 1
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
out = []
for bbox, mask, conf in segmasks:
m = (mask > 0.5).astype(np.uint8) * 255
m = cv2.dilate(m, kernel, iterations=1)
out.append((bbox, (m > 0).astype(np.float32), conf))
return out
def combine_masks(segmasks, out_shape_hw: tuple[int, int] | None = None) -> torch.Tensor:
if not segmasks:
if out_shape_hw is None:
return torch.zeros((1, 1, 1), dtype=torch.float32)
h, w = out_shape_hw
return torch.zeros((1, h, w), dtype=torch.float32)
base = segmasks[0][1]
combined = np.zeros_like(base, dtype=np.float32)
for _, m, _ in segmasks:
combined = np.maximum(combined, m.astype(np.float32))
return torch.from_numpy(combined).unsqueeze(0)
# ---------------------------
# Impact-compatible detector wrapper objects
# ---------------------------
SEG = namedtuple(
"SEG",
[
"cropped_image",
"cropped_mask",
"confidence",
"crop_region",
"bbox",
"label",
"control_net_wrapper",
],
defaults=[None],
)
class NO_BBOX_DETECTOR:
pass
class NO_SEGM_DETECTOR:
pass
def _create_segmasks(results):
# results = [labels, bboxes_xyxy, segms, confs]
bboxes = results[1]
segms = results[2]
confs = results[3]
out = []
for i in range(len(segms)):
out.append((bboxes[i], segms[i].astype(np.float32), confs[i]))
return out
def _inference_bbox(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""):
pred = model(image_pil, conf=confidence, device=device)
bboxes = pred[0].boxes.xyxy.cpu().numpy() # xyxy
if bboxes.shape[0] == 0:
return [[], [], [], []]
# Make simple rectangle masks for each bbox
np_img = np.array(image_pil)
if np_img.ndim == 2:
h, w = np_img.shape
else:
h, w = np_img.shape[0], np_img.shape[1]
segms = []
for x0, y0, x1, y1 in bboxes:
m = np.zeros((h, w), dtype=np.uint8)
x0i, y0i, x1i, y1i = int(x0), int(y0), int(x1), int(y1)
x0i = max(0, min(w - 1, x0i))
x1i = max(0, min(w, x1i))
y0i = max(0, min(h - 1, y0i))
y1i = max(0, min(h, y1i))
if cv2 is not None:
cv2.rectangle(m, (x0i, y0i), (x1i, y1i), 255, -1)
else:
m[y0i:y1i, x0i:x1i] = 255
segms.append((m > 0))
labels = []
confs = []
for i in range(len(bboxes)):
labels.append(pred[0].names[int(pred[0].boxes[i].cls.item())])
confs.append(pred[0].boxes[i].conf.detach().cpu().numpy())
return [labels, list(bboxes), segms, confs]
def _inference_segm(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""):
pred = model(image_pil, conf=confidence, device=device)
bboxes = pred[0].boxes.xyxy.cpu().numpy() # xyxy
if bboxes.shape[0] == 0:
return [[], [], [], []]
if pred[0].masks is None or pred[0].masks.data is None:
# fallback: no masks, treat like bbox
return _inference_bbox(model, image_pil, confidence=confidence, device=device)
segms = pred[0].masks.data.detach().cpu().numpy() # (n, h, w) in model-space
# Resize masks back to original image size
h_orig = image_pil.size[1]
w_orig = image_pil.size[0]
results = [[], [], [], []]
for i in range(len(bboxes)):
results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())])
results[1].append(bboxes[i])
mask = torch.from_numpy(segms[i]).float()
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(h_orig, w_orig), mode="bilinear", align_corners=False)
mask = mask.squeeze(0).squeeze(0)
results[2].append(mask.numpy())
results[3].append(pred[0].boxes[i].conf.detach().cpu().numpy())
return results
class SaliaUltraBBoxDetector:
def __init__(self, model):
self.model = model
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
drop_size = max(int(drop_size), 1)
detected = _inference_bbox(self.model, tensor2pil(image), confidence=float(threshold))
segmasks = _create_segmasks(detected)
if int(dilation) > 0:
segmasks = dilate_masks(segmasks, int(dilation))
items = []
h = image.shape[1]
w = image.shape[2]
for (bbox, mask, conf), label in zip(segmasks, detected[0]):
x1, y1, x2, y2 = bbox
if (x2 - x1) > drop_size and (y2 - y1) > drop_size:
crop_region = make_crop_region(w, h, bbox, float(crop_factor))
if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"):
crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region)
cropped_image = crop_image(image, crop_region)
cropped_mask = crop_ndarray2(mask, crop_region)
items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None))
segs = (image.shape[1], image.shape[2]), items
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
segs = detailer_hook.post_detection(segs)
return segs
def detect_combined(self, image, threshold, dilation):
detected = _inference_bbox(self.model, tensor2pil(image), confidence=float(threshold))
segmasks = _create_segmasks(detected)
if int(dilation) > 0:
segmasks = dilate_masks(segmasks, int(dilation))
return combine_masks(segmasks, out_shape_hw=(image.shape[1], image.shape[2]))
def setAux(self, x):
pass
class SaliaUltraSegmDetector:
def __init__(self, model):
self.model = model
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
drop_size = max(int(drop_size), 1)
detected = _inference_segm(self.model, tensor2pil(image), confidence=float(threshold))
segmasks = _create_segmasks(detected)
if int(dilation) > 0:
segmasks = dilate_masks(segmasks, int(dilation))
items = []
h = image.shape[1]
w = image.shape[2]
for (bbox, mask, conf), label in zip(segmasks, detected[0]):
x1, y1, x2, y2 = bbox
if (x2 - x1) > drop_size and (y2 - y1) > drop_size:
crop_region = make_crop_region(w, h, bbox, float(crop_factor))
if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"):
crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region)
cropped_image = crop_image(image, crop_region)
cropped_mask = crop_ndarray2(mask, crop_region)
items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None))
segs = (image.shape[1], image.shape[2]), items
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
segs = detailer_hook.post_detection(segs)
return segs
def detect_combined(self, image, threshold, dilation):
detected = _inference_segm(self.model, tensor2pil(image), confidence=float(threshold))
segmasks = _create_segmasks(detected)
if int(dilation) > 0:
segmasks = dilate_masks(segmasks, int(dilation))
return combine_masks(segmasks, out_shape_hw=(image.shape[1], image.shape[2]))
def setAux(self, x):
pass
# ---------------------------
# The actual ComfyUI Node
# ---------------------------
class SaliaUltralyticsDetectorProvider2:
"""Load an Ultralytics `.pt` model and provide Impact-compatible detectors."""
@classmethod
def INPUT_TYPES(cls):
bboxs = ["bbox/" + x for x in folder_paths.get_filename_list("ultralytics_bbox")]
segms = ["segm/" + x for x in folder_paths.get_filename_list("ultralytics_segm")]
return {"required": {"model_name": (bboxs + segms,)}}
RETURN_TYPES = ("BBOX_DETECTOR", "SEGM_DETECTOR")
FUNCTION = "doit"
CATEGORY = "Salia/Detectors"
def doit(self, model_name: str):
# First, allow selecting a file like "bbox/foo.pt" that lives under models/ultralytics/bbox
model_path = folder_paths.get_full_path("ultralytics", model_name)
if model_path is None:
if model_name.startswith("bbox/"):
model_path = folder_paths.get_full_path("ultralytics_bbox", model_name[5:])
elif model_name.startswith("segm/"):
model_path = folder_paths.get_full_path("ultralytics_segm", model_name[5:])
if model_path is None:
cands = []
try:
cands.extend(folder_paths.get_folder_paths("ultralytics"))
if model_name.startswith("bbox/"):
cands.extend(folder_paths.get_folder_paths("ultralytics_bbox"))
elif model_name.startswith("segm/"):
cands.extend(folder_paths.get_folder_paths("ultralytics_segm"))
except Exception:
pass
formatted = "\n\t".join(cands)
raise ValueError(
f"[Salia Ultralytics] model file '{model_name}' was not found.\n"
f"Searched these folders:\n\t{formatted}\n"
f"Tip: put bbox models in 'models/ultralytics/bbox' or segm models in 'models/ultralytics/segm'."
)
# NEW: print + log the resolved on-disk path (and any duplicates)
matches = _find_all_model_paths(model_name)
_log_selected_model(model_name, os.path.abspath(model_path), matches)
model = _load_yolo(model_path)
if model_name.startswith("bbox/"):
return SaliaUltraBBoxDetector(model), NO_SEGM_DETECTOR()
else:
return SaliaUltraBBoxDetector(model), SaliaUltraSegmDetector(model)
NODE_CLASS_MAPPINGS = {
"SaliaUltralyticsDetectorProvider2": SaliaUltralyticsDetectorProvider2,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaliaUltralyticsDetectorProvider2": "Salia Ultralytics Detector 2 (Salia)",
}