"""Centralized model selection and device management for SAR-X AI application. This module provides a unified interface for model loading, device management, and model selection across all pages in the application. """ from pathlib import Path from typing import Dict, List, Tuple, Union import streamlit as st import torch from deim_model import DeimHgnetV2MDrone from yolo_model import YOLOModel class ModelManager: """Centralized model manager for device and model selection.""" def __init__(self): self.device = self._get_device() self.models_dir = Path(__file__).resolve().parent.parent / "models" self.model_entries = self._discover_model_entries() def _get_device(self) -> str: """Determine the best available device (CUDA or CPU).""" if torch.cuda.is_available(): return "cuda" return "cpu" def _discover_model_entries(self) -> List[Tuple[str, str]]: """Discover available models in the models directory.""" entries: List[Tuple[str, str]] = [("DEIM Model", "deim")] if self.models_dir.exists(): # Only add YOLOv8n base model yolov8n_file = self.models_dir / "yolov8n.pt" if yolov8n_file.exists(): entries.append(("YOLOv8n Model", f"yolo:{yolov8n_file.resolve()}")) return entries def get_available_models(self) -> List[str]: """Get list of available model labels.""" available_labels = [] for label, key in self.model_entries: if key == "deim": available_labels.append(label) continue if not key.startswith("yolo:"): continue weight_path = Path(key.split(":", 1)[1]) if weight_path.exists(): available_labels.append(label) if not available_labels: available_labels = ["DEIM Model"] return available_labels def get_model_key(self, model_label: str) -> str: """Get model key from model label.""" label_to_key: Dict[str, str] = {label: key for label, key in self.model_entries} return label_to_key.get(model_label, "deim") @st.cache_resource def load_model( _self, model_key: str, device: str = None ) -> Union[DeimHgnetV2MDrone, YOLOModel]: """Load a model with caching for better performance. Args: model_key: The model identifier (e.g., "deim" or "yolo:/path/to/model.pt") device: Device to load the model on (defaults to auto-detected device) Returns: Loaded model instance """ if device is None: device = _self.device if model_key == "deim": return DeimHgnetV2MDrone(device=device) elif model_key.startswith("yolo:"): model_path = model_key.split(":", 1)[1] return YOLOModel(model_path) else: raise ValueError(f"Invalid model key: {model_key}") def get_device_info(self) -> Dict[str, str]: """Get information about the current device.""" device_info = { "device": self.device, "cuda_available": str(torch.cuda.is_available()), } if torch.cuda.is_available(): device_info.update( { "cuda_device_count": str(torch.cuda.device_count()), "cuda_device_name": torch.cuda.get_device_name(0), "cuda_memory_allocated": f"{torch.cuda.memory_allocated(0) / 1024**3:.2f} GB", "cuda_memory_reserved": f"{torch.cuda.memory_reserved(0) / 1024**3:.2f} GB", } ) return device_info def render_device_info(self): """Render device information in Streamlit sidebar.""" device_info = self.get_device_info() st.sidebar.header("Device Information") # Device status if device_info["device"] == "cuda": st.sidebar.success(f"🚀 Using GPU: {device_info['cuda_device_name']}") st.sidebar.info( f"Memory: {device_info['cuda_memory_allocated']} / {device_info['cuda_memory_reserved']}" ) else: st.sidebar.warning("🖥️ Using CPU") # # Show device details in expander # with st.sidebar.expander("Device Details"): # for key, value in device_info.items(): # st.text(f"{key}: {value}") def render_model_selection(self, key_prefix: str = "") -> Tuple[str, str]: """Render model selection UI in Streamlit sidebar. Args: key_prefix: Prefix for Streamlit widget keys to avoid conflicts Returns: Tuple of (model_label, model_key) """ st.sidebar.subheader("Model Selection") available_models = self.get_available_models() model_label = st.sidebar.selectbox( "Model", available_models, index=0, key=f"{key_prefix}_model_select" ) model_key = self.get_model_key(model_label) return model_label, model_key # Global instance for easy access model_manager = ModelManager() def get_model_manager() -> ModelManager: """Get the global model manager instance.""" return model_manager def load_model( model_key: str, device: str = None ) -> Union[DeimHgnetV2MDrone, YOLOModel]: """Convenience function to load a model.""" return model_manager.load_model(model_key, device) def get_device() -> str: """Get the current device.""" return model_manager.device def get_available_models() -> List[str]: """Get list of available models.""" return model_manager.get_available_models() def get_model_key(model_label: str) -> str: """Get model key from model label.""" return model_manager.get_model_key(model_label)