lucid-hf's picture
CI: deploy Docker/PDM Space
a65508e verified
"""Centralized model selection and device management for SAR-X AI application.
This module provides a unified interface for model loading, device management,
and model selection across all pages in the application.
"""
from pathlib import Path
from typing import Dict, List, Tuple, Union
import streamlit as st
import torch
from deim_model import DeimHgnetV2MDrone
from yolo_model import YOLOModel
class ModelManager:
"""Centralized model manager for device and model selection."""
def __init__(self):
self.device = self._get_device()
self.models_dir = Path(__file__).resolve().parent.parent / "models"
self.model_entries = self._discover_model_entries()
def _get_device(self) -> str:
"""Determine the best available device (CUDA or CPU)."""
if torch.cuda.is_available():
return "cuda"
return "cpu"
def _discover_model_entries(self) -> List[Tuple[str, str]]:
"""Discover available models in the models directory."""
entries: List[Tuple[str, str]] = [("DEIM Model", "deim")]
if self.models_dir.exists():
# Only add YOLOv8n base model
yolov8n_file = self.models_dir / "yolov8n.pt"
if yolov8n_file.exists():
entries.append(("YOLOv8n Model", f"yolo:{yolov8n_file.resolve()}"))
return entries
def get_available_models(self) -> List[str]:
"""Get list of available model labels."""
available_labels = []
for label, key in self.model_entries:
if key == "deim":
available_labels.append(label)
continue
if not key.startswith("yolo:"):
continue
weight_path = Path(key.split(":", 1)[1])
if weight_path.exists():
available_labels.append(label)
if not available_labels:
available_labels = ["DEIM Model"]
return available_labels
def get_model_key(self, model_label: str) -> str:
"""Get model key from model label."""
label_to_key: Dict[str, str] = {label: key for label, key in self.model_entries}
return label_to_key.get(model_label, "deim")
@st.cache_resource
def load_model(
_self, model_key: str, device: str = None
) -> Union[DeimHgnetV2MDrone, YOLOModel]:
"""Load a model with caching for better performance.
Args:
model_key: The model identifier (e.g., "deim" or "yolo:/path/to/model.pt")
device: Device to load the model on (defaults to auto-detected device)
Returns:
Loaded model instance
"""
if device is None:
device = _self.device
if model_key == "deim":
return DeimHgnetV2MDrone(device=device)
elif model_key.startswith("yolo:"):
model_path = model_key.split(":", 1)[1]
return YOLOModel(model_path)
else:
raise ValueError(f"Invalid model key: {model_key}")
def get_device_info(self) -> Dict[str, str]:
"""Get information about the current device."""
device_info = {
"device": self.device,
"cuda_available": str(torch.cuda.is_available()),
}
if torch.cuda.is_available():
device_info.update(
{
"cuda_device_count": str(torch.cuda.device_count()),
"cuda_device_name": torch.cuda.get_device_name(0),
"cuda_memory_allocated": f"{torch.cuda.memory_allocated(0) / 1024**3:.2f} GB",
"cuda_memory_reserved": f"{torch.cuda.memory_reserved(0) / 1024**3:.2f} GB",
}
)
return device_info
def render_device_info(self):
"""Render device information in Streamlit sidebar."""
device_info = self.get_device_info()
st.sidebar.header("Device Information")
# Device status
if device_info["device"] == "cuda":
st.sidebar.success(f"πŸš€ Using GPU: {device_info['cuda_device_name']}")
st.sidebar.info(
f"Memory: {device_info['cuda_memory_allocated']} / {device_info['cuda_memory_reserved']}"
)
else:
st.sidebar.warning("πŸ–₯️ Using CPU")
# # Show device details in expander
# with st.sidebar.expander("Device Details"):
# for key, value in device_info.items():
# st.text(f"{key}: {value}")
def render_model_selection(self, key_prefix: str = "") -> Tuple[str, str]:
"""Render model selection UI in Streamlit sidebar.
Args:
key_prefix: Prefix for Streamlit widget keys to avoid conflicts
Returns:
Tuple of (model_label, model_key)
"""
st.sidebar.subheader("Model Selection")
available_models = self.get_available_models()
model_label = st.sidebar.selectbox(
"Model", available_models, index=0, key=f"{key_prefix}_model_select"
)
model_key = self.get_model_key(model_label)
return model_label, model_key
# Global instance for easy access
model_manager = ModelManager()
def get_model_manager() -> ModelManager:
"""Get the global model manager instance."""
return model_manager
def load_model(
model_key: str, device: str = None
) -> Union[DeimHgnetV2MDrone, YOLOModel]:
"""Convenience function to load a model."""
return model_manager.load_model(model_key, device)
def get_device() -> str:
"""Get the current device."""
return model_manager.device
def get_available_models() -> List[str]:
"""Get list of available models."""
return model_manager.get_available_models()
def get_model_key(model_label: str) -> str:
"""Get model key from model label."""
return model_manager.get_model_key(model_label)