usiddiquee786's picture
Update app.py
86bb37d verified
import os
import gradio as gr
import subprocess
import tempfile
import shutil
from pathlib import Path
import sys
import importlib.util
# Ensure models directory exists
MODELS_DIR = Path("models")
os.makedirs(MODELS_DIR, exist_ok=True)
# Create permanent output directory
OUTPUT_DIR = Path("outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)
def ensure_dependencies():
"""Ensure all required dependencies are installed."""
required_packages = [
"ultralytics",
"boxmot"
]
for package in required_packages:
try:
importlib.import_module(package)
print(f"βœ… {package} is installed")
except ImportError:
print(f"⚠️ {package} is not installed, attempting to install...")
subprocess.run([sys.executable, "-m", "pip", "install", package], check=True)
# Apply tracker patches if tracker_patch.py exists
def apply_patches():
patch_path = Path("tracker_patch.py")
if patch_path.exists():
spec = importlib.util.spec_from_file_location("tracker_patch", patch_path)
if spec:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if hasattr(module, "patch_trackers"):
module.patch_trackers()
print("βœ… Applied tracker patches")
else:
print("⚠️ tracker_patch.py exists but has no patch_trackers function")
else:
print("⚠️ tracker_patch.py not found, skipping patches")
def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold):
"""Run object tracking on the uploaded video."""
try:
# Create temporary workspace
with tempfile.TemporaryDirectory() as temp_dir:
# Prepare input
input_path = os.path.join(temp_dir, "input_video.mp4")
shutil.copy(video_file, input_path)
# Prepare output directory
output_dir = os.path.join(temp_dir, "output")
os.makedirs(output_dir, exist_ok=True)
# Build command
cmd = [
"python", "tracking/track.py",
"--yolo-model", str(MODELS_DIR / yolo_model),
"--reid-model", str(MODELS_DIR / reid_model),
"--tracking-method", tracking_method,
"--source", input_path,
"--conf", str(conf_threshold),
"--save",
"--project", output_dir,
"--name", "track",
"--exist-ok"
]
# Add class filtering if specific classes are provided
if class_ids and class_ids.strip():
# Parse the comma-separated class IDs
try:
# Split by comma and convert to integers to validate
class_list = [int(c.strip()) for c in class_ids.split(",") if c.strip()]
# Add each class ID as a separate argument
if class_list:
cmd.append("--classes")
cmd.extend(str(c) for c in class_list)
except ValueError:
return None, "Invalid class IDs. Please enter comma-separated numbers (e.g., '0,1,2')."
# Special handling for OcSort
if tracking_method == "ocsort":
cmd.append("--per-class")
# Execute tracking with error handling
print(f"Executing command: {' '.join(cmd)}")
process = subprocess.run(
cmd,
capture_output=True,
text=True
)
# Check for errors in output
if process.returncode != 0:
error_message = process.stderr or process.stdout
print(f"Process failed with return code {process.returncode}")
print(f"Error: {error_message}")
return None, f"Error in tracking process: {error_message}"
print(f"Process completed with return code {process.returncode}")
# Find output video
output_files = []
for root, _, files in os.walk(output_dir):
for file in files:
if file.lower().endswith((".mp4", ".avi", ".mov")):
output_files.append(os.path.join(root, file))
print(f"Found output files: {output_files}")
if not output_files:
print("No output video files found")
return None, "No output video was generated. Check if tracking was successful."
output_file = output_files[0]
print(f"Selected output file: {output_file}")
# Verify file exists and has size
if os.path.exists(output_file):
file_size = os.path.getsize(output_file)
print(f"Output file exists with size: {file_size} bytes")
if file_size == 0:
return None, "Output video was generated but has zero size."
# Copy to permanent location with unique name
permanent_path = os.path.join(OUTPUT_DIR, f"output_{os.path.basename(video_file)}")
shutil.copy(output_file, permanent_path)
print(f"Copied output to permanent location: {permanent_path}")
# Ensure the file is in MP4 format for better compatibility with Gradio
if not permanent_path.lower().endswith('.mp4'):
mp4_path = os.path.splitext(permanent_path)[0] + '.mp4'
try:
print(f"Converting to MP4 format: {mp4_path}")
subprocess.run([
'ffmpeg', '-i', permanent_path,
'-c:v', 'libx264', '-preset', 'fast',
'-c:a', 'aac', mp4_path
], check=True, capture_output=True)
os.remove(permanent_path) # Remove the original file
permanent_path = mp4_path
except Exception as e:
print(f"Failed to convert to MP4: {str(e)}")
# Continue with original file if conversion fails
return permanent_path, "Processing completed successfully!"
else:
print(f"Output file not found at {output_file}")
return None, "Output file was referenced but doesn't exist on disk."
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error: {str(e)}"
# Define the Gradio interface
def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold):
# Validate inputs
if not video_path:
return None, "Please upload a video file"
print(f"Processing video: {video_path}")
print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={class_ids}, conf={conf_threshold}")
output_path, status = run_tracking(
video_path,
yolo_model,
reid_model,
tracking_method,
class_ids,
conf_threshold
)
if output_path:
print(f"Returning output path: {output_path}")
# Make sure the path is absolute for Gradio
abs_path = os.path.abspath(output_path)
return abs_path, status
else:
print(f"No output path available. Status: {status}")
return None, status
# Available models and tracking methods
yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
reid_models = ["osnet_x0_25_msmt17.pt"]
tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
# Ensure dependencies and apply patches at startup
ensure_dependencies()
apply_patches()
# Create the Gradio interface
with gr.Blocks(title="πŸš€ Object Tracking") as app:
gr.Markdown("# πŸš€ Object Tracking")
gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
# Add class reference information
with gr.Accordion("YOLO Class Reference", open=False):
gr.Markdown("""
# YOLO Class IDs Reference
Enter the class IDs as comma-separated numbers in the "Target Classes" field.
Leave empty to track all classes.
## Common Class IDs:
- 0: person
- 1: bicycle
- 2: car
- 3: motorcycle
- 5: bus
- 7: truck
- 16: dog
- 17: horse
- 67: cell phone
[See full COCO class list here](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/datasets/coco128.yaml)
""")
with gr.Row():
with gr.Column(scale=1):
input_video = gr.Video(label="Input Video", sources=["upload"])
with gr.Group():
yolo_model = gr.Dropdown(
choices=yolo_models,
value="yolov8n.pt",
label="YOLO Model"
)
reid_model = gr.Dropdown(
choices=reid_models,
value="osnet_x0_25_msmt17.pt",
label="ReID Model"
)
tracking_method = gr.Dropdown(
choices=tracking_methods,
value="bytetrack",
label="Tracking Method"
)
# Class ID input field
class_ids = gr.Textbox(
value="",
label="Target Classes (comma-separated IDs, e.g. '0,2,5', leave empty for all classes)",
placeholder="e.g. 0,2,5"
)
conf_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.3,
step=0.05,
label="Confidence Threshold"
)
process_btn = gr.Button("Process Video", variant="primary")
with gr.Column(scale=1):
output_video = gr.Video(label="Output Video with Tracking")
status_text = gr.Textbox(label="Status", value="Ready to process video")
process_btn.click(
fn=process_video,
inputs=[input_video, yolo_model, reid_model, tracking_method, class_ids, conf_threshold],
outputs=[output_video, status_text]
)
# Add a debug section
with gr.Accordion("Debug Information", open=False):
debug_text = gr.Textbox(label="Debug Log", lines=10, interactive=False)
def check_environment():
info = []
# Check Python version
info.append(f"Python version: {sys.version}")
# Check Gradio version
info.append(f"Gradio version: {gr.__version__}")
# Check for ffmpeg
try:
ffmpeg_version = subprocess.run(["ffmpeg", "-version"], capture_output=True, text=True)
info.append("ffmpeg: Installed")
except:
info.append("ffmpeg: Not found")
# Check tracking directory
if os.path.exists("tracking"):
info.append("tracking directory: Found")
else:
info.append("tracking directory: Not found")
# Check models
info.append("Models:")
for model in os.listdir(MODELS_DIR) if os.path.exists(MODELS_DIR) else []:
info.append(f" - {model}")
return "\n".join(info)
check_btn = gr.Button("Check Environment")
check_btn.click(fn=check_environment, outputs=debug_text)
# Launch the app
if __name__ == "__main__":
app.launch(debug=True, share=True)