import os import logging from typing import List, Tuple import torch import numpy as np from ultralytics import YOLO # Impact Pack (for SEG and SEGS helpers) import impact.core as core from impact.core import SEG # Local helpers (your utils_salia) try: # Package-style import (recommended inside a ComfyUI custom node package) from .utils_salia import ( NODE_DIR, IMGSZ, list_local_pt_files, tensor_to_pil, make_crop_region, crop_image, crop_ndarray2, dilate_mask, ) except ImportError: # Fallback if utils_salia is importable directly (not as a package) from utils_salia import ( NODE_DIR, IMGSZ, list_local_pt_files, tensor_to_pil, make_crop_region, crop_image, crop_ndarray2, dilate_mask, ) logger = logging.getLogger(__name__) # ------------------------------------------------------------------------- # YOLO TensorRT-based BBOX_DETECTOR implementation # ------------------------------------------------------------------------- class TRTYOLOBBoxDetector: """ BBOX_DETECTOR interface compatible with Impact Pack / FaceDetailer. Required API: - setAux(x) - detect(image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None) - detect_combined(image, threshold, dilation) """ def __init__(self, yolo_model: YOLO, device: str = "0"): self.bbox_model = yolo_model self.device = device or "0" # aux is used as a class name filter, e.g. FaceDetailer calls setAux('face') self.aux: str | None = None # ------------------------------------------------------------------ # API: setAux # ------------------------------------------------------------------ def setAux(self, x: str): """ Store auxiliary info (typically a class filter like 'face'). FaceDetailer calls setAux('face') before detect() and setAux(None) after. """ self.aux = x # ------------------------------------------------------------------ # API: detect # ------------------------------------------------------------------ def detect( self, image: torch.Tensor, threshold: float, dilation: int, crop_factor: float, drop_size: int = 1, detailer_hook=None, ) -> Tuple[Tuple[int, int], List[SEG]]: """ Main detection method used by FaceDetailer. Args: image: ComfyUI IMAGE tensor [B, H, W, C] in 0..1. threshold: confidence threshold for detections. dilation: mask dilation/erosion size in pixels (>0 dilate, <0 erode). crop_factor: expansion factor for bbox when computing crop_region. drop_size: minimum bbox width/height to keep. detailer_hook: optional hook with post_crop_region / post_detection. Returns: SEGS tuple: ( (H, W), [SEG, SEG, ...] ) """ if image.dim() != 4: raise ValueError( "[TRTYOLOBBoxDetector] Expected IMAGE tensor with 4 dims [B, H, W, C]." ) # Impact Pack detectors typically only use the first image in a batch. if image.shape[0] != 1: logger.warning( "[TRTYOLOBBoxDetector] Batch > 1 detected; using only the first image for detection." ) image = image[:1] # Original image size h, w = int(image.shape[1]), int(image.shape[2]) shape = (h, w) # Convert tensor to PIL for Ultralytics inference pil_img = tensor_to_pil(image) # Run YOLO model prediction with given threshold on the chosen device pred_list = self.bbox_model(pil_img, conf=threshold, device=self.device, verbose=False) if len(pred_list) == 0: return (shape, []) pred = pred_list[0] boxes = pred.boxes if boxes is None or boxes.xyxy is None or boxes.xyxy.shape[0] == 0: return (shape, []) xyxy = boxes.xyxy.cpu().numpy() # [N, 4] (x1, y1, x2, y2) confs = boxes.conf.cpu().numpy() # [N] confidence clses = boxes.cls.cpu().numpy().astype(int) # [N] class indices names = pred.names # class names (can be list/tuple or dict) seg_items: List[SEG] = [] for i in range(xyxy.shape[0]): x1, y1, x2, y2 = xyxy[i] score = float(confs[i]) cls_id = int(clses[i]) # ------------------------------------------------------------------ # Class label lookup robust to list/dict for names # ------------------------------------------------------------------ if isinstance(names, (list, tuple)): label = names[cls_id] if 0 <= cls_id < len(names) else str(cls_id) else: # dict-like: {class_index: "name"} label = names.get(cls_id, str(cls_id)) # ------------------------------------------------------------------ # Aux filter (e.g. only keep 'face') # ------------------------------------------------------------------ if self.aux and isinstance(self.aux, str): if label.lower() != self.aux.lower(): # Skip detections for other classes continue # ------------------------------------------------------------------ # Drop tiny boxes # ------------------------------------------------------------------ box_w = x2 - x1 box_h = y2 - y1 if box_w <= drop_size or box_h <= drop_size: continue # Clamp bbox to image bounds (integer pixel coords) x1_i = max(int(np.floor(x1)), 0) y1_i = max(int(np.floor(y1)), 0) x2_i = min(int(np.ceil(x2)), w) y2_i = min(int(np.ceil(y2)), h) if x2_i <= x1_i or y2_i <= y1_i: continue # ------------------------------------------------------------------ # Create full-image mask from bbox as float32 in [0, 1] # ------------------------------------------------------------------ mask = np.zeros((h, w), dtype=np.float32) mask[y1_i:y2_i, x1_i:x2_i] = 1.0 # Optional dilation / erosion via GPU-aware helper. # IMPORTANT: dilate_mask must return float32 [0,1] as well. if dilation: mask = dilate_mask(mask, dilation) # Impact core uses bbox as [x1, y1, x2, y2] item_bbox = [float(x1), float(y1), float(x2), float(y2)] # ------------------------------------------------------------------ # Compute crop region (expanded bbox) in xyxy format # ------------------------------------------------------------------ crop_region = make_crop_region(w, h, item_bbox, crop_factor) if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"): crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region) # ------------------------------------------------------------------ # Crop image + mask # ------------------------------------------------------------------ cropped_image = crop_image(image, crop_region) # torch [1, h', w', C] cropped_mask = crop_ndarray2(mask, crop_region) # np.float32 [h', w'] in [0,1] # Build SEG object for this detection seg = SEG( cropped_image, cropped_mask, score, crop_region, item_bbox, label, None, # control_net_wrapper ) seg_items.append(seg) segs = (shape, seg_items) # Optional post-detection hook if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): segs = detailer_hook.post_detection(segs) return segs # ------------------------------------------------------------------ # API: detect_combined # ------------------------------------------------------------------ def detect_combined( self, image: torch.Tensor, threshold: float, dilation: int, ) -> torch.Tensor: """ Optional combined-mask API: returns a single MASK tensor covering all detections. """ shape, seg_list = self.detect( image=image, threshold=threshold, dilation=dilation, crop_factor=1.0, drop_size=1, detailer_hook=None, ) return core.segs_to_combined_mask((shape, seg_list)) # ------------------------------------------------------------------------- # NODE 1: TRTYOLOEngineBuilder # - Builds a TensorRT engine from a .pt file in the node folder. # ------------------------------------------------------------------------- class TRTYOLOEngineBuilder: @classmethod def INPUT_TYPES(cls): pt_files = list_local_pt_files() default_name = pt_files[0] if pt_files else "face.pt" return { "required": { "pt_model_name": ( pt_files, { "default": default_name, "tooltip": ( "Select a YOLO .pt file that lives in the SAME folder as this node file." ), }, ), } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("engine_path",) FUNCTION = "build" CATEGORY = "ImpactPack/TensorRT" def build(self, pt_model_name: str): # Resolve .pt path relative to this node file pt_path = os.path.join(NODE_DIR, pt_model_name) if not os.path.isfile(pt_path): raise FileNotFoundError( f"[TRTYOLOEngineBuilder] .pt model not found: {pt_path}" ) logger.info( f"[TRTYOLOEngineBuilder] Exporting TensorRT engine from '{pt_path}' " f"with imgsz={IMGSZ} (H,W), batch=1, half=True, device='0', exist_ok=True" ) # Export the model to TensorRT engine format try: result = YOLO(pt_path).export( format="engine", imgsz=IMGSZ, half=True, device="0", exist_ok=True, ) except TypeError: # Fallback for older Ultralytics versions without 'exist_ok' or similar args result = YOLO(pt_path).export( format="engine", imgsz=IMGSZ, half=True, device="0", ) # Handle return type (path string, Path, or list/tuple of them) if isinstance(result, (list, tuple)): engine_path = result[0] if len(result) > 0 else "" else: engine_path = result engine_path = str(engine_path) if not engine_path: raise RuntimeError( "[TRTYOLOEngineBuilder] Engine export failed (empty output path)." ) # If Ultralytics returned a relative path, try to resolve it robustly. if not os.path.isabs(engine_path): # 1) Check next to the .pt model (Ultralytics usually uses self.file.with_suffix('.engine')) model_dir = os.path.dirname(pt_path) candidate = os.path.join(model_dir, engine_path) if os.path.isfile(candidate): engine_path = candidate else: # 2) As a fallback, try relative to NODE_DIR candidate = os.path.join(NODE_DIR, engine_path) if os.path.isfile(candidate): engine_path = candidate # If still not found, we leave engine_path as-is; user may have a runs/... path. logger.info(f"[TRTYOLOEngineBuilder] Export complete. Engine path: {engine_path}") return (engine_path,) # ------------------------------------------------------------------------- # NODE 2: TRTYOLOBBoxDetectorProvider # - Loads the TensorRT engine and provides a BBOX_DETECTOR object. # ------------------------------------------------------------------------- class TRTYOLOBBoxDetectorProvider: @classmethod def INPUT_TYPES(cls): return { "required": { "engine_path": ( "STRING", { "default": "", "tooltip": ( "Path to the TensorRT .engine file.\n" "Can be an absolute path or relative to this node's folder.\n" "Typically use the output of TRTYOLOEngineBuilder." ), }, ), } } RETURN_TYPES = ("BBOX_DETECTOR",) RETURN_NAMES = ("bbox_detector",) FUNCTION = "load" CATEGORY = "ImpactPack/TensorRT" def load(self, engine_path: str): if not engine_path: raise ValueError( "[TRTYOLOBBoxDetectorProvider] 'engine_path' is empty. " "Provide a valid path or connect from TRTYOLOEngineBuilder." ) engine_path = engine_path.strip() # Resolve relative paths against this node's folder if not os.path.isabs(engine_path): engine_path = os.path.join(NODE_DIR, engine_path) if not os.path.isfile(engine_path): raise FileNotFoundError( f"[TRTYOLOBBoxDetectorProvider] Engine file not found: {engine_path}" ) logger.info( f"[TRTYOLOBBoxDetectorProvider] Loading YOLO TensorRT engine from '{engine_path}' on device '0'" ) # Load the TensorRT engine with Ultralytics (TensorRT backend) yolo_model = YOLO(engine_path) detector = TRTYOLOBBoxDetector(yolo_model, device="0") return (detector,) # ------------------------------------------------------------------------- # ComfyUI node registration # ------------------------------------------------------------------------- NODE_CLASS_MAPPINGS = { "TRTYOLOEngineBuilder": TRTYOLOEngineBuilder, "TRTYOLOBBoxDetectorProvider": TRTYOLOBBoxDetectorProvider, } NODE_DISPLAY_NAME_MAPPINGS = { "TRTYOLOEngineBuilder": "TensorRT YOLO Engine Builder (1344x768)", "TRTYOLOBBoxDetectorProvider": "TensorRT YOLO BBox Detector", }