lucid-hf's picture
CI: deploy Docker/PDM Space
98a3af2 verified
#!/usr/bin/env python3
"""
DEIM Debug Script for Interactive Bbox Detection and Visualization
Copyright (c) 2024 The DEIM Authors. All Rights Reserved.
This script provides interactive debugging capabilities for DEIM models:
- Load model from config and checkpoint
- Process images and videos
- Interactive OpenCV visualization with imshow
- Adjustable confidence thresholds
- Keyboard controls for video playback
"""
import argparse
import os
import sys
import time
from pathlib import Path
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
# Add the project root to Python path
sys.path.insert(0, str(Path(__file__).parent))
from engine.core import YAMLConfig
# Default class names - will be overridden by dataset configuration
DEFAULT_CLASSES = {
1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorbike', 5: 'aeroplane',
6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'trafficlight',
11: 'firehydrant', 13: 'stopsign', 14: 'parkingmeter', 15: 'bench',
16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep',
21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe',
27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie',
33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard',
37: 'sportsball', 38: 'kite', 39: 'baseballbat', 40: 'baseballglove',
41: 'skateboard', 42: 'surfboard', 43: 'tennisracket', 44: 'bottle',
46: 'wineglass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon',
51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',
56: 'broccoli', 57: 'carrot', 58: 'hotdog', 59: 'pizza', 60: 'donut',
61: 'cake', 62: 'chair', 63: 'sofa', 64: 'pottedplant', 65: 'bed',
67: 'diningtable', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse',
75: 'remote', 76: 'keyboard', 77: 'cellphone', 78: 'microwave',
79: 'oven', 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book',
85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddybear',
89: 'hairdrier', 90: 'toothbrush'
}
def load_class_names_from_config(cfg):
"""Load class names from dataset configuration"""
try:
# Import here to avoid circular imports
from engine.data.dataset.coco_dataset import mscoco_category2name
# Check if we can access dataset configuration
if hasattr(cfg, 'val_dataloader') and cfg.val_dataloader is not None:
dataset_cfg = cfg.val_dataloader.dataset
# Try to instantiate dataset to get class names
try:
# Get the number of classes from config
num_classes = getattr(cfg, 'num_classes', 80)
# Check if using COCO remapping
remap_mscoco = getattr(cfg, 'remap_mscoco_category', False)
if remap_mscoco:
print(f"Using COCO class names (remapped)")
return mscoco_category2name
# Try to create dataset instance to get category names
if hasattr(dataset_cfg, 'ann_file') and dataset_cfg.ann_file:
# For COCO-style datasets, try to load annotations
try:
from pycocotools.coco import COCO
if os.path.exists(dataset_cfg.ann_file):
coco = COCO(dataset_cfg.ann_file)
categories = coco.dataset.get('categories', [])
if categories:
class_names = {}
for i, cat in enumerate(categories):
# Use category ID as key for proper mapping
class_names[cat['id']] = cat['name']
print(f"Loaded {len(class_names)} class names from annotation file")
return class_names
except Exception as e:
print(f"Could not load classes from annotation file: {e}")
# Generate generic class names based on number of classes
print(f"Generating generic class names for {num_classes} classes")
if num_classes == 80:
return mscoco_category2name
elif num_classes == 1:
return {1: 'object'}
elif num_classes == 2:
return {1: 'person', 2: 'object'} # Common for crowd detection
elif num_classes == 20:
# VOC classes
voc_classes = {
1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat', 5: 'bottle',
6: 'bus', 7: 'car', 8: 'cat', 9: 'chair', 10: 'cow',
11: 'diningtable', 12: 'dog', 13: 'horse', 14: 'motorbike', 15: 'person',
16: 'pottedplant', 17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor'
}
return voc_classes
else:
# Generic class names
return {i + 1: f'class_{i + 1}' for i in range(num_classes)}
except Exception as e:
print(f"Could not instantiate dataset: {e}")
except Exception as e:
print(f"Could not load class names from config: {e}")
# Fallback to default COCO classes
print("Using default COCO class names")
return DEFAULT_CLASSES
# Color palette for bounding boxes (BGR format for OpenCV)
COLORS = [
(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255),
(0, 255, 255), (128, 0, 0), (0, 128, 0), (0, 0, 128), (128, 128, 0),
(128, 0, 128), (0, 128, 128), (255, 128, 0), (255, 0, 128), (128, 255, 0),
(0, 255, 128), (128, 0, 255), (0, 128, 255), (192, 192, 192), (64, 64, 64)
]
class DEIMModel(nn.Module):
"""Wrapper for DEIM model with postprocessing"""
def __init__(self, config_path, checkpoint_path, device='cuda', input_size=640):
super().__init__()
self.device = device
config_overrides = {'HGNetv2': {'pretrained': False}}
self.cfg = YAMLConfig(config_path, resume=checkpoint_path, **config_overrides)
print(f"Loading checkpoint from: {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location='cpu')['model']
self.cfg.model.load_state_dict(state_dict)
self.model = self.cfg.model.eval().to(device)
self.postprocessor = self.cfg.postprocessor.eval().to(device)
self.class_names = load_class_names_from_config(self.cfg)
self.num_classes = getattr(self.cfg, 'num_classes', len(self.class_names))
print(f"Model loaded successfully on {device}")
print(f"Model type: {type(self.model).__name__}")
print(f"Number of classes: {self.num_classes}")
print(f"Sample classes: {dict(list(self.class_names.items())[:5])}...")
def forward(self, images, orig_sizes):
"""Forward pass through model and postprocessor"""
with torch.no_grad():
outputs = self.model(images)
results = self.postprocessor(outputs, orig_sizes)
return results
def get_class_name(self, class_id):
"""Get class name for given class ID"""
return self.class_names.get(class_id, f'class_{class_id}')
class DebugVisualizer:
"""Interactive visualizer with OpenCV"""
def __init__(self, model, confidence_threshold=0.5, window_name="DEIM Debug"):
self.model = model
self.confidence_threshold = confidence_threshold
self.window_name = window_name
self.paused = False
self.show_info = True
# Create window
cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(self.window_name, 1200, 800)
print("\n=== Debug Controls ===")
print("SPACE: Pause/Resume video")
print("'q' or ESC: Quit")
print("'i': Toggle info display")
print("'+'/'-': Increase/Decrease confidence threshold")
print("'s': Save current frame")
print("'n': Next file (in folder mode)")
print("=====================\n")
def draw_detections(self, image, results, frame_info=None):
"""Draw bounding boxes and labels on image"""
vis_image = image.copy()
if len(results) == 0:
return vis_image
# Extract results
result = results[0] if isinstance(results, list) else results
labels = result['labels'].cpu().numpy()
boxes = result['boxes'].cpu().numpy()
scores = result['scores'].cpu().numpy()
# Filter by confidence threshold
valid_indices = scores >= self.confidence_threshold
labels = labels[valid_indices]
boxes = boxes[valid_indices]
scores = scores[valid_indices]
# Draw bounding boxes
for i, (box, label, score) in enumerate(zip(boxes, labels, scores)):
x1, y1, x2, y2 = box.astype(int)
# Get class name and color
class_name = self.model.get_class_name(label)
color = COLORS[label % len(COLORS)]
# Draw bounding box
cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, 2)
# Prepare label text
label_text = f'{class_name}: {score:.2f}'
# Get text size
(text_w, text_h), baseline = cv2.getTextSize(
label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
# Draw label background
cv2.rectangle(vis_image, (x1, y1 - text_h - baseline),
(x1 + text_w, y1), color, -1)
# Draw label text
cv2.putText(vis_image, label_text, (x1, y1 - baseline),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# Draw info overlay
if self.show_info:
self._draw_info_overlay(vis_image, labels, scores, frame_info)
return vis_image
def _draw_info_overlay(self, image, labels, scores, frame_info=None):
"""Draw information overlay on image"""
h, w = image.shape[:2]
overlay_y = 30
# Detection count and confidence info
info_lines = [
f"Detections: {len(labels)} (conf >= {self.confidence_threshold:.2f})",
f"Avg Confidence: {scores.mean():.3f}" if len(scores) > 0 else "Avg Confidence: N/A"
]
# Add frame info for videos
if frame_info:
info_lines.extend([
f"Frame: {frame_info.get('frame_num', 'N/A')}",
f"FPS: {frame_info.get('fps', 'N/A'):.1f}",
f"Status: {'PAUSED' if self.paused else 'PLAYING'}"
])
# Add file progress if available
if 'file_progress' in frame_info:
info_lines.append(f"File: {frame_info['file_progress']}")
# Draw background
overlay_height = len(info_lines) * 25 + 20
cv2.rectangle(image, (10, 10), (350, 10 + overlay_height),
(0, 0, 0), -1)
cv2.rectangle(image, (10, 10), (350, 10 + overlay_height),
(255, 255, 255), 1)
# Draw text
for i, line in enumerate(info_lines):
cv2.putText(image, line, (20, overlay_y + i * 25),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
def show_image(self, image, results, title=None):
"""Display single image with detections"""
vis_image = self.draw_detections(image, results)
if title:
cv2.setWindowTitle(self.window_name, f"{self.window_name} - {title}")
cv2.imshow(self.window_name, vis_image)
# Wait for key press
while True:
key = cv2.waitKey(0) & 0xFF
if key == ord('q') or key == 27: # 'q' or ESC
return False
elif key == ord('n'): # Next file
return True
elif key == ord('s'): # Save image
save_path = f"debug_output_{int(time.time())}.jpg"
cv2.imwrite(save_path, vis_image)
print(f"Image saved as {save_path}")
elif key == ord('i'): # Toggle info
self.show_info = not self.show_info
vis_image = self.draw_detections(image, results)
cv2.imshow(self.window_name, vis_image)
elif key == ord('+') or key == ord('='): # Increase threshold
self.confidence_threshold = min(1.0, self.confidence_threshold + 0.05)
print(f"Confidence threshold: {self.confidence_threshold:.2f}")
vis_image = self.draw_detections(image, results)
cv2.imshow(self.window_name, vis_image)
elif key == ord('-') or key == ord('_'): # Decrease threshold
self.confidence_threshold = max(0.0, self.confidence_threshold - 0.05)
print(f"Confidence threshold: {self.confidence_threshold:.2f}")
vis_image = self.draw_detections(image, results)
cv2.imshow(self.window_name, vis_image)
else:
break
return True
def show_video_frame(self, image, results, frame_info):
"""Display video frame with detections"""
vis_image = self.draw_detections(image, results, frame_info)
cv2.setWindowTitle(self.window_name,
f"{self.window_name} - Frame {frame_info.get('frame_num', 'N/A')}")
cv2.imshow(self.window_name, vis_image)
# Handle keyboard input
wait_time = 1 if self.paused else max(1, int(1000 / frame_info.get('fps', 30)))
key = cv2.waitKey(1) & 0xFF
if key == ord('q') or key == 27: # Quit
return False
elif key == ord('n'): # Next file (skip rest of video)
return 'next'
elif key == ord(' '): # Pause/Resume
self.paused = not self.paused
print("PAUSED" if self.paused else "RESUMED")
elif key == ord('s'): # Save frame
save_path = f"debug_frame_{frame_info.get('frame_num', int(time.time()))}.jpg"
cv2.imwrite(save_path, vis_image)
print(f"Frame saved as {save_path}")
elif key == ord('i'): # Toggle info
self.show_info = not self.show_info
elif key == ord('+') or key == ord('='): # Increase threshold
self.confidence_threshold = min(1.0, self.confidence_threshold + 0.05)
print(f"Confidence threshold: {self.confidence_threshold:.2f}")
elif key == ord('-') or key == ord('_'): # Decrease threshold
self.confidence_threshold = max(0.0, self.confidence_threshold - 0.05)
print(f"Confidence threshold: {self.confidence_threshold:.2f}")
return True
def close(self):
"""Close visualization windows"""
cv2.destroyAllWindows()
def find_media_files(folder_path):
"""Recursively find all image and video files in folder"""
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp', '.gif'}
video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.m4v', '.webm'}
media_files = []
folder_path = Path(folder_path)
if folder_path.is_file():
# Single file provided
if folder_path.suffix.lower() in image_extensions | video_extensions:
media_files.append(folder_path)
else:
# Recursively find all media files
for file_path in folder_path.rglob('*'):
if file_path.is_file() and file_path.suffix.lower() in image_extensions | video_extensions:
media_files.append(file_path)
# Sort files for consistent ordering
media_files.sort()
# Separate images and videos
images = [f for f in media_files if f.suffix.lower() in image_extensions]
videos = [f for f in media_files if f.suffix.lower() in video_extensions]
print(f"Found {len(images)} images and {len(videos)} videos")
return images, videos
def process_image(model, image_path, visualizer, input_size=640, file_index=None, total_files=None):
"""Process single image"""
progress_str = f"[{file_index + 1}/{total_files}] " if file_index is not None else ""
print(f"{progress_str}Processing image: {image_path}")
# Load and preprocess image
image = cv2.imread(str(image_path))
if image is None:
print(f"Error: Could not load image {image_path}")
return False
h, w = image.shape[:2]
orig_size = torch.tensor([[w, h]], dtype=torch.float32).to(model.device)
# Convert to PIL for transforms
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Apply transforms
transforms = T.Compose([
T.Resize((input_size, input_size)),
T.ToTensor(),
])
tensor_image = transforms(pil_image).unsqueeze(0).to(model.device)
# Run inference
start_time = time.time()
results = model(tensor_image, orig_size)
inference_time = time.time() - start_time
print(f"Inference time: {inference_time:.3f}s")
# Show results
title = f"{progress_str}{Path(image_path).name} ({inference_time:.3f}s)"
return visualizer.show_image(image, results, title)
def process_video(model, video_path, visualizer, input_size=640, file_index=None, total_files=None):
"""Process video file"""
progress_str = f"[{file_index + 1}/{total_files}] " if file_index is not None else ""
print(f"{progress_str}Processing video: {video_path}")
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
print(f"Error: Could not open video {video_path}")
return False
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Video FPS: {fps:.2f}")
print(f"Total frames: {total_frames}")
# Apply transforms
transforms = T.Compose([
T.Resize((input_size, input_size)),
T.ToTensor(),
])
frame_num = 0
start_time = time.time()
try:
while cap.isOpened():
for _ in range(1):
ret, frame = cap.read()
if not ret:
break
frame_num += 1
h, w = frame.shape[:2]
orig_size = torch.tensor([[w, h]], dtype=torch.float32).to(model.device)
# Convert to PIL for transforms
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
tensor_frame = transforms(pil_frame).unsqueeze(0).to(model.device)
# Run inference
frame_start = time.time()
results = model(tensor_frame, orig_size)
inference_time = time.time() - frame_start
# Calculate average FPS
elapsed_time = time.time() - start_time
avg_fps = frame_num / elapsed_time if elapsed_time > 0 else 0
# Prepare frame info
frame_info = {
'frame_num': frame_num,
'fps': avg_fps,
'inference_time': inference_time,
'total_frames': total_frames,
'file_progress': f"{progress_str}{Path(video_path).name}"
}
# Show frame
result = visualizer.show_video_frame(frame, results, frame_info)
if result == False:
break
elif result == 'next':
print("Skipping to next file...")
break
# Print progress periodically
if frame_num % 30 == 0:
print(f"Processed {frame_num}/{total_frames} frames, "
f"Avg FPS: {avg_fps:.1f}, "
f"Inference: {inference_time:.3f}s")
finally:
cap.release()
print(f"\nVideo processing completed!")
print(f"Total frames processed: {frame_num}")
print(f"Average FPS: {frame_num / (time.time() - start_time):.2f}")
return True
def process_folder(model, folder_path, visualizer, input_size=640, process_videos=True):
"""Process all images and videos in a folder recursively"""
print(f"Scanning folder: {folder_path}")
# Find all media files
images, videos = find_media_files(folder_path)
if not images and not videos:
print("No image or video files found!")
return False
all_files = []
# Add images first
if images:
print(f"\nFound {len(images)} images:")
for img in images[:10]: # Show first 10
print(f" {img}")
if len(images) > 10:
print(f" ... and {len(images) - 10} more")
all_files.extend([(img, 'image') for img in images])
# Add videos if requested
if videos and process_videos:
print(f"\nFound {len(videos)} videos:")
for vid in videos[:10]: # Show first 10
print(f" {vid}")
if len(videos) > 10:
print(f" ... and {len(videos) - 10} more")
all_files.extend([(vid, 'video') for vid in videos])
elif videos and not process_videos:
print(f"\nSkipping {len(videos)} videos (use --process-videos to include them)")
if not all_files:
print("No files to process!")
return False
print(f"\nProcessing {len(all_files)} files total...")
print("Use SPACE to pause/resume, 'q' to quit, 'n' for next file")
# Process all files
for i, (file_path, file_type) in enumerate(all_files):
print(f"\n{'=' * 60}")
try:
if file_type == 'image':
success = process_image(model, file_path, visualizer, input_size, i, len(all_files))
else: # video
success = process_video(model, file_path, visualizer, input_size, i, len(all_files))
if not success:
print(f"Stopping processing at user request or error")
break
except KeyboardInterrupt:
print(f"\nProcessing interrupted by user")
break
except Exception as e:
print(f"Error processing {file_path}: {e}")
import traceback
traceback.print_exc()
# Ask user if they want to continue
response = input("Continue with next file? (y/n): ")
if response.lower() != 'y':
break
print(f"\nFinished processing folder: {folder_path}")
return True
def main():
parser = argparse.ArgumentParser(description="DEIM Debug Script")
parser.add_argument('-c', '--config', type=str, required=True,
help='Path to config file')
parser.add_argument('-ckpt', '--checkpoint', type=str, required=True,
help='Path to model checkpoint')
parser.add_argument('-i', '--input', type=str, required=True,
help='Path to input image, video, or folder')
parser.add_argument('-d', '--device', type=str, default='cuda',
help='Device to use (cuda/cpu)')
parser.add_argument('--input-size', type=int, default=1600,
help='Input image size')
parser.add_argument('--conf-threshold', type=float, default=0.3,
help='Confidence threshold for detections')
parser.add_argument('--process-videos', action='store_true',
help='Process video files when scanning folders')
parser.add_argument('--images-only', action='store_true',
help='Process only images (skip videos)')
parser.add_argument('--videos-only', action='store_true',
help='Process only videos (skip images)')
args = parser.parse_args()
# Check if files exist
if not os.path.exists(args.config):
print(f"Error: Config file not found: {args.config}")
return
if not os.path.exists(args.checkpoint):
print(f"Error: Checkpoint file not found: {args.checkpoint}")
return
if not os.path.exists(args.input):
print(f"Error: Input file not found: {args.input}")
return
# Check device availability
if args.device == 'cuda' and not torch.cuda.is_available():
print("Warning: CUDA not available, using CPU")
args.device = 'cpu'
print("=== DEIM Debug Script ===")
print(f"Config: {args.config}")
print(f"Checkpoint: {args.checkpoint}")
print(f"Input: {args.input}")
print(f"Device: {args.device}")
print(f"Input size: {args.input_size}")
print(f"Confidence threshold: {args.conf_threshold}")
print("========================\n")
try:
# Initialize model
print("Loading model...")
model = DEIMModel(args.config, args.checkpoint, args.device, args.input_size)
# Initialize visualizer
visualizer = DebugVisualizer(model, args.conf_threshold)
# Determine input type and process
input_path = Path(args.input)
if input_path.is_file():
# Single file
if input_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp']:
if not args.videos_only:
success = process_image(model, args.input, visualizer, args.input_size)
else:
print("Skipping image file (videos-only mode)")
success = True
elif input_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.m4v', '.webm']:
if not args.images_only:
success = process_video(model, args.input, visualizer, args.input_size)
else:
print("Skipping video file (images-only mode)")
success = True
else:
print(f"Error: Unsupported file format: {input_path.suffix}")
success = False
elif input_path.is_dir():
# Folder - process recursively
process_videos = args.process_videos or not args.images_only
if args.videos_only:
# Only process videos, skip images
success = process_folder(model, args.input, visualizer, args.input_size, process_videos=True)
elif args.images_only:
# Only process images, skip videos
success = process_folder(model, args.input, visualizer, args.input_size, process_videos=False)
else:
# Process based on --process-videos flag
success = process_folder(model, args.input, visualizer, args.input_size, process_videos)
else:
print(f"Error: Input path does not exist: {args.input}")
success = False
if success:
print("Processing completed successfully!")
except Exception as e:
print(f"Error during processing: {e}")
import traceback
traceback.print_exc()
finally:
# Cleanup
if 'visualizer' in locals():
visualizer.close()
print("Debug session ended.")
if __name__ == '__main__':
main()