Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| DEIM Debug Script for Interactive Bbox Detection and Visualization | |
| Copyright (c) 2024 The DEIM Authors. All Rights Reserved. | |
| This script provides interactive debugging capabilities for DEIM models: | |
| - Load model from config and checkpoint | |
| - Process images and videos | |
| - Interactive OpenCV visualization with imshow | |
| - Adjustable confidence thresholds | |
| - Keyboard controls for video playback | |
| """ | |
| import argparse | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| # Add the project root to Python path | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from engine.core import YAMLConfig | |
| # Default class names - will be overridden by dataset configuration | |
| DEFAULT_CLASSES = { | |
| 1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorbike', 5: 'aeroplane', | |
| 6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'trafficlight', | |
| 11: 'firehydrant', 13: 'stopsign', 14: 'parkingmeter', 15: 'bench', | |
| 16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', | |
| 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', | |
| 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', | |
| 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard', | |
| 37: 'sportsball', 38: 'kite', 39: 'baseballbat', 40: 'baseballglove', | |
| 41: 'skateboard', 42: 'surfboard', 43: 'tennisracket', 44: 'bottle', | |
| 46: 'wineglass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon', | |
| 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', | |
| 56: 'broccoli', 57: 'carrot', 58: 'hotdog', 59: 'pizza', 60: 'donut', | |
| 61: 'cake', 62: 'chair', 63: 'sofa', 64: 'pottedplant', 65: 'bed', | |
| 67: 'diningtable', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', | |
| 75: 'remote', 76: 'keyboard', 77: 'cellphone', 78: 'microwave', | |
| 79: 'oven', 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', | |
| 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddybear', | |
| 89: 'hairdrier', 90: 'toothbrush' | |
| } | |
| def load_class_names_from_config(cfg): | |
| """Load class names from dataset configuration""" | |
| try: | |
| # Import here to avoid circular imports | |
| from engine.data.dataset.coco_dataset import mscoco_category2name | |
| # Check if we can access dataset configuration | |
| if hasattr(cfg, 'val_dataloader') and cfg.val_dataloader is not None: | |
| dataset_cfg = cfg.val_dataloader.dataset | |
| # Try to instantiate dataset to get class names | |
| try: | |
| # Get the number of classes from config | |
| num_classes = getattr(cfg, 'num_classes', 80) | |
| # Check if using COCO remapping | |
| remap_mscoco = getattr(cfg, 'remap_mscoco_category', False) | |
| if remap_mscoco: | |
| print(f"Using COCO class names (remapped)") | |
| return mscoco_category2name | |
| # Try to create dataset instance to get category names | |
| if hasattr(dataset_cfg, 'ann_file') and dataset_cfg.ann_file: | |
| # For COCO-style datasets, try to load annotations | |
| try: | |
| from pycocotools.coco import COCO | |
| if os.path.exists(dataset_cfg.ann_file): | |
| coco = COCO(dataset_cfg.ann_file) | |
| categories = coco.dataset.get('categories', []) | |
| if categories: | |
| class_names = {} | |
| for i, cat in enumerate(categories): | |
| # Use category ID as key for proper mapping | |
| class_names[cat['id']] = cat['name'] | |
| print(f"Loaded {len(class_names)} class names from annotation file") | |
| return class_names | |
| except Exception as e: | |
| print(f"Could not load classes from annotation file: {e}") | |
| # Generate generic class names based on number of classes | |
| print(f"Generating generic class names for {num_classes} classes") | |
| if num_classes == 80: | |
| return mscoco_category2name | |
| elif num_classes == 1: | |
| return {1: 'object'} | |
| elif num_classes == 2: | |
| return {1: 'person', 2: 'object'} # Common for crowd detection | |
| elif num_classes == 20: | |
| # VOC classes | |
| voc_classes = { | |
| 1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat', 5: 'bottle', | |
| 6: 'bus', 7: 'car', 8: 'cat', 9: 'chair', 10: 'cow', | |
| 11: 'diningtable', 12: 'dog', 13: 'horse', 14: 'motorbike', 15: 'person', | |
| 16: 'pottedplant', 17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor' | |
| } | |
| return voc_classes | |
| else: | |
| # Generic class names | |
| return {i + 1: f'class_{i + 1}' for i in range(num_classes)} | |
| except Exception as e: | |
| print(f"Could not instantiate dataset: {e}") | |
| except Exception as e: | |
| print(f"Could not load class names from config: {e}") | |
| # Fallback to default COCO classes | |
| print("Using default COCO class names") | |
| return DEFAULT_CLASSES | |
| # Color palette for bounding boxes (BGR format for OpenCV) | |
| COLORS = [ | |
| (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), | |
| (0, 255, 255), (128, 0, 0), (0, 128, 0), (0, 0, 128), (128, 128, 0), | |
| (128, 0, 128), (0, 128, 128), (255, 128, 0), (255, 0, 128), (128, 255, 0), | |
| (0, 255, 128), (128, 0, 255), (0, 128, 255), (192, 192, 192), (64, 64, 64) | |
| ] | |
| class DEIMModel(nn.Module): | |
| """Wrapper for DEIM model with postprocessing""" | |
| def __init__(self, config_path, checkpoint_path, device='cuda', input_size=640): | |
| super().__init__() | |
| self.device = device | |
| config_overrides = {'HGNetv2': {'pretrained': False}} | |
| self.cfg = YAMLConfig(config_path, resume=checkpoint_path, **config_overrides) | |
| print(f"Loading checkpoint from: {checkpoint_path}") | |
| state_dict = torch.load(checkpoint_path, map_location='cpu')['model'] | |
| self.cfg.model.load_state_dict(state_dict) | |
| self.model = self.cfg.model.eval().to(device) | |
| self.postprocessor = self.cfg.postprocessor.eval().to(device) | |
| self.class_names = load_class_names_from_config(self.cfg) | |
| self.num_classes = getattr(self.cfg, 'num_classes', len(self.class_names)) | |
| print(f"Model loaded successfully on {device}") | |
| print(f"Model type: {type(self.model).__name__}") | |
| print(f"Number of classes: {self.num_classes}") | |
| print(f"Sample classes: {dict(list(self.class_names.items())[:5])}...") | |
| def forward(self, images, orig_sizes): | |
| """Forward pass through model and postprocessor""" | |
| with torch.no_grad(): | |
| outputs = self.model(images) | |
| results = self.postprocessor(outputs, orig_sizes) | |
| return results | |
| def get_class_name(self, class_id): | |
| """Get class name for given class ID""" | |
| return self.class_names.get(class_id, f'class_{class_id}') | |
| class DebugVisualizer: | |
| """Interactive visualizer with OpenCV""" | |
| def __init__(self, model, confidence_threshold=0.5, window_name="DEIM Debug"): | |
| self.model = model | |
| self.confidence_threshold = confidence_threshold | |
| self.window_name = window_name | |
| self.paused = False | |
| self.show_info = True | |
| # Create window | |
| cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL) | |
| cv2.resizeWindow(self.window_name, 1200, 800) | |
| print("\n=== Debug Controls ===") | |
| print("SPACE: Pause/Resume video") | |
| print("'q' or ESC: Quit") | |
| print("'i': Toggle info display") | |
| print("'+'/'-': Increase/Decrease confidence threshold") | |
| print("'s': Save current frame") | |
| print("'n': Next file (in folder mode)") | |
| print("=====================\n") | |
| def draw_detections(self, image, results, frame_info=None): | |
| """Draw bounding boxes and labels on image""" | |
| vis_image = image.copy() | |
| if len(results) == 0: | |
| return vis_image | |
| # Extract results | |
| result = results[0] if isinstance(results, list) else results | |
| labels = result['labels'].cpu().numpy() | |
| boxes = result['boxes'].cpu().numpy() | |
| scores = result['scores'].cpu().numpy() | |
| # Filter by confidence threshold | |
| valid_indices = scores >= self.confidence_threshold | |
| labels = labels[valid_indices] | |
| boxes = boxes[valid_indices] | |
| scores = scores[valid_indices] | |
| # Draw bounding boxes | |
| for i, (box, label, score) in enumerate(zip(boxes, labels, scores)): | |
| x1, y1, x2, y2 = box.astype(int) | |
| # Get class name and color | |
| class_name = self.model.get_class_name(label) | |
| color = COLORS[label % len(COLORS)] | |
| # Draw bounding box | |
| cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, 2) | |
| # Prepare label text | |
| label_text = f'{class_name}: {score:.2f}' | |
| # Get text size | |
| (text_w, text_h), baseline = cv2.getTextSize( | |
| label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
| # Draw label background | |
| cv2.rectangle(vis_image, (x1, y1 - text_h - baseline), | |
| (x1 + text_w, y1), color, -1) | |
| # Draw label text | |
| cv2.putText(vis_image, label_text, (x1, y1 - baseline), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| # Draw info overlay | |
| if self.show_info: | |
| self._draw_info_overlay(vis_image, labels, scores, frame_info) | |
| return vis_image | |
| def _draw_info_overlay(self, image, labels, scores, frame_info=None): | |
| """Draw information overlay on image""" | |
| h, w = image.shape[:2] | |
| overlay_y = 30 | |
| # Detection count and confidence info | |
| info_lines = [ | |
| f"Detections: {len(labels)} (conf >= {self.confidence_threshold:.2f})", | |
| f"Avg Confidence: {scores.mean():.3f}" if len(scores) > 0 else "Avg Confidence: N/A" | |
| ] | |
| # Add frame info for videos | |
| if frame_info: | |
| info_lines.extend([ | |
| f"Frame: {frame_info.get('frame_num', 'N/A')}", | |
| f"FPS: {frame_info.get('fps', 'N/A'):.1f}", | |
| f"Status: {'PAUSED' if self.paused else 'PLAYING'}" | |
| ]) | |
| # Add file progress if available | |
| if 'file_progress' in frame_info: | |
| info_lines.append(f"File: {frame_info['file_progress']}") | |
| # Draw background | |
| overlay_height = len(info_lines) * 25 + 20 | |
| cv2.rectangle(image, (10, 10), (350, 10 + overlay_height), | |
| (0, 0, 0), -1) | |
| cv2.rectangle(image, (10, 10), (350, 10 + overlay_height), | |
| (255, 255, 255), 1) | |
| # Draw text | |
| for i, line in enumerate(info_lines): | |
| cv2.putText(image, line, (20, overlay_y + i * 25), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) | |
| def show_image(self, image, results, title=None): | |
| """Display single image with detections""" | |
| vis_image = self.draw_detections(image, results) | |
| if title: | |
| cv2.setWindowTitle(self.window_name, f"{self.window_name} - {title}") | |
| cv2.imshow(self.window_name, vis_image) | |
| # Wait for key press | |
| while True: | |
| key = cv2.waitKey(0) & 0xFF | |
| if key == ord('q') or key == 27: # 'q' or ESC | |
| return False | |
| elif key == ord('n'): # Next file | |
| return True | |
| elif key == ord('s'): # Save image | |
| save_path = f"debug_output_{int(time.time())}.jpg" | |
| cv2.imwrite(save_path, vis_image) | |
| print(f"Image saved as {save_path}") | |
| elif key == ord('i'): # Toggle info | |
| self.show_info = not self.show_info | |
| vis_image = self.draw_detections(image, results) | |
| cv2.imshow(self.window_name, vis_image) | |
| elif key == ord('+') or key == ord('='): # Increase threshold | |
| self.confidence_threshold = min(1.0, self.confidence_threshold + 0.05) | |
| print(f"Confidence threshold: {self.confidence_threshold:.2f}") | |
| vis_image = self.draw_detections(image, results) | |
| cv2.imshow(self.window_name, vis_image) | |
| elif key == ord('-') or key == ord('_'): # Decrease threshold | |
| self.confidence_threshold = max(0.0, self.confidence_threshold - 0.05) | |
| print(f"Confidence threshold: {self.confidence_threshold:.2f}") | |
| vis_image = self.draw_detections(image, results) | |
| cv2.imshow(self.window_name, vis_image) | |
| else: | |
| break | |
| return True | |
| def show_video_frame(self, image, results, frame_info): | |
| """Display video frame with detections""" | |
| vis_image = self.draw_detections(image, results, frame_info) | |
| cv2.setWindowTitle(self.window_name, | |
| f"{self.window_name} - Frame {frame_info.get('frame_num', 'N/A')}") | |
| cv2.imshow(self.window_name, vis_image) | |
| # Handle keyboard input | |
| wait_time = 1 if self.paused else max(1, int(1000 / frame_info.get('fps', 30))) | |
| key = cv2.waitKey(1) & 0xFF | |
| if key == ord('q') or key == 27: # Quit | |
| return False | |
| elif key == ord('n'): # Next file (skip rest of video) | |
| return 'next' | |
| elif key == ord(' '): # Pause/Resume | |
| self.paused = not self.paused | |
| print("PAUSED" if self.paused else "RESUMED") | |
| elif key == ord('s'): # Save frame | |
| save_path = f"debug_frame_{frame_info.get('frame_num', int(time.time()))}.jpg" | |
| cv2.imwrite(save_path, vis_image) | |
| print(f"Frame saved as {save_path}") | |
| elif key == ord('i'): # Toggle info | |
| self.show_info = not self.show_info | |
| elif key == ord('+') or key == ord('='): # Increase threshold | |
| self.confidence_threshold = min(1.0, self.confidence_threshold + 0.05) | |
| print(f"Confidence threshold: {self.confidence_threshold:.2f}") | |
| elif key == ord('-') or key == ord('_'): # Decrease threshold | |
| self.confidence_threshold = max(0.0, self.confidence_threshold - 0.05) | |
| print(f"Confidence threshold: {self.confidence_threshold:.2f}") | |
| return True | |
| def close(self): | |
| """Close visualization windows""" | |
| cv2.destroyAllWindows() | |
| def find_media_files(folder_path): | |
| """Recursively find all image and video files in folder""" | |
| image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp', '.gif'} | |
| video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.m4v', '.webm'} | |
| media_files = [] | |
| folder_path = Path(folder_path) | |
| if folder_path.is_file(): | |
| # Single file provided | |
| if folder_path.suffix.lower() in image_extensions | video_extensions: | |
| media_files.append(folder_path) | |
| else: | |
| # Recursively find all media files | |
| for file_path in folder_path.rglob('*'): | |
| if file_path.is_file() and file_path.suffix.lower() in image_extensions | video_extensions: | |
| media_files.append(file_path) | |
| # Sort files for consistent ordering | |
| media_files.sort() | |
| # Separate images and videos | |
| images = [f for f in media_files if f.suffix.lower() in image_extensions] | |
| videos = [f for f in media_files if f.suffix.lower() in video_extensions] | |
| print(f"Found {len(images)} images and {len(videos)} videos") | |
| return images, videos | |
| def process_image(model, image_path, visualizer, input_size=640, file_index=None, total_files=None): | |
| """Process single image""" | |
| progress_str = f"[{file_index + 1}/{total_files}] " if file_index is not None else "" | |
| print(f"{progress_str}Processing image: {image_path}") | |
| # Load and preprocess image | |
| image = cv2.imread(str(image_path)) | |
| if image is None: | |
| print(f"Error: Could not load image {image_path}") | |
| return False | |
| h, w = image.shape[:2] | |
| orig_size = torch.tensor([[w, h]], dtype=torch.float32).to(model.device) | |
| # Convert to PIL for transforms | |
| pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
| # Apply transforms | |
| transforms = T.Compose([ | |
| T.Resize((input_size, input_size)), | |
| T.ToTensor(), | |
| ]) | |
| tensor_image = transforms(pil_image).unsqueeze(0).to(model.device) | |
| # Run inference | |
| start_time = time.time() | |
| results = model(tensor_image, orig_size) | |
| inference_time = time.time() - start_time | |
| print(f"Inference time: {inference_time:.3f}s") | |
| # Show results | |
| title = f"{progress_str}{Path(image_path).name} ({inference_time:.3f}s)" | |
| return visualizer.show_image(image, results, title) | |
| def process_video(model, video_path, visualizer, input_size=640, file_index=None, total_files=None): | |
| """Process video file""" | |
| progress_str = f"[{file_index + 1}/{total_files}] " if file_index is not None else "" | |
| print(f"{progress_str}Processing video: {video_path}") | |
| cap = cv2.VideoCapture(str(video_path)) | |
| if not cap.isOpened(): | |
| print(f"Error: Could not open video {video_path}") | |
| return False | |
| # Get video properties | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| print(f"Video FPS: {fps:.2f}") | |
| print(f"Total frames: {total_frames}") | |
| # Apply transforms | |
| transforms = T.Compose([ | |
| T.Resize((input_size, input_size)), | |
| T.ToTensor(), | |
| ]) | |
| frame_num = 0 | |
| start_time = time.time() | |
| try: | |
| while cap.isOpened(): | |
| for _ in range(1): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_num += 1 | |
| h, w = frame.shape[:2] | |
| orig_size = torch.tensor([[w, h]], dtype=torch.float32).to(model.device) | |
| # Convert to PIL for transforms | |
| pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| tensor_frame = transforms(pil_frame).unsqueeze(0).to(model.device) | |
| # Run inference | |
| frame_start = time.time() | |
| results = model(tensor_frame, orig_size) | |
| inference_time = time.time() - frame_start | |
| # Calculate average FPS | |
| elapsed_time = time.time() - start_time | |
| avg_fps = frame_num / elapsed_time if elapsed_time > 0 else 0 | |
| # Prepare frame info | |
| frame_info = { | |
| 'frame_num': frame_num, | |
| 'fps': avg_fps, | |
| 'inference_time': inference_time, | |
| 'total_frames': total_frames, | |
| 'file_progress': f"{progress_str}{Path(video_path).name}" | |
| } | |
| # Show frame | |
| result = visualizer.show_video_frame(frame, results, frame_info) | |
| if result == False: | |
| break | |
| elif result == 'next': | |
| print("Skipping to next file...") | |
| break | |
| # Print progress periodically | |
| if frame_num % 30 == 0: | |
| print(f"Processed {frame_num}/{total_frames} frames, " | |
| f"Avg FPS: {avg_fps:.1f}, " | |
| f"Inference: {inference_time:.3f}s") | |
| finally: | |
| cap.release() | |
| print(f"\nVideo processing completed!") | |
| print(f"Total frames processed: {frame_num}") | |
| print(f"Average FPS: {frame_num / (time.time() - start_time):.2f}") | |
| return True | |
| def process_folder(model, folder_path, visualizer, input_size=640, process_videos=True): | |
| """Process all images and videos in a folder recursively""" | |
| print(f"Scanning folder: {folder_path}") | |
| # Find all media files | |
| images, videos = find_media_files(folder_path) | |
| if not images and not videos: | |
| print("No image or video files found!") | |
| return False | |
| all_files = [] | |
| # Add images first | |
| if images: | |
| print(f"\nFound {len(images)} images:") | |
| for img in images[:10]: # Show first 10 | |
| print(f" {img}") | |
| if len(images) > 10: | |
| print(f" ... and {len(images) - 10} more") | |
| all_files.extend([(img, 'image') for img in images]) | |
| # Add videos if requested | |
| if videos and process_videos: | |
| print(f"\nFound {len(videos)} videos:") | |
| for vid in videos[:10]: # Show first 10 | |
| print(f" {vid}") | |
| if len(videos) > 10: | |
| print(f" ... and {len(videos) - 10} more") | |
| all_files.extend([(vid, 'video') for vid in videos]) | |
| elif videos and not process_videos: | |
| print(f"\nSkipping {len(videos)} videos (use --process-videos to include them)") | |
| if not all_files: | |
| print("No files to process!") | |
| return False | |
| print(f"\nProcessing {len(all_files)} files total...") | |
| print("Use SPACE to pause/resume, 'q' to quit, 'n' for next file") | |
| # Process all files | |
| for i, (file_path, file_type) in enumerate(all_files): | |
| print(f"\n{'=' * 60}") | |
| try: | |
| if file_type == 'image': | |
| success = process_image(model, file_path, visualizer, input_size, i, len(all_files)) | |
| else: # video | |
| success = process_video(model, file_path, visualizer, input_size, i, len(all_files)) | |
| if not success: | |
| print(f"Stopping processing at user request or error") | |
| break | |
| except KeyboardInterrupt: | |
| print(f"\nProcessing interrupted by user") | |
| break | |
| except Exception as e: | |
| print(f"Error processing {file_path}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Ask user if they want to continue | |
| response = input("Continue with next file? (y/n): ") | |
| if response.lower() != 'y': | |
| break | |
| print(f"\nFinished processing folder: {folder_path}") | |
| return True | |
| def main(): | |
| parser = argparse.ArgumentParser(description="DEIM Debug Script") | |
| parser.add_argument('-c', '--config', type=str, required=True, | |
| help='Path to config file') | |
| parser.add_argument('-ckpt', '--checkpoint', type=str, required=True, | |
| help='Path to model checkpoint') | |
| parser.add_argument('-i', '--input', type=str, required=True, | |
| help='Path to input image, video, or folder') | |
| parser.add_argument('-d', '--device', type=str, default='cuda', | |
| help='Device to use (cuda/cpu)') | |
| parser.add_argument('--input-size', type=int, default=1600, | |
| help='Input image size') | |
| parser.add_argument('--conf-threshold', type=float, default=0.3, | |
| help='Confidence threshold for detections') | |
| parser.add_argument('--process-videos', action='store_true', | |
| help='Process video files when scanning folders') | |
| parser.add_argument('--images-only', action='store_true', | |
| help='Process only images (skip videos)') | |
| parser.add_argument('--videos-only', action='store_true', | |
| help='Process only videos (skip images)') | |
| args = parser.parse_args() | |
| # Check if files exist | |
| if not os.path.exists(args.config): | |
| print(f"Error: Config file not found: {args.config}") | |
| return | |
| if not os.path.exists(args.checkpoint): | |
| print(f"Error: Checkpoint file not found: {args.checkpoint}") | |
| return | |
| if not os.path.exists(args.input): | |
| print(f"Error: Input file not found: {args.input}") | |
| return | |
| # Check device availability | |
| if args.device == 'cuda' and not torch.cuda.is_available(): | |
| print("Warning: CUDA not available, using CPU") | |
| args.device = 'cpu' | |
| print("=== DEIM Debug Script ===") | |
| print(f"Config: {args.config}") | |
| print(f"Checkpoint: {args.checkpoint}") | |
| print(f"Input: {args.input}") | |
| print(f"Device: {args.device}") | |
| print(f"Input size: {args.input_size}") | |
| print(f"Confidence threshold: {args.conf_threshold}") | |
| print("========================\n") | |
| try: | |
| # Initialize model | |
| print("Loading model...") | |
| model = DEIMModel(args.config, args.checkpoint, args.device, args.input_size) | |
| # Initialize visualizer | |
| visualizer = DebugVisualizer(model, args.conf_threshold) | |
| # Determine input type and process | |
| input_path = Path(args.input) | |
| if input_path.is_file(): | |
| # Single file | |
| if input_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp']: | |
| if not args.videos_only: | |
| success = process_image(model, args.input, visualizer, args.input_size) | |
| else: | |
| print("Skipping image file (videos-only mode)") | |
| success = True | |
| elif input_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.m4v', '.webm']: | |
| if not args.images_only: | |
| success = process_video(model, args.input, visualizer, args.input_size) | |
| else: | |
| print("Skipping video file (images-only mode)") | |
| success = True | |
| else: | |
| print(f"Error: Unsupported file format: {input_path.suffix}") | |
| success = False | |
| elif input_path.is_dir(): | |
| # Folder - process recursively | |
| process_videos = args.process_videos or not args.images_only | |
| if args.videos_only: | |
| # Only process videos, skip images | |
| success = process_folder(model, args.input, visualizer, args.input_size, process_videos=True) | |
| elif args.images_only: | |
| # Only process images, skip videos | |
| success = process_folder(model, args.input, visualizer, args.input_size, process_videos=False) | |
| else: | |
| # Process based on --process-videos flag | |
| success = process_folder(model, args.input, visualizer, args.input_size, process_videos) | |
| else: | |
| print(f"Error: Input path does not exist: {args.input}") | |
| success = False | |
| if success: | |
| print("Processing completed successfully!") | |
| except Exception as e: | |
| print(f"Error during processing: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| finally: | |
| # Cleanup | |
| if 'visualizer' in locals(): | |
| visualizer.close() | |
| print("Debug session ended.") | |
| if __name__ == '__main__': | |
| main() | |