Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import subprocess | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| import sys | |
| import importlib.util | |
| # Ensure models directory exists | |
| MODELS_DIR = Path("models") | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| # Create permanent output directory | |
| OUTPUT_DIR = Path("outputs") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| def ensure_dependencies(): | |
| """Ensure all required dependencies are installed.""" | |
| required_packages = [ | |
| "ultralytics", | |
| "boxmot" | |
| ] | |
| for package in required_packages: | |
| try: | |
| importlib.import_module(package) | |
| print(f"β {package} is installed") | |
| except ImportError: | |
| print(f"β οΈ {package} is not installed, attempting to install...") | |
| subprocess.run([sys.executable, "-m", "pip", "install", package], check=True) | |
| # Apply tracker patches if tracker_patch.py exists | |
| def apply_patches(): | |
| patch_path = Path("tracker_patch.py") | |
| if patch_path.exists(): | |
| spec = importlib.util.spec_from_file_location("tracker_patch", patch_path) | |
| if spec: | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| if hasattr(module, "patch_trackers"): | |
| module.patch_trackers() | |
| print("β Applied tracker patches") | |
| else: | |
| print("β οΈ tracker_patch.py exists but has no patch_trackers function") | |
| else: | |
| print("β οΈ tracker_patch.py not found, skipping patches") | |
| def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold): | |
| """Run object tracking on the uploaded video.""" | |
| try: | |
| # Create temporary workspace | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Prepare input | |
| input_path = os.path.join(temp_dir, "input_video.mp4") | |
| shutil.copy(video_file, input_path) | |
| # Prepare output directory | |
| output_dir = os.path.join(temp_dir, "output") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Build command | |
| cmd = [ | |
| "python", "tracking/track.py", | |
| "--yolo-model", str(MODELS_DIR / yolo_model), | |
| "--reid-model", str(MODELS_DIR / reid_model), | |
| "--tracking-method", tracking_method, | |
| "--source", input_path, | |
| "--conf", str(conf_threshold), | |
| "--save", | |
| "--project", output_dir, | |
| "--name", "track", | |
| "--exist-ok" | |
| ] | |
| # Add class filtering if specific classes are provided | |
| if class_ids and class_ids.strip(): | |
| # Parse the comma-separated class IDs | |
| try: | |
| # Split by comma and convert to integers to validate | |
| class_list = [int(c.strip()) for c in class_ids.split(",") if c.strip()] | |
| # Add each class ID as a separate argument | |
| if class_list: | |
| cmd.append("--classes") | |
| cmd.extend(str(c) for c in class_list) | |
| except ValueError: | |
| return None, "Invalid class IDs. Please enter comma-separated numbers (e.g., '0,1,2')." | |
| # Special handling for OcSort | |
| if tracking_method == "ocsort": | |
| cmd.append("--per-class") | |
| # Execute tracking with error handling | |
| print(f"Executing command: {' '.join(cmd)}") | |
| process = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True | |
| ) | |
| # Check for errors in output | |
| if process.returncode != 0: | |
| error_message = process.stderr or process.stdout | |
| print(f"Process failed with return code {process.returncode}") | |
| print(f"Error: {error_message}") | |
| return None, f"Error in tracking process: {error_message}" | |
| print(f"Process completed with return code {process.returncode}") | |
| # Find output video | |
| output_files = [] | |
| for root, _, files in os.walk(output_dir): | |
| for file in files: | |
| if file.lower().endswith((".mp4", ".avi", ".mov")): | |
| output_files.append(os.path.join(root, file)) | |
| print(f"Found output files: {output_files}") | |
| if not output_files: | |
| print("No output video files found") | |
| return None, "No output video was generated. Check if tracking was successful." | |
| output_file = output_files[0] | |
| print(f"Selected output file: {output_file}") | |
| # Verify file exists and has size | |
| if os.path.exists(output_file): | |
| file_size = os.path.getsize(output_file) | |
| print(f"Output file exists with size: {file_size} bytes") | |
| if file_size == 0: | |
| return None, "Output video was generated but has zero size." | |
| # Copy to permanent location with unique name | |
| permanent_path = os.path.join(OUTPUT_DIR, f"output_{os.path.basename(video_file)}") | |
| shutil.copy(output_file, permanent_path) | |
| print(f"Copied output to permanent location: {permanent_path}") | |
| # Ensure the file is in MP4 format for better compatibility with Gradio | |
| if not permanent_path.lower().endswith('.mp4'): | |
| mp4_path = os.path.splitext(permanent_path)[0] + '.mp4' | |
| try: | |
| print(f"Converting to MP4 format: {mp4_path}") | |
| subprocess.run([ | |
| 'ffmpeg', '-i', permanent_path, | |
| '-c:v', 'libx264', '-preset', 'fast', | |
| '-c:a', 'aac', mp4_path | |
| ], check=True, capture_output=True) | |
| os.remove(permanent_path) # Remove the original file | |
| permanent_path = mp4_path | |
| except Exception as e: | |
| print(f"Failed to convert to MP4: {str(e)}") | |
| # Continue with original file if conversion fails | |
| return permanent_path, "Processing completed successfully!" | |
| else: | |
| print(f"Output file not found at {output_file}") | |
| return None, "Output file was referenced but doesn't exist on disk." | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"Error: {str(e)}" | |
| # Define the Gradio interface | |
| def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold): | |
| # Validate inputs | |
| if not video_path: | |
| return None, "Please upload a video file" | |
| print(f"Processing video: {video_path}") | |
| print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={class_ids}, conf={conf_threshold}") | |
| output_path, status = run_tracking( | |
| video_path, | |
| yolo_model, | |
| reid_model, | |
| tracking_method, | |
| class_ids, | |
| conf_threshold | |
| ) | |
| if output_path: | |
| print(f"Returning output path: {output_path}") | |
| # Make sure the path is absolute for Gradio | |
| abs_path = os.path.abspath(output_path) | |
| return abs_path, status | |
| else: | |
| print(f"No output path available. Status: {status}") | |
| return None, status | |
| # Available models and tracking methods | |
| yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"] | |
| reid_models = ["osnet_x0_25_msmt17.pt"] | |
| tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"] | |
| # Ensure dependencies and apply patches at startup | |
| ensure_dependencies() | |
| apply_patches() | |
| # Create the Gradio interface | |
| with gr.Blocks(title="π Object Tracking") as app: | |
| gr.Markdown("# π Object Tracking") | |
| gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.") | |
| # Add class reference information | |
| with gr.Accordion("YOLO Class Reference", open=False): | |
| gr.Markdown(""" | |
| # YOLO Class IDs Reference | |
| Enter the class IDs as comma-separated numbers in the "Target Classes" field. | |
| Leave empty to track all classes. | |
| ## Common Class IDs: | |
| - 0: person | |
| - 1: bicycle | |
| - 2: car | |
| - 3: motorcycle | |
| - 5: bus | |
| - 7: truck | |
| - 16: dog | |
| - 17: horse | |
| - 67: cell phone | |
| [See full COCO class list here](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/datasets/coco128.yaml) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_video = gr.Video(label="Input Video", sources=["upload"]) | |
| with gr.Group(): | |
| yolo_model = gr.Dropdown( | |
| choices=yolo_models, | |
| value="yolov8n.pt", | |
| label="YOLO Model" | |
| ) | |
| reid_model = gr.Dropdown( | |
| choices=reid_models, | |
| value="osnet_x0_25_msmt17.pt", | |
| label="ReID Model" | |
| ) | |
| tracking_method = gr.Dropdown( | |
| choices=tracking_methods, | |
| value="bytetrack", | |
| label="Tracking Method" | |
| ) | |
| # Class ID input field | |
| class_ids = gr.Textbox( | |
| value="", | |
| label="Target Classes (comma-separated IDs, e.g. '0,2,5', leave empty for all classes)", | |
| placeholder="e.g. 0,2,5" | |
| ) | |
| conf_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.3, | |
| step=0.05, | |
| label="Confidence Threshold" | |
| ) | |
| process_btn = gr.Button("Process Video", variant="primary") | |
| with gr.Column(scale=1): | |
| output_video = gr.Video(label="Output Video with Tracking") | |
| status_text = gr.Textbox(label="Status", value="Ready to process video") | |
| process_btn.click( | |
| fn=process_video, | |
| inputs=[input_video, yolo_model, reid_model, tracking_method, class_ids, conf_threshold], | |
| outputs=[output_video, status_text] | |
| ) | |
| # Add a debug section | |
| with gr.Accordion("Debug Information", open=False): | |
| debug_text = gr.Textbox(label="Debug Log", lines=10, interactive=False) | |
| def check_environment(): | |
| info = [] | |
| # Check Python version | |
| info.append(f"Python version: {sys.version}") | |
| # Check Gradio version | |
| info.append(f"Gradio version: {gr.__version__}") | |
| # Check for ffmpeg | |
| try: | |
| ffmpeg_version = subprocess.run(["ffmpeg", "-version"], capture_output=True, text=True) | |
| info.append("ffmpeg: Installed") | |
| except: | |
| info.append("ffmpeg: Not found") | |
| # Check tracking directory | |
| if os.path.exists("tracking"): | |
| info.append("tracking directory: Found") | |
| else: | |
| info.append("tracking directory: Not found") | |
| # Check models | |
| info.append("Models:") | |
| for model in os.listdir(MODELS_DIR) if os.path.exists(MODELS_DIR) else []: | |
| info.append(f" - {model}") | |
| return "\n".join(info) | |
| check_btn = gr.Button("Check Environment") | |
| check_btn.click(fn=check_environment, outputs=debug_text) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch(debug=True, share=True) |