Spaces:
Sleeping
Sleeping
File size: 5,880 Bytes
a65508e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
"""Centralized model selection and device management for SAR-X AI application.
This module provides a unified interface for model loading, device management,
and model selection across all pages in the application.
"""
from pathlib import Path
from typing import Dict, List, Tuple, Union
import streamlit as st
import torch
from deim_model import DeimHgnetV2MDrone
from yolo_model import YOLOModel
class ModelManager:
"""Centralized model manager for device and model selection."""
def __init__(self):
self.device = self._get_device()
self.models_dir = Path(__file__).resolve().parent.parent / "models"
self.model_entries = self._discover_model_entries()
def _get_device(self) -> str:
"""Determine the best available device (CUDA or CPU)."""
if torch.cuda.is_available():
return "cuda"
return "cpu"
def _discover_model_entries(self) -> List[Tuple[str, str]]:
"""Discover available models in the models directory."""
entries: List[Tuple[str, str]] = [("DEIM Model", "deim")]
if self.models_dir.exists():
# Only add YOLOv8n base model
yolov8n_file = self.models_dir / "yolov8n.pt"
if yolov8n_file.exists():
entries.append(("YOLOv8n Model", f"yolo:{yolov8n_file.resolve()}"))
return entries
def get_available_models(self) -> List[str]:
"""Get list of available model labels."""
available_labels = []
for label, key in self.model_entries:
if key == "deim":
available_labels.append(label)
continue
if not key.startswith("yolo:"):
continue
weight_path = Path(key.split(":", 1)[1])
if weight_path.exists():
available_labels.append(label)
if not available_labels:
available_labels = ["DEIM Model"]
return available_labels
def get_model_key(self, model_label: str) -> str:
"""Get model key from model label."""
label_to_key: Dict[str, str] = {label: key for label, key in self.model_entries}
return label_to_key.get(model_label, "deim")
@st.cache_resource
def load_model(
_self, model_key: str, device: str = None
) -> Union[DeimHgnetV2MDrone, YOLOModel]:
"""Load a model with caching for better performance.
Args:
model_key: The model identifier (e.g., "deim" or "yolo:/path/to/model.pt")
device: Device to load the model on (defaults to auto-detected device)
Returns:
Loaded model instance
"""
if device is None:
device = _self.device
if model_key == "deim":
return DeimHgnetV2MDrone(device=device)
elif model_key.startswith("yolo:"):
model_path = model_key.split(":", 1)[1]
return YOLOModel(model_path)
else:
raise ValueError(f"Invalid model key: {model_key}")
def get_device_info(self) -> Dict[str, str]:
"""Get information about the current device."""
device_info = {
"device": self.device,
"cuda_available": str(torch.cuda.is_available()),
}
if torch.cuda.is_available():
device_info.update(
{
"cuda_device_count": str(torch.cuda.device_count()),
"cuda_device_name": torch.cuda.get_device_name(0),
"cuda_memory_allocated": f"{torch.cuda.memory_allocated(0) / 1024**3:.2f} GB",
"cuda_memory_reserved": f"{torch.cuda.memory_reserved(0) / 1024**3:.2f} GB",
}
)
return device_info
def render_device_info(self):
"""Render device information in Streamlit sidebar."""
device_info = self.get_device_info()
st.sidebar.header("Device Information")
# Device status
if device_info["device"] == "cuda":
st.sidebar.success(f"π Using GPU: {device_info['cuda_device_name']}")
st.sidebar.info(
f"Memory: {device_info['cuda_memory_allocated']} / {device_info['cuda_memory_reserved']}"
)
else:
st.sidebar.warning("π₯οΈ Using CPU")
# # Show device details in expander
# with st.sidebar.expander("Device Details"):
# for key, value in device_info.items():
# st.text(f"{key}: {value}")
def render_model_selection(self, key_prefix: str = "") -> Tuple[str, str]:
"""Render model selection UI in Streamlit sidebar.
Args:
key_prefix: Prefix for Streamlit widget keys to avoid conflicts
Returns:
Tuple of (model_label, model_key)
"""
st.sidebar.subheader("Model Selection")
available_models = self.get_available_models()
model_label = st.sidebar.selectbox(
"Model", available_models, index=0, key=f"{key_prefix}_model_select"
)
model_key = self.get_model_key(model_label)
return model_label, model_key
# Global instance for easy access
model_manager = ModelManager()
def get_model_manager() -> ModelManager:
"""Get the global model manager instance."""
return model_manager
def load_model(
model_key: str, device: str = None
) -> Union[DeimHgnetV2MDrone, YOLOModel]:
"""Convenience function to load a model."""
return model_manager.load_model(model_key, device)
def get_device() -> str:
"""Get the current device."""
return model_manager.device
def get_available_models() -> List[str]:
"""Get list of available models."""
return model_manager.get_available_models()
def get_model_key(model_label: str) -> str:
"""Get model key from model label."""
return model_manager.get_model_key(model_label)
|