Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Tuple | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from yolo_model import BaseModel | |
| _TORCH_MIN_VERSION = (2, 5) | |
| def _parse_version(version_str: str) -> Tuple[int, ...]: | |
| parts = [] | |
| for piece in version_str.split("+")[0].split("."): | |
| try: | |
| parts.append(int(piece)) | |
| except ValueError: | |
| break | |
| return tuple(parts) | |
| class DeimHgnetV2MDrone(BaseModel): | |
| def __init__(self, device: str, version: str = "v2"): | |
| self.device = device | |
| repo_root = Path(__file__).resolve().parents[1] | |
| default_rel = ( | |
| Path("app_service") / "models" / f"model_deimhgnetV2m_{device}_{version}.pt" | |
| ) | |
| # Allow explicit override via env var | |
| override = ( | |
| Path(os.environ["DEIM_WEIGHTS_PATH"]) | |
| if "DEIM_WEIGHTS_PATH" in os.environ | |
| else None | |
| ) | |
| candidate_paths = [ | |
| override, | |
| repo_root / default_rel, | |
| Path(__file__).resolve().parent | |
| / "models" | |
| / f"model_deimhgnetV2m_{device}_{version}.pt", | |
| Path.cwd() / "services" / default_rel, | |
| Path("/app") / "services" / default_rel, | |
| ] | |
| weights_path = next((p for p in candidate_paths if p and p.exists()), None) | |
| if weights_path is None: | |
| models_dir = Path(__file__).resolve().parent / "models" | |
| alt_models_dir = repo_root / "app_service" / "models" | |
| available = [] | |
| for d in [models_dir, alt_models_dir]: | |
| try: | |
| if d.exists(): | |
| available.extend(str(p.name) for p in d.glob("*.pt")) | |
| except Exception: | |
| pass | |
| searched = [str(p) for p in candidate_paths if p] | |
| raise FileNotFoundError( | |
| "Model weights not found. Looked in: " | |
| + "; ".join(searched) | |
| + ". Available .pt files: " | |
| + (", ".join(sorted(set(available))) or "<none>") | |
| ) | |
| cfg_path = weights_path.with_suffix(".json") | |
| if not cfg_path.exists(): | |
| raise FileNotFoundError( | |
| f"Config JSON not found next to weights: {cfg_path}" | |
| ) | |
| version_tuple = _parse_version(torch.__version__) | |
| if version_tuple < _TORCH_MIN_VERSION: | |
| raise RuntimeError( | |
| "PyTorch {} is too old for these weights. " | |
| "Please upgrade to >= {}.{} (e.g. set torch==2.5.1 in Dockerfile).".format( | |
| torch.__version__, *_TORCH_MIN_VERSION | |
| ) | |
| ) | |
| size_bytes = weights_path.stat().st_size | |
| if size_bytes < 1_000_000: | |
| raise RuntimeError( | |
| f"Weights file at {weights_path} is only {size_bytes} bytes. " | |
| "This usually means Git LFS pointers were copied instead of the binary file. " | |
| "Run `git lfs pull` before building the container to fetch the real weights." | |
| ) | |
| self.cfg = json.load(open(cfg_path, "r")) | |
| self._target_h, self._target_w = ( | |
| int(self.cfg["target_size"][0]), | |
| int(self.cfg["target_size"][1]), | |
| ) | |
| self._categories = self.cfg["categories"] | |
| self._confs_by_categories = self.cfg["confs_by_categories"] | |
| print(f"Loading model from: {weights_path}") | |
| print(f"Model device: {self.device}") | |
| self.model = torch.jit.load(weights_path, map_location=self.device).eval() | |
| print(f"Model loaded successfully on device: {self.device}") | |
| def _preprocess_image(self, image: Image): | |
| transforms = T.Compose( | |
| [ | |
| T.Resize((self.cfg["target_size"][0], self.cfg["target_size"][1])), | |
| T.ToTensor(), | |
| ] | |
| ) | |
| return transforms(image).unsqueeze(0).to(self.device) | |
| def _postprocess_detections(self, scores, bboxes, min_confidence: float, wh: Tuple[int, int]): | |
| w, h = wh | |
| b_np = bboxes[0].cpu().numpy() | |
| s_np = scores.sigmoid()[0].cpu().numpy() | |
| mask = (s_np.max(axis=1) >= min_confidence).squeeze() | |
| if not mask.any(): | |
| return np.zeros((0, 6), dtype=np.float32) | |
| valid = b_np[mask] | |
| cx, cy, box_w, box_h = valid[:, 0], valid[:, 1], valid[:, 2], valid[:, 3] | |
| x1 = cx - box_w / 2 | |
| y1 = cy - box_h / 2 | |
| x2 = cx + box_w / 2 | |
| y2 = cy + box_h / 2 | |
| valid_xyxy = np.stack([x1, y1, x2, y2], axis=1) * [w, h, w, h] | |
| return np.concatenate([ | |
| valid_xyxy, | |
| s_np[mask].max(axis=1, keepdims=True), | |
| s_np[mask].argmax(axis=1, keepdims=True) | |
| ], axis=1) | |
| def _nms(self, dets): | |
| if dets.shape[0] == 0 or self.cfg["nms_iou_thr"] <= 0: | |
| return dets | |
| class_ids = np.unique(dets[:, 5].astype(int)) | |
| keep_all = [] | |
| for class_id in class_ids: | |
| class_mask = dets[:, 5] == class_id | |
| class_dets = dets[class_mask] | |
| if class_dets.shape[0] == 0: | |
| continue | |
| class_keep = self._nms_single_class(class_dets) | |
| original_indices = np.where(class_mask)[0] | |
| keep_all.extend(original_indices[class_keep]) | |
| return dets[keep_all] if keep_all else np.zeros((0, 6), dtype=np.float32) | |
| def _nms_single_class(self, dets): | |
| if dets.shape[0] == 0: | |
| return [] | |
| x1 = dets[:, 0] | |
| y1 = dets[:, 1] | |
| x2 = dets[:, 2] | |
| y2 = dets[:, 3] | |
| scores = dets[:, 4] | |
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |
| order = scores.argsort()[::-1] | |
| keep = [] | |
| while order.size > 0: | |
| i = order[0] | |
| keep.append(i) | |
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |
| w = np.maximum(0.0, xx2 - xx1 + 1) | |
| h = np.maximum(0.0, yy2 - yy1 + 1) | |
| inter = w * h | |
| iou = inter / (areas[i] + areas[order[1:]] - inter) | |
| inds = np.where(iou <= self.cfg["nms_iou_thr"])[0] | |
| order = order[inds + 1] | |
| return keep | |
| def _draw_detections_on_np( | |
| self, image_np: np.ndarray, dets: np.ndarray | |
| ) -> np.ndarray: | |
| for bbox in dets: | |
| x1, y1, x2, y2, confidence, category_id = bbox | |
| category_name = self._categories[int(category_id)] | |
| conf_by_this_cat = self._confs_by_categories.get(category_name, 0.0) | |
| if confidence < conf_by_this_cat: | |
| continue | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| cv2.rectangle(image_np, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| label = f"{category_name} {confidence:.2f}" | |
| label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] | |
| cv2.rectangle( | |
| image_np, | |
| (x1, y1 - label_size[1] - 10), | |
| (x1 + label_size[0], y1), | |
| (0, 255, 0), | |
| -1, | |
| ) | |
| cv2.putText( | |
| image_np, | |
| label, | |
| (x1, y1 - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (0, 0, 0), | |
| 1, | |
| ) | |
| return image_np | |
| def _preprocess_frame_fast(self, frame_bgr: np.ndarray) -> torch.Tensor: | |
| """Convert BGR numpy frame to normalized tensor on target device.""" | |
| frame = np.ascontiguousarray(frame_bgr) | |
| if frame.shape[0] != self._target_h or frame.shape[1] != self._target_w: | |
| frame = cv2.resize(frame, (self._target_w, self._target_h)) | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).contiguous() | |
| tensor = tensor.to(self.device, dtype=torch.float32).unsqueeze(0) | |
| tensor = tensor.div(255.0) | |
| return tensor | |
| def annotate_frame_bgr(self, frame_bgr: np.ndarray, min_confidence: float) -> np.ndarray: | |
| """Run inference on a BGR frame and return annotated frame in BGR space.""" | |
| tensor = self._preprocess_frame_fast(frame_bgr) | |
| with torch.inference_mode(): | |
| scores, bboxes = self.model(tensor) | |
| dets = self._postprocess_detections( | |
| scores, bboxes, min_confidence, (frame_bgr.shape[1], frame_bgr.shape[0]) | |
| ) | |
| dets = self._nms(dets) | |
| annotated = frame_bgr.copy() | |
| return self._draw_detections_on_np(annotated, dets) | |
| def predict_image(self, image: Image, min_confidence: float) -> Image: | |
| tensor = self._preprocess_image(image.copy()) | |
| with torch.no_grad(): | |
| labels, bboxes = self.model(tensor) | |
| dets = self._postprocess_detections(labels, bboxes, min_confidence, image.size) | |
| dets = self._nms(dets) | |
| image_np: np.ndarray = np.array(image) | |
| image_np = self._draw_detections_on_np(image_np, dets) | |
| return Image.fromarray(image_np) | |
| def predict_video( | |
| self, video, min_confidence: float, target_dir_name="annotated_video" | |
| ): | |
| input_path = str(video) | |
| cap = cv2.VideoCapture(input_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Cannot open video: {input_path}") | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| input_p = Path(input_path) | |
| out_dir = Path(target_dir_name) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # Use simple AVI format with MJPG codec (most compatible) | |
| out_path = out_dir / f"{input_p.stem}_annotated.avi" | |
| # Set up video writer with better error handling | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 | |
| # Use MJPG codec which is most widely supported | |
| fourcc = cv2.VideoWriter_fourcc(*"MJPG") | |
| writer = cv2.VideoWriter(str(out_path), fourcc, fps, (width, height)) | |
| if not writer.isOpened(): | |
| # Fallback to XVID if MJPG fails | |
| print("MJPG codec failed, trying XVID...") | |
| fourcc = cv2.VideoWriter_fourcc(*"XVID") | |
| writer = cv2.VideoWriter(str(out_path), fourcc, fps, (width, height)) | |
| if not writer.isOpened(): | |
| raise RuntimeError( | |
| "Could not initialize video writer with MJPG or XVID codec" | |
| ) | |
| print(f"DEIM Model: Processing video {input_p.name} ({width}x{height}, {fps:.1f} FPS)") | |
| print(f"DEIM Model: Output will be saved to {out_path}") | |
| frame_count = 0 | |
| while True: | |
| ret, frame_bgr = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(frame_rgb) | |
| tensor = self._preprocess_image(pil_img.copy()) | |
| with torch.no_grad(): | |
| labels, bboxes = self.model(tensor) | |
| dets = self._postprocess_detections( | |
| labels, bboxes, min_confidence, (width, height) | |
| ) | |
| dets = self._nms(dets) | |
| annotated_frame = self._draw_detections_on_np( | |
| frame_bgr.copy(), dets | |
| ) | |
| writer.write(annotated_frame) | |
| frame_count += 1 | |
| print(f"processed {frame_count} frames...") | |
| cap.release() | |
| if writer is not None: | |
| writer.release() | |
| return str(out_path) | |
| # if __name__ == "__main__": | |
| # model = DeimHgnetV2MDrone(version="v3", device="cpu") | |
| # output_image = model.predict_video("./resources/videos/raw/sample2.mp4", 0.3) | |