Spaces:
Sleeping
Sleeping
| """Centralized model selection and device management for SAR-X AI application. | |
| This module provides a unified interface for model loading, device management, | |
| and model selection across all pages in the application. | |
| """ | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Union | |
| import streamlit as st | |
| import torch | |
| from deim_model import DeimHgnetV2MDrone | |
| from yolo_model import YOLOModel | |
| class ModelManager: | |
| """Centralized model manager for device and model selection.""" | |
| def __init__(self): | |
| self.device = self._get_device() | |
| self.models_dir = Path(__file__).resolve().parent.parent / "models" | |
| self.model_entries = self._discover_model_entries() | |
| def _get_device(self) -> str: | |
| """Determine the best available device (CUDA or CPU).""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| return "cpu" | |
| def _discover_model_entries(self) -> List[Tuple[str, str]]: | |
| """Discover available models in the models directory.""" | |
| entries: List[Tuple[str, str]] = [("DEIM Model", "deim")] | |
| if self.models_dir.exists(): | |
| # Only add YOLOv8n base model | |
| yolov8n_file = self.models_dir / "yolov8n.pt" | |
| if yolov8n_file.exists(): | |
| entries.append(("YOLOv8n Model", f"yolo:{yolov8n_file.resolve()}")) | |
| return entries | |
| def get_available_models(self) -> List[str]: | |
| """Get list of available model labels.""" | |
| available_labels = [] | |
| for label, key in self.model_entries: | |
| if key == "deim": | |
| available_labels.append(label) | |
| continue | |
| if not key.startswith("yolo:"): | |
| continue | |
| weight_path = Path(key.split(":", 1)[1]) | |
| if weight_path.exists(): | |
| available_labels.append(label) | |
| if not available_labels: | |
| available_labels = ["DEIM Model"] | |
| return available_labels | |
| def get_model_key(self, model_label: str) -> str: | |
| """Get model key from model label.""" | |
| label_to_key: Dict[str, str] = {label: key for label, key in self.model_entries} | |
| return label_to_key.get(model_label, "deim") | |
| def load_model( | |
| _self, model_key: str, device: str = None | |
| ) -> Union[DeimHgnetV2MDrone, YOLOModel]: | |
| """Load a model with caching for better performance. | |
| Args: | |
| model_key: The model identifier (e.g., "deim" or "yolo:/path/to/model.pt") | |
| device: Device to load the model on (defaults to auto-detected device) | |
| Returns: | |
| Loaded model instance | |
| """ | |
| if device is None: | |
| device = _self.device | |
| if model_key == "deim": | |
| return DeimHgnetV2MDrone(device=device) | |
| elif model_key.startswith("yolo:"): | |
| model_path = model_key.split(":", 1)[1] | |
| return YOLOModel(model_path) | |
| else: | |
| raise ValueError(f"Invalid model key: {model_key}") | |
| def get_device_info(self) -> Dict[str, str]: | |
| """Get information about the current device.""" | |
| device_info = { | |
| "device": self.device, | |
| "cuda_available": str(torch.cuda.is_available()), | |
| } | |
| if torch.cuda.is_available(): | |
| device_info.update( | |
| { | |
| "cuda_device_count": str(torch.cuda.device_count()), | |
| "cuda_device_name": torch.cuda.get_device_name(0), | |
| "cuda_memory_allocated": f"{torch.cuda.memory_allocated(0) / 1024**3:.2f} GB", | |
| "cuda_memory_reserved": f"{torch.cuda.memory_reserved(0) / 1024**3:.2f} GB", | |
| } | |
| ) | |
| return device_info | |
| def render_device_info(self): | |
| """Render device information in Streamlit sidebar.""" | |
| device_info = self.get_device_info() | |
| st.sidebar.header("Device Information") | |
| # Device status | |
| if device_info["device"] == "cuda": | |
| st.sidebar.success(f"π Using GPU: {device_info['cuda_device_name']}") | |
| st.sidebar.info( | |
| f"Memory: {device_info['cuda_memory_allocated']} / {device_info['cuda_memory_reserved']}" | |
| ) | |
| else: | |
| st.sidebar.warning("π₯οΈ Using CPU") | |
| # # Show device details in expander | |
| # with st.sidebar.expander("Device Details"): | |
| # for key, value in device_info.items(): | |
| # st.text(f"{key}: {value}") | |
| def render_model_selection(self, key_prefix: str = "") -> Tuple[str, str]: | |
| """Render model selection UI in Streamlit sidebar. | |
| Args: | |
| key_prefix: Prefix for Streamlit widget keys to avoid conflicts | |
| Returns: | |
| Tuple of (model_label, model_key) | |
| """ | |
| st.sidebar.subheader("Model Selection") | |
| available_models = self.get_available_models() | |
| model_label = st.sidebar.selectbox( | |
| "Model", available_models, index=0, key=f"{key_prefix}_model_select" | |
| ) | |
| model_key = self.get_model_key(model_label) | |
| return model_label, model_key | |
| # Global instance for easy access | |
| model_manager = ModelManager() | |
| def get_model_manager() -> ModelManager: | |
| """Get the global model manager instance.""" | |
| return model_manager | |
| def load_model( | |
| model_key: str, device: str = None | |
| ) -> Union[DeimHgnetV2MDrone, YOLOModel]: | |
| """Convenience function to load a model.""" | |
| return model_manager.load_model(model_key, device) | |
| def get_device() -> str: | |
| """Get the current device.""" | |
| return model_manager.device | |
| def get_available_models() -> List[str]: | |
| """Get list of available models.""" | |
| return model_manager.get_available_models() | |
| def get_model_key(model_label: str) -> str: | |
| """Get model key from model label.""" | |
| return model_manager.get_model_key(model_label) | |