Spaces:
Sleeping
Sleeping
| # pages/lost_at_sea.py | |
| import io | |
| import time | |
| import queue | |
| import threading | |
| import tempfile | |
| from pathlib import Path | |
| from contextlib import contextmanager | |
| import cv2 | |
| import numpy as np | |
| import streamlit as st | |
| from PIL import Image | |
| # Torch (optional) | |
| try: | |
| import torch | |
| except Exception: | |
| torch = None | |
| from utils.model_manager import get_model_manager, load_model | |
| # ================== CONFIG ================== | |
| # --- User-tunable parameters --- | |
| DEFAULT_CONF_THRESHOLD = 0.30 # Detection confidence | |
| DEFAULT_TARGET_SHORT_SIDE = 960 # Resize short edge (px) | |
| DEFAULT_MAX_PREVIEW_FPS = 30 # Limit UI update frequency | |
| DEFAULT_DROP_IF_BEHIND = False # Drop frames if lagging | |
| DEFAULT_PROCESS_STRIDE = 1 # Process every Nth frame (1=all) | |
| DEFAULT_QUEUE_SIZE = 24 # Frame queue length | |
| DEFAULT_WRITER_CODEC = "mp4v" # Codec to avoid OpenH264 issue | |
| DEFAULT_TMP_EXT = ".mp4" # Temp file extension | |
| DEFAULT_MAX_SLIDER_SHORT_SIDE = 1080 # Max short side slider | |
| DEFAULT_MIN_SLIDER_SHORT_SIDE = 256 # Min short side slider | |
| DEFAULT_MIN_FPS_SLIDER = 1 # Min preview FPS slider | |
| DEFAULT_MAX_FPS_SLIDER = 30 # Max preview FPS slider | |
| # ============================================ | |
| # ============== Session state (stop flag) ============== | |
| if "stop_video" not in st.session_state: | |
| st.session_state["stop_video"] = False | |
| # ================== Page setup ================== | |
| st.set_page_config(page_title="Lost At Sea", layout="wide", initial_sidebar_state="expanded") | |
| st.markdown( | |
| "<h2 style='text-align:center;margin-top:0'>SAR-X<sup>ai</sup></h2>" | |
| "<h2 style='text-align:center;margin-top:0'>Lost At Sea π</h2>", | |
| unsafe_allow_html=True, | |
| ) | |
| # ================== Sidebar ================== | |
| with st.sidebar: | |
| st.page_link("app.py", label="Home") | |
| st.page_link("pages/bushland_beacon.py", label="Bushland Beacon") | |
| st.page_link("pages/lost_at_sea.py", label="Lost At Sea") | |
| st.page_link("pages/signal_watch.py", label="Signal Watch") | |
| st.markdown("---") | |
| st.page_link("pages/task_satellite.py", label="Task Satellite") | |
| st.page_link("pages/task_drone.py", label="Task Drone") | |
| st.markdown("---") | |
| st.sidebar.header("Image Detection") | |
| img_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], key="img_up") | |
| run_img = st.button("π Run Image Detection", use_container_width=True) | |
| st.sidebar.header("Video") | |
| vid_file = st.file_uploader("Upload a video", type=["mp4", "mov", "avi", "mkv"], key="vid_up") | |
| # New buttons | |
| run_vid_plain = st.button("Play Video", use_container_width=True) | |
| run_vid = st.button("π₯ Run Detection", use_container_width=True) | |
| stop_vid = st.button("Stop", use_container_width=True) | |
| if stop_vid: | |
| st.session_state["stop_video"] = True | |
| st.sidebar.markdown("---") | |
| st.sidebar.header("Parameters") | |
| conf_thr = st.slider("Minimum confidence threshold", 0.05, 0.95, DEFAULT_CONF_THRESHOLD, 0.01) | |
| target_short_side = st.select_slider( | |
| "Target short-side (downscale)", | |
| options=[256, 320, 384, 448, 512, 640, 720, 800, 864, 960, 1080], | |
| value=DEFAULT_TARGET_SHORT_SIDE, | |
| help="Resize so the shorter edge equals this value. Smaller = faster." | |
| ) | |
| max_preview_fps = st.slider( | |
| "Max preview FPS", | |
| min_value=DEFAULT_MIN_FPS_SLIDER, | |
| max_value=DEFAULT_MAX_FPS_SLIDER, | |
| value=DEFAULT_MAX_PREVIEW_FPS, | |
| help="Throttles UI updates for smoother preview." | |
| ) | |
| drop_if_behind = st.toggle( | |
| "Drop frames if behind", | |
| value=DEFAULT_DROP_IF_BEHIND, | |
| help="Drop frames to maintain smooth preview." | |
| ) | |
| process_stride = st.slider( | |
| "Process every Nth frame", | |
| min_value=1, | |
| max_value=5, | |
| value=DEFAULT_PROCESS_STRIDE, | |
| help="1 = every frame; higher values reuse last result." | |
| ) | |
| st.sidebar.markdown("---") | |
| model_manager = get_model_manager() | |
| model_label, model_key = model_manager.render_model_selection(key_prefix="lost_at_sea") | |
| st.sidebar.markdown("---") | |
| model_manager.render_device_info() | |
| # ================== Perf knobs for OpenCV ================== | |
| try: | |
| cv2.setNumThreads(1) | |
| except Exception: | |
| pass | |
| try: | |
| cv2.ocl.setUseOpenCL(False) | |
| except Exception: | |
| pass | |
| # ================== Helper functions ================== | |
| def _resize_keep_aspect(img_bgr: np.ndarray, short_side: int) -> np.ndarray: | |
| h, w = img_bgr.shape[:2] | |
| if min(h, w) == short_side: | |
| return img_bgr | |
| if h < w: | |
| new_h = short_side | |
| new_w = int(round(w * (short_side / h))) | |
| else: | |
| new_w = short_side | |
| new_h = int(round(h * (short_side / w))) | |
| return cv2.resize(img_bgr, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| def _should_force_cpu_for_model(model_key: str) -> bool: | |
| return (model_key or "").lower() == "deim" | |
| def _choose_device(model_key: str) -> str: | |
| if _should_force_cpu_for_model(model_key): | |
| return "cpu" | |
| if torch is not None and torch.cuda.is_available(): | |
| return "cuda" | |
| return "cpu" | |
| def _warmup_model(model, model_key: str, shape=(720, 1280, 3), conf: float = 0.25): | |
| dummy = np.zeros(shape, dtype=np.uint8) | |
| try: | |
| if (model_key or "").lower() == "deim": | |
| pil = Image.fromarray(cv2.cvtColor(dummy, cv2.COLOR_BGR2RGB)) | |
| model.predict_image(pil, min_confidence=conf) | |
| else: | |
| model.predict_and_visualize(dummy, min_confidence=conf, show_score=False) | |
| except Exception: | |
| pass | |
| def maybe_autocast(enabled: bool): | |
| if enabled and torch is not None and torch.cuda.is_available(): | |
| with torch.cuda.amp.autocast(): | |
| yield | |
| else: | |
| yield | |
| def _device_hint() -> str: | |
| if torch is None: | |
| return "cpu" | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| # ================== Passthrough (no model, no boxes) ================== | |
| def run_video_passthrough( | |
| vid_bytes: bytes, | |
| target_short_side: int = DEFAULT_TARGET_SHORT_SIDE, | |
| max_preview_fps: int = DEFAULT_MAX_PREVIEW_FPS, | |
| drop_if_behind: bool = DEFAULT_DROP_IF_BEHIND, | |
| ): | |
| """Play the uploaded video with scaling & pacing only (no inference, no overlays).""" | |
| ts = int(time.time() * 1000) | |
| tmp_in = Path(tempfile.gettempdir()) / f"in_{ts}{DEFAULT_TMP_EXT}" | |
| with open(tmp_in, "wb") as f: | |
| f.write(vid_bytes) | |
| cap = cv2.VideoCapture(str(tmp_in), cv2.CAP_FFMPEG) | |
| if not cap.isOpened(): | |
| st.error("Failed to open the uploaded video.") | |
| return | |
| try: | |
| cap.set(cv2.CAP_PROP_BUFFERSIZE, 2) | |
| except Exception: | |
| pass | |
| src_fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| # UI placeholders | |
| frame_ph = st.empty() | |
| info_ph = st.empty() | |
| prog = st.progress(0.0, text="Preparingβ¦") | |
| # Reader thread -> queue | |
| q: "queue.Queue[tuple[int, np.ndarray] | None]" = queue.Queue(maxsize=DEFAULT_QUEUE_SIZE) | |
| def reader(): | |
| idx = 0 | |
| while True: | |
| if st.session_state.get("stop_video", False): | |
| break | |
| ok, frm = cap.read() | |
| if not ok: | |
| break | |
| if drop_if_behind and q.full(): | |
| try: | |
| q.get_nowait() | |
| except queue.Empty: | |
| pass | |
| try: | |
| q.put((idx, frm), timeout=0.05) | |
| except queue.Full: | |
| pass | |
| idx += 1 | |
| q.put(None) | |
| reader_th = threading.Thread(target=reader, daemon=True) | |
| reader_th.start() | |
| # Writer (optional export) | |
| tmp_out = Path(tempfile.gettempdir()) / f"out_{ts}{DEFAULT_TMP_EXT}" | |
| writer = None | |
| # Pacing and preview throttle | |
| min_preview_interval = 1.0 / float(max_preview_fps) | |
| last_preview_ts = 0.0 | |
| frame_interval = 1.0 / float(src_fps if src_fps > 0 else 25.0) | |
| next_write_ts = time.perf_counter() + frame_interval | |
| frames_done = 0 | |
| t0 = time.perf_counter() | |
| try: | |
| with st.spinner("Playing videoβ¦"): | |
| while True: | |
| if st.session_state.get("stop_video", False): | |
| break | |
| item = q.get() | |
| if item is None: | |
| break | |
| idx, frame_bgr = item | |
| # Downscale for speed/preview | |
| vis_bgr = _resize_keep_aspect(frame_bgr, short_side=target_short_side) | |
| # Init writer lazily | |
| if writer is None: | |
| H, W = vis_bgr.shape[:2] | |
| fourcc = cv2.VideoWriter_fourcc(*DEFAULT_WRITER_CODEC) | |
| writer = cv2.VideoWriter(str(tmp_out), fourcc, src_fps, (W, H)) | |
| # Pace writing to match source | |
| now = time.perf_counter() | |
| if now < next_write_ts: | |
| time.sleep(max(0.0, next_write_ts - now)) | |
| writer.write(vis_bgr) | |
| next_write_ts += frame_interval | |
| frames_done += 1 | |
| # UI updates (throttled) | |
| now = time.perf_counter() | |
| if (now - last_preview_ts) >= min_preview_interval: | |
| frame_ph.image( | |
| cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB), | |
| use_container_width=True, | |
| output_format="JPEG", | |
| channels="RGB", | |
| ) | |
| elapsed = now - t0 | |
| fps_est = frames_done / max(elapsed, 1e-6) | |
| info_ph.info( | |
| f"Frames: {frames_done}/{total_frames or '?'} β’ " | |
| f"Throughput: {fps_est:.1f} FPS β’ Source FPS: {src_fps:.1f} β’ " | |
| f"Mode: Passthrough" | |
| ) | |
| last_preview_ts = now | |
| # Progress | |
| progress = ((idx + 1) / total_frames) if total_frames > 0 else min(frames_done / (frames_done + 30), 0.99) | |
| prog.progress(progress, text=f"Playing frame {idx + 1}{'/' + str(total_frames) if total_frames>0 else ''}β¦") | |
| except Exception as exc: | |
| st.error(f"Video playback failed: {exc}") | |
| return | |
| finally: | |
| try: | |
| cap.release() | |
| if writer is not None: | |
| writer.release() | |
| except Exception: | |
| pass | |
| # Reset stop flag after finishing | |
| st.session_state["stop_video"] = False | |
| st.success("Done!") | |
| if tmp_out.exists(): | |
| st.video(str(tmp_out)) | |
| with open(tmp_out, "rb") as f: | |
| st.download_button( | |
| "Download video", | |
| data=f.read(), | |
| file_name=tmp_out.name, | |
| mime="video/mp4", | |
| ) | |
| else: | |
| st.error("Playback completed but output file was not created.") | |
| # ================== Detection routines ================== | |
| def run_image_detection(uploaded_file, conf_thr: float = 0.5, model_key: str = "deim"): | |
| try: | |
| data = uploaded_file.getvalue() | |
| img = Image.open(io.BytesIO(data)).convert("RGB") | |
| st.image(img, caption="Uploaded Image", use_container_width=True) | |
| except Exception as e: | |
| st.error(f"Error loading image: {e}") | |
| return | |
| try: | |
| model = load_model(model_key) | |
| device = _choose_device(model_key) | |
| if torch is not None: | |
| try: | |
| model.to(device) | |
| except Exception: | |
| pass | |
| _warmup_model(model, model_key=model_key, shape=(img.height, img.width, 3), conf=conf_thr) | |
| use_amp = (device == "cuda") and not _should_force_cpu_for_model(model_key) | |
| with st.spinner(f"Running detection on {device.upper()}β¦"): | |
| with maybe_autocast(use_amp): | |
| if (model_key or "").lower() == "deim": | |
| annotated = model.predict_image(img, min_confidence=conf_thr) | |
| else: | |
| try: | |
| annotated = model.predict_image(img, min_confidence=conf_thr) | |
| except Exception: | |
| np_bgr = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| _, vis = model.predict_and_visualize(np_bgr, min_confidence=conf_thr, show_score=True) | |
| annotated = Image.fromarray(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)) | |
| st.subheader("π― Detection Results") | |
| st.image(annotated, caption="Detections", use_container_width=True) | |
| if _should_force_cpu_for_model(model_key): | |
| st.info("DEIM runs on CPU to avoid TorchScript device mismatch.") | |
| except Exception as e: | |
| st.error(f"Error during detection: {e}") | |
| def run_video_detection( | |
| vid_bytes: bytes, | |
| conf_thr: float = 0.5, | |
| model_key: str = "deim", | |
| target_short_side: int = DEFAULT_TARGET_SHORT_SIDE, | |
| max_preview_fps: int = DEFAULT_MAX_PREVIEW_FPS, | |
| drop_if_behind: bool = DEFAULT_DROP_IF_BEHIND, | |
| process_stride: int = DEFAULT_PROCESS_STRIDE, | |
| ): | |
| # Save upload to a temp file | |
| ts = int(time.time() * 1000) | |
| tmp_in = Path(tempfile.gettempdir()) / f"in_{ts}{DEFAULT_TMP_EXT}" | |
| with open(tmp_in, "wb") as f: | |
| f.write(vid_bytes) | |
| # Load model & choose device | |
| model = load_model(model_key) | |
| device = _choose_device(model_key) | |
| if torch is not None: | |
| try: | |
| model.to(device) | |
| except Exception: | |
| pass | |
| # Capture | |
| cap = cv2.VideoCapture(str(tmp_in), cv2.CAP_FFMPEG) | |
| if not cap.isOpened(): | |
| st.error("Failed to open the uploaded video.") | |
| return | |
| try: | |
| cap.set(cv2.CAP_PROP_BUFFERSIZE, 2) | |
| except Exception: | |
| pass | |
| src_fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| src_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| src_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| frame_ph = st.empty() | |
| info_ph = st.empty() | |
| prog = st.progress(0.0, text="Preparingβ¦") | |
| _warmup_model(model, model_key=model_key, shape=(min(src_h, src_w), max(src_h, src_w), 3), conf=conf_thr) | |
| # Reader thread -> bounded queue | |
| q: "queue.Queue[tuple[int, np.ndarray] | None]" = queue.Queue(maxsize=DEFAULT_QUEUE_SIZE) | |
| def reader(): | |
| idx = 0 | |
| while True: | |
| if st.session_state.get("stop_video", False): | |
| break | |
| ok, frm = cap.read() | |
| if not ok: | |
| break | |
| if drop_if_behind and q.full(): | |
| # drop the oldest frame to keep things moving | |
| try: | |
| q.get_nowait() | |
| except queue.Empty: | |
| pass | |
| try: | |
| q.put((idx, frm), timeout=0.05) | |
| except queue.Full: | |
| pass | |
| idx += 1 | |
| q.put(None) | |
| reader_th = threading.Thread(target=reader, daemon=True) | |
| reader_th.start() | |
| tmp_out = Path(tempfile.gettempdir()) / f"out_{ts}{DEFAULT_TMP_EXT}" | |
| writer = None | |
| # Preview throttle | |
| min_preview_interval = 1.0 / float(max_preview_fps) | |
| last_preview_ts = 0.0 | |
| # Source pacing | |
| frame_interval = 1.0 / float(src_fps if src_fps > 0 else 25.0) | |
| next_write_ts = time.perf_counter() + frame_interval | |
| frames_done = 0 | |
| t0 = time.perf_counter() | |
| use_amp = (device == "cuda") and not _should_force_cpu_for_model(model_key) | |
| last_vis_bgr = None # for stride reuse | |
| try: | |
| with st.spinner(f"Processing video on {device.upper()} with live previewβ¦"): | |
| while True: | |
| if st.session_state.get("stop_video", False): | |
| break | |
| item = q.get() | |
| if item is None: | |
| break | |
| idx, frame_bgr = item | |
| # Downscale for speed | |
| proc_bgr = _resize_keep_aspect(frame_bgr, short_side=target_short_side) | |
| run_infer = (process_stride <= 1) or ((idx % process_stride) == 0) | |
| if run_infer: | |
| # Run model | |
| if (model_key or "").lower() == "deim": | |
| img_rgb = cv2.cvtColor(proc_bgr, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(img_rgb) | |
| annotated_pil = model.predict_image(pil_img, min_confidence=conf_thr) | |
| vis_bgr = cv2.cvtColor(np.array(annotated_pil), cv2.COLOR_RGB2BGR) | |
| else: | |
| with maybe_autocast(use_amp): | |
| try: | |
| _, vis_bgr = model.predict_and_visualize( | |
| proc_bgr, min_confidence=conf_thr, show_score=True | |
| ) | |
| except Exception: | |
| pil = Image.fromarray(cv2.cvtColor(proc_bgr, cv2.COLOR_BGR2RGB)) | |
| annotated = model.predict_image(pil, min_confidence=conf_thr) | |
| vis_bgr = cv2.cvtColor(np.array(annotated), cv2.COLOR_RGB2BGR) | |
| last_vis_bgr = vis_bgr | |
| else: | |
| # Reuse last visualised frame to avoid visible βskipsβ | |
| vis_bgr = last_vis_bgr if last_vis_bgr is not None else proc_bgr | |
| # Init writer when first output frame is ready | |
| if writer is None: | |
| H, W = vis_bgr.shape[:2] | |
| fourcc = cv2.VideoWriter_fourcc(*DEFAULT_WRITER_CODEC) # avoids OpenH264 issues | |
| out_fps = src_fps # preserve source FPS in output | |
| writer = cv2.VideoWriter(str(tmp_out), fourcc, out_fps, (W, H)) | |
| # Pace writing to match the source timeline | |
| now = time.perf_counter() | |
| if now < next_write_ts: | |
| time.sleep(max(0.0, next_write_ts - now)) | |
| writer.write(vis_bgr) | |
| next_write_ts += frame_interval | |
| frames_done += 1 | |
| # UI updates (throttled) | |
| now = time.perf_counter() | |
| if (now - last_preview_ts) >= min_preview_interval: | |
| frame_ph.image( | |
| cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB), | |
| use_container_width=True, | |
| output_format="JPEG", | |
| channels="RGB", | |
| ) | |
| elapsed = now - t0 | |
| fps_est = frames_done / max(elapsed, 1e-6) | |
| device_msg = f"{device.upper()}" if device != "cuda" else f"{device.upper()} ({_device_hint().upper()})" | |
| info_text = ( | |
| f"Processed: {frames_done} / {total_frames if total_frames>0 else '?'} β’ " | |
| f"Throughput: {fps_est:.1f} FPS β’ " | |
| f"Source FPS: {src_fps:.1f} β’ Device: {device_msg} β’ " | |
| f"Stride: {process_stride}x" | |
| ) | |
| if _should_force_cpu_for_model(model_key): | |
| info_text += " β’ Note: DEIM forced to CPU." | |
| info_ph.info(info_text) | |
| last_preview_ts = now | |
| # Progress bar | |
| progress = ((idx + 1) / total_frames) if total_frames > 0 else min(frames_done / (frames_done + 30), 0.99) | |
| prog.progress(progress, text=f"Processing frame {idx + 1}{'/' + str(total_frames) if total_frames>0 else ''}β¦") | |
| except Exception as exc: | |
| st.error(f"Video detection failed: {exc}") | |
| return | |
| finally: | |
| try: | |
| cap.release() | |
| if writer is not None: | |
| writer.release() | |
| except Exception: | |
| pass | |
| # Reset stop flag after finishing | |
| st.session_state["stop_video"] = False | |
| st.success("Done!") | |
| if tmp_out.exists(): | |
| st.video(str(tmp_out)) | |
| with open(tmp_out, "rb") as f: | |
| st.download_button( | |
| "Download processed video", | |
| data=f.read(), | |
| file_name=tmp_out.name, | |
| mime="video/mp4", | |
| ) | |
| else: | |
| st.error("Video processing completed but output file was not created.") | |
| # ================== Main Actions ================== | |
| if run_img: | |
| if img_file is None: | |
| st.warning("Please upload an image first.") | |
| else: | |
| run_image_detection(img_file, conf_thr=conf_thr, model_key=model_key) | |
| # New: Passthrough mode | |
| if run_vid_plain: | |
| if vid_file is None: | |
| st.warning("Please upload a video first.") | |
| else: | |
| st.session_state["stop_video"] = False | |
| run_video_passthrough( | |
| vid_bytes=vid_file.read(), | |
| target_short_side=target_short_side, | |
| max_preview_fps=max_preview_fps, | |
| drop_if_behind=drop_if_behind, | |
| ) | |
| # Original: Detection mode | |
| if run_vid: | |
| if vid_file is None: | |
| st.warning("Please upload a video first.") | |
| else: | |
| st.session_state["stop_video"] = False | |
| run_video_detection( | |
| vid_bytes=vid_file.read(), | |
| conf_thr=conf_thr, | |
| model_key=model_key, | |
| target_short_side=target_short_side, | |
| max_preview_fps=max_preview_fps, | |
| drop_if_behind=drop_if_behind, | |
| process_stride=process_stride, | |
| ) | |