File size: 5,880 Bytes
a65508e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Centralized model selection and device management for SAR-X AI application.

This module provides a unified interface for model loading, device management,
and model selection across all pages in the application.
"""

from pathlib import Path
from typing import Dict, List, Tuple, Union

import streamlit as st
import torch
from deim_model import DeimHgnetV2MDrone
from yolo_model import YOLOModel


class ModelManager:
    """Centralized model manager for device and model selection."""

    def __init__(self):
        self.device = self._get_device()
        self.models_dir = Path(__file__).resolve().parent.parent / "models"
        self.model_entries = self._discover_model_entries()

    def _get_device(self) -> str:
        """Determine the best available device (CUDA or CPU)."""
        if torch.cuda.is_available():
            return "cuda"
        return "cpu"

    def _discover_model_entries(self) -> List[Tuple[str, str]]:
        """Discover available models in the models directory."""
        entries: List[Tuple[str, str]] = [("DEIM Model", "deim")]

        if self.models_dir.exists():
            # Only add YOLOv8n base model
            yolov8n_file = self.models_dir / "yolov8n.pt"
            if yolov8n_file.exists():
                entries.append(("YOLOv8n Model", f"yolo:{yolov8n_file.resolve()}"))

        return entries

    def get_available_models(self) -> List[str]:
        """Get list of available model labels."""
        available_labels = []
        for label, key in self.model_entries:
            if key == "deim":
                available_labels.append(label)
                continue
            if not key.startswith("yolo:"):
                continue
            weight_path = Path(key.split(":", 1)[1])
            if weight_path.exists():
                available_labels.append(label)

        if not available_labels:
            available_labels = ["DEIM Model"]

        return available_labels

    def get_model_key(self, model_label: str) -> str:
        """Get model key from model label."""
        label_to_key: Dict[str, str] = {label: key for label, key in self.model_entries}
        return label_to_key.get(model_label, "deim")

    @st.cache_resource
    def load_model(
        _self, model_key: str, device: str = None
    ) -> Union[DeimHgnetV2MDrone, YOLOModel]:
        """Load a model with caching for better performance.

        Args:
            model_key: The model identifier (e.g., "deim" or "yolo:/path/to/model.pt")
            device: Device to load the model on (defaults to auto-detected device)

        Returns:
            Loaded model instance

        """
        if device is None:
            device = _self.device

        if model_key == "deim":
            return DeimHgnetV2MDrone(device=device)
        elif model_key.startswith("yolo:"):
            model_path = model_key.split(":", 1)[1]
            return YOLOModel(model_path)
        else:
            raise ValueError(f"Invalid model key: {model_key}")

    def get_device_info(self) -> Dict[str, str]:
        """Get information about the current device."""
        device_info = {
            "device": self.device,
            "cuda_available": str(torch.cuda.is_available()),
        }

        if torch.cuda.is_available():
            device_info.update(
                {
                    "cuda_device_count": str(torch.cuda.device_count()),
                    "cuda_device_name": torch.cuda.get_device_name(0),
                    "cuda_memory_allocated": f"{torch.cuda.memory_allocated(0) / 1024**3:.2f} GB",
                    "cuda_memory_reserved": f"{torch.cuda.memory_reserved(0) / 1024**3:.2f} GB",
                }
            )

        return device_info

    def render_device_info(self):
        """Render device information in Streamlit sidebar."""
        device_info = self.get_device_info()

        st.sidebar.header("Device Information")

        # Device status
        if device_info["device"] == "cuda":
            st.sidebar.success(f"πŸš€ Using GPU: {device_info['cuda_device_name']}")
            st.sidebar.info(
                f"Memory: {device_info['cuda_memory_allocated']} / {device_info['cuda_memory_reserved']}"
            )
        else:
            st.sidebar.warning("πŸ–₯️ Using CPU")

        # # Show device details in expander
        # with st.sidebar.expander("Device Details"):
        #     for key, value in device_info.items():
        #         st.text(f"{key}: {value}")

    def render_model_selection(self, key_prefix: str = "") -> Tuple[str, str]:
        """Render model selection UI in Streamlit sidebar.

        Args:
            key_prefix: Prefix for Streamlit widget keys to avoid conflicts

        Returns:
            Tuple of (model_label, model_key)

        """
        st.sidebar.subheader("Model Selection")

        available_models = self.get_available_models()
        model_label = st.sidebar.selectbox(
            "Model", available_models, index=0, key=f"{key_prefix}_model_select"
        )
        model_key = self.get_model_key(model_label)

        return model_label, model_key


# Global instance for easy access
model_manager = ModelManager()


def get_model_manager() -> ModelManager:
    """Get the global model manager instance."""
    return model_manager


def load_model(
    model_key: str, device: str = None
) -> Union[DeimHgnetV2MDrone, YOLOModel]:
    """Convenience function to load a model."""
    return model_manager.load_model(model_key, device)


def get_device() -> str:
    """Get the current device."""
    return model_manager.device


def get_available_models() -> List[str]:
    """Get list of available models."""
    return model_manager.get_available_models()


def get_model_key(model_label: str) -> str:
    """Get model key from model label."""
    return model_manager.get_model_key(model_label)