Spaces:
Running
Running
| """ | |
| Multi-Modal Knowledge Distillation Web Application | |
| A FastAPI-based web application for creating new AI models through knowledge distillation | |
| from multiple pre-trained models across different modalities. | |
| """ | |
| import os | |
| import asyncio | |
| import logging | |
| import uuid | |
| from typing import List, Dict, Any, Optional, Union | |
| from pathlib import Path | |
| import json | |
| import shutil | |
| from datetime import datetime | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, WebSocket, WebSocketDisconnect, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.responses import HTMLResponse, FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| from src.model_loader import ModelLoader | |
| from src.distillation import KnowledgeDistillationTrainer | |
| from src.utils import setup_logging, validate_file, cleanup_temp_files, get_system_info | |
| # Import new core components | |
| from src.core.memory_manager import AdvancedMemoryManager | |
| from src.core.chunk_loader import AdvancedChunkLoader | |
| from src.core.cpu_optimizer import CPUOptimizer | |
| from src.core.token_manager import TokenManager | |
| # Import medical components | |
| from src.medical.medical_datasets import MedicalDatasetManager | |
| from src.medical.dicom_handler import DicomHandler | |
| from src.medical.medical_preprocessing import MedicalPreprocessor | |
| # Import database components | |
| from database.database import DatabaseManager | |
| # Setup logging with error handling | |
| try: | |
| setup_logging() | |
| logger = logging.getLogger(__name__) | |
| except Exception as e: | |
| # Fallback to basic logging if setup fails | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| logger.warning(f"Failed to setup advanced logging: {e}") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Multi-Modal Knowledge Distillation", | |
| description="Create new AI models through knowledge distillation from multiple pre-trained models", | |
| version="2.1.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files and templates | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # Global variables for tracking training sessions | |
| training_sessions: Dict[str, Dict[str, Any]] = {} | |
| active_connections: Dict[str, WebSocket] = {} | |
| def serialize_session_for_websocket(session_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Clean session data for WebSocket JSON serialization | |
| Converts Path objects and other non-serializable types to strings | |
| """ | |
| cleaned_data = {} | |
| for key, value in session_data.items(): | |
| try: | |
| if isinstance(value, Path): | |
| # Convert Path objects to strings | |
| cleaned_data[key] = str(value) | |
| elif isinstance(value, (list, tuple)): | |
| # Clean lists/tuples recursively | |
| cleaned_data[key] = [ | |
| str(item) if isinstance(item, Path) else | |
| serialize_session_for_websocket(item) if isinstance(item, dict) else | |
| item for item in value | |
| ] | |
| elif isinstance(value, dict): | |
| # Clean nested dictionaries recursively | |
| cleaned_data[key] = serialize_session_for_websocket(value) | |
| elif hasattr(value, '__dict__') and not isinstance(value, (str, int, float, bool, type(None))): | |
| # Convert complex objects to string representation | |
| cleaned_data[key] = str(value) | |
| else: | |
| # Keep simple types as-is | |
| cleaned_data[key] = value | |
| except Exception as e: | |
| # If anything fails, convert to string | |
| logger.warning(f"Error serializing session key '{key}': {e}") | |
| cleaned_data[key] = str(value) if value is not None else None | |
| return cleaned_data | |
| # Pydantic models for API | |
| class TrainingConfig(BaseModel): | |
| session_id: str = Field(..., description="Unique session identifier") | |
| teacher_models: List[Union[str, Dict[str, Any]]] = Field(..., description="List of teacher model paths/URLs or model configs") | |
| student_config: Dict[str, Any] = Field(default_factory=dict, description="Student model configuration") | |
| training_params: Dict[str, Any] = Field(default_factory=dict, description="Training parameters") | |
| distillation_strategy: str = Field(default="ensemble", description="Distillation strategy") | |
| hf_token: Optional[str] = Field(default=None, description="Hugging Face token") | |
| trust_remote_code: bool = Field(default=False, description="Trust remote code execution") | |
| existing_student_model: Optional[str] = Field(default=None, description="Path to existing trained student model for retraining") | |
| incremental_training: bool = Field(default=False, description="Whether this is incremental training") | |
| class TrainingStatus(BaseModel): | |
| session_id: str | |
| status: str | |
| progress: float | |
| current_step: int | |
| total_steps: int | |
| loss: Optional[float] = None | |
| eta: Optional[str] = None | |
| message: str = "" | |
| class ModelInfo(BaseModel): | |
| name: str | |
| size: int | |
| format: str | |
| modality: str | |
| architecture: Optional[str] = None | |
| # Initialize components | |
| model_loader = ModelLoader() | |
| distillation_trainer = KnowledgeDistillationTrainer() | |
| # Initialize new advanced components | |
| memory_manager = AdvancedMemoryManager(max_memory_gb=14.0) # 14GB for 16GB systems | |
| chunk_loader = AdvancedChunkLoader(memory_manager) | |
| cpu_optimizer = CPUOptimizer(memory_manager) | |
| token_manager = TokenManager() | |
| database_manager = DatabaseManager() | |
| # Initialize medical components | |
| medical_dataset_manager = MedicalDatasetManager(memory_manager) | |
| dicom_handler = DicomHandler(memory_limit_mb=1000.0) | |
| medical_preprocessor = MedicalPreprocessor() | |
| async def startup_event(): | |
| """Initialize application on startup""" | |
| logger.info("Starting Multi-Modal Knowledge Distillation application") | |
| # Create necessary directories with error handling | |
| for directory in ["uploads", "models", "temp", "logs"]: | |
| try: | |
| Path(directory).mkdir(exist_ok=True) | |
| logger.info(f"Created/verified directory: {directory}") | |
| except PermissionError: | |
| logger.warning(f"Cannot create directory {directory}, using temp directory") | |
| except Exception as e: | |
| logger.warning(f"Error creating directory {directory}: {e}") | |
| # Log system information | |
| try: | |
| system_info = get_system_info() | |
| logger.info(f"System info: {system_info}") | |
| except Exception as e: | |
| logger.warning(f"Could not get system info: {e}") | |
| async def shutdown_event(): | |
| """Cleanup on application shutdown""" | |
| logger.info("Shutting down application") | |
| cleanup_temp_files() | |
| async def read_root(): | |
| """Serve the main web interface""" | |
| return templates.TemplateResponse("index.html", {"request": {}}) | |
| async def health_check(): | |
| """Health check endpoint for Docker and monitoring""" | |
| try: | |
| # Get system information | |
| memory_info = memory_manager.get_memory_info() | |
| # Check if default token is available | |
| default_token = token_manager.get_token() | |
| return { | |
| "status": "healthy", | |
| "version": "2.0.0", | |
| "timestamp": datetime.now().isoformat(), | |
| "memory": { | |
| "usage_percent": memory_info.get("process_memory_percent", 0), | |
| "available_gb": memory_info.get("system_memory_available_gb", 0), | |
| "status": memory_manager.check_memory_status() | |
| }, | |
| "tokens": { | |
| "default_available": bool(default_token), | |
| "total_tokens": len(token_manager.list_tokens()) | |
| }, | |
| "features": { | |
| "memory_management": True, | |
| "chunk_loading": True, | |
| "cpu_optimization": True, | |
| "medical_datasets": True, | |
| "token_management": True | |
| }, | |
| "system_info": get_system_info() | |
| } | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| return { | |
| "status": "unhealthy", | |
| "error": str(e), | |
| "timestamp": datetime.now().isoformat(), | |
| "version": "2.0.0" | |
| } | |
| async def test_token(): | |
| """Test if HF token is working""" | |
| hf_token = ( | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_TOKEN') or | |
| os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| ) | |
| if not hf_token: | |
| return { | |
| "token_available": False, | |
| "message": "No HF token found in environment variables" | |
| } | |
| try: | |
| # Test token by trying to access a gated model's config | |
| from transformers import AutoConfig | |
| config = AutoConfig.from_pretrained("google/gemma-2b", token=hf_token) | |
| return { | |
| "token_available": True, | |
| "token_valid": True, | |
| "message": "Token is working correctly" | |
| } | |
| except Exception as e: | |
| return { | |
| "token_available": True, | |
| "token_valid": False, | |
| "message": f"Token validation failed: {str(e)}" | |
| } | |
| async def test_model_loading(request: Dict[str, Any]): | |
| """Test loading a specific model""" | |
| try: | |
| model_path = request.get('model_path') | |
| trust_remote_code = request.get('trust_remote_code', False) | |
| if not model_path: | |
| return {"success": False, "error": "model_path is required"} | |
| # Get appropriate token based on access type | |
| access_type = request.get('access_type', 'read') | |
| hf_token = request.get('token') | |
| if not hf_token or hf_token == 'auto': | |
| # Get appropriate token for the access type | |
| hf_token = token_manager.get_token_for_task(access_type) | |
| if hf_token: | |
| logger.info(f"Using {access_type} token for model testing") | |
| else: | |
| logger.warning(f"No suitable token found for {access_type} access") | |
| # Fallback to environment variables | |
| hf_token = ( | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_TOKEN') or | |
| os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| ) | |
| # Test model loading | |
| model_info = await model_loader.get_model_info(model_path) | |
| return { | |
| "success": True, | |
| "model_info": model_info, | |
| "message": f"Model {model_path} can be loaded" | |
| } | |
| except Exception as e: | |
| error_msg = str(e) | |
| suggestions = [] | |
| if 'trust_remote_code' in error_msg.lower(): | |
| suggestions.append("فعّل 'Trust Remote Code' للنماذج التي تتطلب كود مخصص") | |
| elif 'gated' in error_msg.lower(): | |
| suggestions.append("النموذج يتطلب إذن وصول خاص - استخدم رمز مخصص") | |
| elif 'siglip' in error_msg.lower(): | |
| suggestions.append("جرب تفعيل 'Trust Remote Code' لنماذج SigLIP") | |
| elif '401' in error_msg or 'authentication' in error_msg.lower(): | |
| suggestions.append("تحقق من رمز Hugging Face الخاص بك") | |
| suggestions.append("تأكد من أن الرمز له صلاحية الوصول لهذا النموذج") | |
| elif '404' in error_msg or 'not found' in error_msg.lower(): | |
| suggestions.append("تحقق من اسم مستودع النموذج") | |
| suggestions.append("تأكد من وجود النموذج على Hugging Face") | |
| return { | |
| "success": False, | |
| "error": error_msg, | |
| "suggestions": suggestions | |
| } | |
| async def upload_model( | |
| background_tasks: BackgroundTasks, | |
| files: List[UploadFile] = File(...), | |
| model_names: List[str] = Form(...) | |
| ): | |
| """Upload model files""" | |
| try: | |
| uploaded_models = [] | |
| for file, name in zip(files, model_names): | |
| # Validate file | |
| validation_result = validate_file(file) | |
| if not validation_result["valid"]: | |
| raise HTTPException(status_code=400, detail=validation_result["error"]) | |
| # Generate unique filename | |
| file_id = str(uuid.uuid4()) | |
| file_extension = Path(file.filename).suffix | |
| safe_filename = f"{file_id}{file_extension}" | |
| file_path = Path("uploads") / safe_filename | |
| # Save file | |
| with open(file_path, "wb") as buffer: | |
| content = await file.read() | |
| buffer.write(content) | |
| # Get model info | |
| model_info = await model_loader.get_model_info(str(file_path)) | |
| uploaded_models.append({ | |
| "id": file_id, | |
| "name": name, | |
| "filename": file.filename, | |
| "path": str(file_path), | |
| "size": len(content), | |
| "info": model_info | |
| }) | |
| logger.info(f"Uploaded model: {name} ({file.filename})") | |
| # Schedule cleanup of old files | |
| background_tasks.add_task(cleanup_temp_files, max_age_hours=24) | |
| return { | |
| "success": True, | |
| "models": uploaded_models, | |
| "message": f"Successfully uploaded {len(uploaded_models)} model(s)" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error uploading models: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def start_training( | |
| background_tasks: BackgroundTasks, | |
| config: TrainingConfig | |
| ): | |
| """Start knowledge distillation training""" | |
| try: | |
| session_id = config.session_id | |
| # Handle existing sessions intelligently | |
| if session_id in training_sessions: | |
| existing_session = training_sessions[session_id] | |
| status = existing_session.get("status", "unknown") | |
| # If session is completed or failed, allow reuse by cleaning it up | |
| if status in ["completed", "failed"]: | |
| logger.info(f"Cleaning up previous session {session_id} with status: {status}") | |
| del training_sessions[session_id] | |
| # Also clean up WebSocket connection if exists | |
| if session_id in active_connections: | |
| try: | |
| await active_connections[session_id].close() | |
| except: | |
| pass | |
| del active_connections[session_id] | |
| else: | |
| # Session is still active | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Training session already exists with status: {status}. Please wait for completion or use a different session ID." | |
| ) | |
| # Set HF token from environment if available | |
| hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_TOKEN') | |
| if hf_token: | |
| os.environ['HF_TOKEN'] = hf_token | |
| logger.info("Using Hugging Face token from environment") | |
| # Check for large models and warn | |
| large_models = [] | |
| for model_info in config.teacher_models: | |
| model_path = model_info if isinstance(model_info, str) else model_info.get('path', '') | |
| if any(size_indicator in model_path.lower() for size_indicator in ['27b', '70b', '13b']): | |
| large_models.append(model_path) | |
| # Initialize training session | |
| training_sessions[session_id] = { | |
| "status": "initializing", | |
| "progress": 0.0, | |
| "current_step": 0, | |
| "total_steps": config.training_params.get("max_steps", 1000), | |
| "config": config.dict(), | |
| "start_time": None, | |
| "end_time": None, | |
| "model_path": None, | |
| "logs": [], | |
| "large_models": large_models, | |
| "message": "Initializing training session..." + ( | |
| f" (Large models detected: {', '.join(large_models)})" if large_models else "" | |
| ) | |
| } | |
| # Start training in background | |
| background_tasks.add_task(run_training, session_id, config) | |
| logger.info(f"Started training session: {session_id}") | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "message": "Training started successfully" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error starting training: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def run_training(session_id: str, config: TrainingConfig): | |
| """Run knowledge distillation training in background""" | |
| try: | |
| session = training_sessions[session_id] | |
| session["status"] = "running" | |
| session["start_time"] = asyncio.get_event_loop().time() | |
| # Set timeout for the entire operation (30 minutes) | |
| timeout_seconds = 30 * 60 | |
| # Set HF token for this session - prioritize config token | |
| config_token = getattr(config, 'hf_token', None) | |
| env_token = ( | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_TOKEN') or | |
| os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| ) | |
| hf_token = config_token or env_token | |
| if hf_token: | |
| logger.info(f"Using Hugging Face token from {'config' if config_token else 'environment'}") | |
| # Set token in environment for this session | |
| os.environ['HF_TOKEN'] = hf_token | |
| else: | |
| logger.warning("No Hugging Face token found - private models may fail") | |
| # Handle existing student model for incremental training | |
| existing_student = None | |
| if config.existing_student_model and config.incremental_training: | |
| try: | |
| await update_training_status(session_id, "loading_student", 0.05, "Loading existing student model...") | |
| # Determine student source and load accordingly | |
| student_source = getattr(config, 'student_source', 'local') | |
| student_path = config.existing_student_model | |
| if student_source == 'huggingface' or ('/' in student_path and not Path(student_path).exists()): | |
| logger.info(f"Loading student model from Hugging Face: {student_path}") | |
| existing_student = await model_loader.load_trained_student(student_path) | |
| elif student_source == 'space': | |
| logger.info(f"Loading student model from Hugging Face Space: {student_path}") | |
| # For spaces, we'll try to load from the space's models directory | |
| space_model_path = f"spaces/{student_path}/models" | |
| existing_student = await model_loader.load_trained_student_from_space(student_path) | |
| else: | |
| logger.info(f"Loading student model from local path: {student_path}") | |
| existing_student = await model_loader.load_trained_student(student_path) | |
| logger.info(f"Successfully loaded existing student model: {existing_student.get('type', 'unknown')}") | |
| # Merge original teachers with new teachers | |
| original_teachers = existing_student.get('original_teachers', []) | |
| new_teachers = [ | |
| model_info if isinstance(model_info, str) else model_info.get('path', '') | |
| for model_info in config.teacher_models | |
| ] | |
| # Combine teachers (avoid duplicates) | |
| all_teachers = original_teachers.copy() | |
| for teacher in new_teachers: | |
| if teacher not in all_teachers: | |
| all_teachers.append(teacher) | |
| logger.info(f"Incremental training: Original teachers: {original_teachers}") | |
| logger.info(f"Incremental training: New teachers: {new_teachers}") | |
| logger.info(f"Incremental training: All teachers: {all_teachers}") | |
| # Update config with all teachers | |
| config.teacher_models = all_teachers | |
| except Exception as e: | |
| logger.error(f"Error loading existing student model: {e}") | |
| await update_training_status(session_id, "failed", session.get("progress", 0), f"Failed to load existing student: {str(e)}") | |
| return | |
| # Load teacher models | |
| await update_training_status(session_id, "loading_models", 0.1, "Loading teacher models...") | |
| teacher_models = [] | |
| trust_remote_code = config.training_params.get('trust_remote_code', False) | |
| total_models = len(config.teacher_models) | |
| for i, model_info in enumerate(config.teacher_models): | |
| try: | |
| # Handle both old format (string) and new format (dict) | |
| if isinstance(model_info, str): | |
| model_path = model_info | |
| model_token = hf_token | |
| model_trust_code = trust_remote_code | |
| else: | |
| model_path = model_info.get('path', model_info) | |
| model_token = model_info.get('token') or hf_token | |
| model_trust_code = model_info.get('trust_remote_code', trust_remote_code) | |
| # Update progress | |
| progress = 0.1 + (i * 0.3 / total_models) # 0.1 to 0.4 | |
| await update_training_status( | |
| session_id, | |
| "loading_models", | |
| progress, | |
| f"Loading model {i+1}/{total_models}: {model_path}..." | |
| ) | |
| logger.info(f"Loading model {model_path} with trust_remote_code={model_trust_code}") | |
| # Special handling for known problematic models | |
| if model_path == 'Wan-AI/Wan2.2-TI2V-5B': | |
| logger.info(f"Detected ti2v model {model_path}, forcing trust_remote_code=True") | |
| model_trust_code = True | |
| elif model_path == 'deepseek-ai/DeepSeek-V3.1-Base': | |
| logger.warning(f"Skipping {model_path}: Requires GPU with FP8 quantization support") | |
| await update_training_status( | |
| session_id, | |
| "loading_models", | |
| progress, | |
| f"Skipping {model_path}: Requires GPU with FP8 quantization" | |
| ) | |
| continue | |
| model = await model_loader.load_model( | |
| model_path, | |
| token=model_token, | |
| trust_remote_code=model_trust_code | |
| ) | |
| teacher_models.append(model) | |
| logger.info(f"Successfully loaded model: {model_path}") | |
| # Update progress after successful load | |
| progress = 0.1 + ((i + 1) * 0.3 / total_models) | |
| await update_training_status( | |
| session_id, | |
| "loading_models", | |
| progress, | |
| f"Loaded {i+1}/{total_models} models successfully" | |
| ) | |
| except Exception as e: | |
| error_msg = f"Failed to load model {model_path}: {str(e)}" | |
| logger.error(error_msg) | |
| # Provide helpful suggestions based on the error | |
| suggestions = [] | |
| error_str = str(e).lower() | |
| # Check if we should retry with trust_remote_code=True | |
| if not model_trust_code and ('ti2v' in error_str or 'does not recognize this architecture' in error_str): | |
| try: | |
| logger.info(f"Retrying {model_path} with trust_remote_code=True") | |
| await update_training_status( | |
| session_id, | |
| "loading_models", | |
| progress, | |
| f"Retrying {model_path} with trust_remote_code=True..." | |
| ) | |
| model = await model_loader.load_model( | |
| model_path, | |
| token=model_token, | |
| trust_remote_code=True | |
| ) | |
| teacher_models.append(model) | |
| logger.info(f"Successfully loaded model on retry: {model_path}") | |
| # Update progress after successful retry | |
| progress = 0.1 + ((i + 1) * 0.3 / total_models) | |
| await update_training_status( | |
| session_id, | |
| "loading_models", | |
| progress, | |
| f"Loaded {i+1}/{total_models} models successfully (retry)" | |
| ) | |
| continue | |
| except Exception as retry_e: | |
| logger.error(f"Retry also failed for {model_path}: {str(retry_e)}") | |
| error_msg = f"Failed even with trust_remote_code=True: {str(retry_e)}" | |
| if 'trust_remote_code' in error_str: | |
| suggestions.append("Try enabling 'Trust Remote Code' option") | |
| elif 'gated' in error_str or 'access' in error_str: | |
| suggestions.append("This model requires access permission and a valid HF token") | |
| elif 'siglip' in error_str or 'unknown' in error_str: | |
| suggestions.append("This model may require special loading. Try enabling 'Trust Remote Code'") | |
| elif 'connection' in error_str or 'network' in error_str: | |
| suggestions.append("Check your internet connection") | |
| elif 'ti2v' in error_str: | |
| suggestions.append("This ti2v model requires trust_remote_code=True") | |
| if suggestions: | |
| error_msg += f". Suggestions: {'; '.join(suggestions)}" | |
| await update_training_status(session_id, "failed", session.get("progress", 0), error_msg) | |
| return | |
| # Initialize student model | |
| await update_training_status(session_id, "initializing_student", 0.2, "Initializing student model...") | |
| student_model = await distillation_trainer.create_student_model( | |
| teacher_models, config.student_config | |
| ) | |
| # Run distillation training | |
| await update_training_status(session_id, "training", 0.3, "Starting knowledge distillation...") | |
| async def progress_callback(step: int, total_steps: int, loss: float, metrics: Dict[str, Any]): | |
| progress = 0.3 + (step / total_steps) * 0.6 # 30% to 90% | |
| await update_training_status( | |
| session_id, "training", progress, | |
| f"Training step {step}/{total_steps}, Loss: {loss:.4f}", | |
| current_step=step, loss=loss | |
| ) | |
| trained_model = await distillation_trainer.train( | |
| student_model, teacher_models, config.training_params, progress_callback | |
| ) | |
| # Save trained model with metadata | |
| await update_training_status(session_id, "saving", 0.9, "Saving trained model...") | |
| # Create model directory with proper structure | |
| model_dir = Path("models") / f"distilled_model_{session_id}" | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| model_path = model_dir / "pytorch_model.safetensors" | |
| # Prepare training metadata for saving | |
| training_metadata = { | |
| 'session_id': session_id, | |
| 'teacher_models': [ | |
| model_info if isinstance(model_info, str) else model_info.get('path', '') | |
| for model_info in config.teacher_models | |
| ], | |
| 'strategy': config.distillation_strategy, | |
| 'training_params': config.training_params, | |
| 'incremental_training': config.incremental_training, | |
| 'existing_student_model': config.existing_student_model | |
| } | |
| await distillation_trainer.save_model(trained_model, str(model_path), training_metadata) | |
| # Complete training | |
| session["status"] = "completed" | |
| session["progress"] = 1.0 | |
| session["end_time"] = asyncio.get_event_loop().time() | |
| session["model_path"] = model_path | |
| session["training_metadata"] = training_metadata | |
| await update_training_status(session_id, "completed", 1.0, "Training completed successfully!") | |
| logger.info(f"Training session {session_id} completed successfully") | |
| except Exception as e: | |
| logger.error(f"Training session {session_id} failed: {str(e)}") | |
| session = training_sessions.get(session_id, {}) | |
| session["status"] = "failed" | |
| session["error"] = str(e) | |
| await update_training_status(session_id, "failed", session.get("progress", 0), f"Training failed: {str(e)}") | |
| async def update_training_status( | |
| session_id: str, | |
| status: str, | |
| progress: float, | |
| message: str, | |
| current_step: int = None, | |
| loss: float = None | |
| ): | |
| """Update training status and notify connected clients""" | |
| if session_id in training_sessions: | |
| session = training_sessions[session_id] | |
| session["status"] = status | |
| session["progress"] = progress | |
| session["message"] = message | |
| if current_step is not None: | |
| session["current_step"] = current_step | |
| if loss is not None: | |
| session["loss"] = loss | |
| # Calculate ETA | |
| if session.get("start_time") and progress > 0: | |
| elapsed = asyncio.get_event_loop().time() - session["start_time"] | |
| if progress < 1.0: | |
| eta_seconds = (elapsed / progress) * (1.0 - progress) | |
| eta = f"{int(eta_seconds // 60)}m {int(eta_seconds % 60)}s" | |
| session["eta"] = eta | |
| # Notify WebSocket clients with cleaned data | |
| if session_id in active_connections: | |
| try: | |
| # Clean session data for JSON serialization | |
| clean_session_data = serialize_session_for_websocket(session) | |
| await active_connections[session_id].send_json({ | |
| "type": "training_update", | |
| "data": clean_session_data | |
| }) | |
| except Exception as ws_error: | |
| logger.warning(f"WebSocket error for session {session_id}: {ws_error}") | |
| # Remove disconnected client | |
| if session_id in active_connections: | |
| del active_connections[session_id] | |
| async def get_training_progress(session_id: str): | |
| """Get training progress for a session""" | |
| if session_id not in training_sessions: | |
| raise HTTPException(status_code=404, detail="Training session not found") | |
| session = training_sessions[session_id] | |
| return TrainingStatus( | |
| session_id=session_id, | |
| status=session["status"], | |
| progress=session["progress"], | |
| current_step=session["current_step"], | |
| total_steps=session["total_steps"], | |
| loss=session.get("loss"), | |
| eta=session.get("eta"), | |
| message=session.get("message", "") | |
| ) | |
| async def download_model(session_id: str): | |
| """Download trained model""" | |
| try: | |
| if session_id not in training_sessions: | |
| raise HTTPException(status_code=404, detail="Training session not found") | |
| session = training_sessions[session_id] | |
| if session["status"] != "completed": | |
| raise HTTPException(status_code=400, detail="Training not completed") | |
| model_path = session.get("model_path") | |
| if not model_path: | |
| # Try to find model in models directory | |
| models_dir = Path("models") | |
| possible_paths = [ | |
| models_dir / f"distilled_model_{session_id}", | |
| models_dir / f"distilled_model_{session_id}.safetensors", | |
| models_dir / f"model_{session_id}", | |
| models_dir / f"student_model_{session_id}" | |
| ] | |
| for path in possible_paths: | |
| if path.exists(): | |
| model_path = str(path) | |
| break | |
| if not model_path or not Path(model_path).exists(): | |
| raise HTTPException(status_code=404, detail="Model file not found. The model may not have been saved properly.") | |
| # Create a zip file with all model files | |
| import zipfile | |
| import tempfile | |
| model_dir = Path(model_path) | |
| if model_dir.is_file(): | |
| # Single file | |
| return FileResponse( | |
| model_path, | |
| media_type="application/octet-stream", | |
| filename=f"distilled_model_{session_id}.safetensors" | |
| ) | |
| else: | |
| # Directory with multiple files | |
| temp_zip = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') | |
| with zipfile.ZipFile(temp_zip.name, 'w') as zipf: | |
| for file_path in model_dir.rglob('*'): | |
| if file_path.is_file(): | |
| zipf.write(file_path, file_path.relative_to(model_dir)) | |
| return FileResponse( | |
| temp_zip.name, | |
| media_type="application/zip", | |
| filename=f"distilled_model_{session_id}.zip" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error downloading model: {e}") | |
| raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") | |
| async def upload_to_huggingface( | |
| session_id: str, | |
| repo_name: str = Form(...), | |
| description: str = Form(""), | |
| private: bool = Form(False), | |
| hf_token: str = Form(...) | |
| ): | |
| """Upload trained model to Hugging Face Hub""" | |
| try: | |
| if session_id not in training_sessions: | |
| raise HTTPException(status_code=404, detail="Training session not found") | |
| session = training_sessions[session_id] | |
| if session["status"] != "completed": | |
| raise HTTPException(status_code=400, detail="Training not completed") | |
| model_path = session.get("model_path") | |
| if not model_path or not Path(model_path).exists(): | |
| raise HTTPException(status_code=404, detail="Model file not found") | |
| # Import huggingface_hub | |
| try: | |
| from huggingface_hub import HfApi, create_repo | |
| except ImportError: | |
| raise HTTPException(status_code=500, detail="huggingface_hub not installed") | |
| # Initialize HF API | |
| api = HfApi(token=hf_token) | |
| # Validate repository name format | |
| if '/' not in repo_name: | |
| raise HTTPException(status_code=400, detail="Repository name must be in format 'username/model-name'") | |
| username, model_name = repo_name.split('/', 1) | |
| # Create repository with better error handling | |
| try: | |
| repo_url = create_repo( | |
| repo_id=repo_name, | |
| token=hf_token, | |
| private=private, | |
| exist_ok=True | |
| ) | |
| logger.info(f"Created/accessed repository: {repo_url}") | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "403" in error_msg or "Forbidden" in error_msg: | |
| raise HTTPException( | |
| status_code=403, | |
| detail=f"Permission denied. Please check: 1) Your token has 'Write' permissions, 2) You own the namespace '{username}', 3) The repository name is correct. Error: {error_msg}" | |
| ) | |
| elif "401" in error_msg or "Unauthorized" in error_msg: | |
| raise HTTPException( | |
| status_code=401, | |
| detail=f"Invalid token. Please check your Hugging Face token. Error: {error_msg}" | |
| ) | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Failed to create repository: {error_msg}") | |
| # Upload model files | |
| model_path_obj = Path(model_path) | |
| uploaded_files = [] | |
| # Determine the model directory | |
| if model_path_obj.is_file(): | |
| model_dir = model_path_obj.parent | |
| else: | |
| model_dir = model_path_obj | |
| # Upload all files in the model directory | |
| essential_files = [ | |
| 'pytorch_model.safetensors', 'config.json', 'model.py', | |
| 'training_history.json', 'README.md' | |
| ] | |
| # Upload essential files first | |
| for file_name in essential_files: | |
| file_path = model_dir / file_name | |
| if file_path.exists(): | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=str(file_path), | |
| path_in_repo=file_name, | |
| repo_id=repo_name, | |
| token=hf_token | |
| ) | |
| uploaded_files.append(file_name) | |
| logger.info(f"Uploaded {file_name}") | |
| except Exception as e: | |
| logger.warning(f"Failed to upload {file_name}: {e}") | |
| # Upload any additional files | |
| for file_path in model_dir.rglob('*'): | |
| if file_path.is_file() and file_path.name not in essential_files: | |
| try: | |
| relative_path = file_path.relative_to(model_dir) | |
| api.upload_file( | |
| path_or_fileobj=str(file_path), | |
| path_in_repo=str(relative_path), | |
| repo_id=repo_name, | |
| token=hf_token | |
| ) | |
| uploaded_files.append(str(relative_path)) | |
| logger.info(f"Uploaded additional file: {relative_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to upload {relative_path}: {e}") | |
| # Create README.md | |
| config_info = session.get("config", {}) | |
| teacher_models_raw = config_info.get("teacher_models", []) | |
| # Extract model paths from teacher_models (handle both string and dict formats) | |
| teacher_models = [] | |
| for model in teacher_models_raw: | |
| if isinstance(model, str): | |
| teacher_models.append(model) | |
| elif isinstance(model, dict): | |
| teacher_models.append(model.get('path', str(model))) | |
| else: | |
| teacher_models.append(str(model)) | |
| readme_content = f"""--- | |
| license: apache-2.0 | |
| tags: | |
| - knowledge-distillation | |
| - pytorch | |
| - transformers | |
| base_model: {teacher_models[0] if teacher_models else 'unknown'} | |
| --- | |
| # {repo_name} | |
| This model was created using knowledge distillation from the following teacher model(s): | |
| {chr(10).join([f"- {model}" for model in teacher_models])} | |
| ## Model Description | |
| {description if description else 'A distilled model created using multi-modal knowledge distillation.'} | |
| ## Training Details | |
| - **Teacher Models**: {', '.join(teacher_models)} | |
| - **Distillation Strategy**: {config_info.get('distillation_strategy', 'ensemble')} | |
| - **Training Steps**: {config_info.get('training_params', {}).get('max_steps', 'unknown')} | |
| - **Learning Rate**: {config_info.get('training_params', {}).get('learning_rate', 'unknown')} | |
| ## Usage | |
| ```python | |
| from transformers import AutoModel, AutoTokenizer | |
| model = AutoModel.from_pretrained("{repo_name}") | |
| tokenizer = AutoTokenizer.from_pretrained("{teacher_models[0] if teacher_models else 'bert-base-uncased'}") | |
| ``` | |
| ## Created with | |
| This model was created using the Multi-Modal Knowledge Distillation platform. | |
| """ | |
| # Upload README | |
| api.upload_file( | |
| path_or_fileobj=readme_content.encode(), | |
| path_in_repo="README.md", | |
| repo_id=repo_name, | |
| token=hf_token | |
| ) | |
| uploaded_files.append("README.md") | |
| return { | |
| "success": True, | |
| "repo_url": f"https://huggingface.co/{repo_name}", | |
| "uploaded_files": uploaded_files, | |
| "message": f"Model successfully uploaded to {repo_name}" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error uploading to Hugging Face: {e}") | |
| raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") | |
| async def validate_repo_name(request: Dict[str, Any]): | |
| """Validate repository name and check permissions""" | |
| try: | |
| repo_name = request.get('repo_name', '').strip() | |
| hf_token = request.get('hf_token', '').strip() | |
| if not repo_name or not hf_token: | |
| return {"valid": False, "error": "Repository name and token are required"} | |
| if '/' not in repo_name: | |
| return {"valid": False, "error": "Repository name must be in format 'username/model-name'"} | |
| username, model_name = repo_name.split('/', 1) | |
| # Check if username matches token owner | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=hf_token) | |
| # Try to get user info | |
| user_info = api.whoami() | |
| token_username = user_info.get('name', '') | |
| if username != token_username: | |
| return { | |
| "valid": False, | |
| "error": f"Username mismatch. Token belongs to '{token_username}' but trying to create repo under '{username}'. Use '{token_username}/{model_name}' instead.", | |
| "suggested_name": f"{token_username}/{model_name}" | |
| } | |
| return { | |
| "valid": True, | |
| "message": f"Repository name '{repo_name}' is valid for your account", | |
| "username": token_username | |
| } | |
| except Exception as e: | |
| return {"valid": False, "error": f"Token validation failed: {str(e)}"} | |
| except Exception as e: | |
| return {"valid": False, "error": f"Validation error: {str(e)}"} | |
| async def test_space(request: Dict[str, Any]): | |
| """Test if a Hugging Face Space exists and has trained models""" | |
| try: | |
| space_name = request.get('space_name', '').strip() | |
| hf_token = request.get('hf_token', '').strip() | |
| if not space_name: | |
| return {"success": False, "error": "Space name is required"} | |
| if '/' not in space_name: | |
| return {"success": False, "error": "Space name must be in format 'username/space-name'"} | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=hf_token if hf_token else None) | |
| # Check if the Space exists | |
| try: | |
| space_info = api.space_info(space_name) | |
| logger.info(f"Found Space: {space_name}") | |
| except Exception as e: | |
| return {"success": False, "error": f"Space not found or not accessible: {str(e)}"} | |
| # Try to list files in the Space to see if it has models | |
| try: | |
| files = api.list_repo_files(space_name, repo_type="space") | |
| model_files = [f for f in files if f.endswith(('.safetensors', '.bin', '.pt'))] | |
| # Check for models directory | |
| models_dir_files = [f for f in files if f.startswith('models/')] | |
| return { | |
| "success": True, | |
| "space_info": { | |
| "name": space_name, | |
| "model_files": model_files, | |
| "models_directory": len(models_dir_files) > 0, | |
| "total_files": len(files) | |
| }, | |
| "models": model_files, | |
| "message": f"Space {space_name} is accessible" | |
| } | |
| except Exception as e: | |
| # Space exists but we can't list files (might be private or no access) | |
| return { | |
| "success": True, | |
| "space_info": {"name": space_name}, | |
| "models": [], | |
| "message": f"Space {space_name} exists but file listing not available (might be private)" | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": f"Error accessing Hugging Face: {str(e)}"} | |
| except Exception as e: | |
| logger.error(f"Error testing Space: {e}") | |
| return {"success": False, "error": f"Test failed: {str(e)}"} | |
| async def list_trained_students(): | |
| """List available trained student models for retraining""" | |
| try: | |
| models_dir = Path("models") | |
| trained_students = [] | |
| if models_dir.exists(): | |
| for model_dir in models_dir.iterdir(): | |
| if model_dir.is_dir(): | |
| try: | |
| # Check if it's a trained student model | |
| config_files = list(model_dir.glob("*config.json")) | |
| history_files = list(model_dir.glob("*training_history.json")) | |
| if config_files: | |
| with open(config_files[0], 'r') as f: | |
| config = json.load(f) | |
| if config.get('is_student_model', False): | |
| history = {} | |
| if history_files: | |
| with open(history_files[0], 'r') as f: | |
| history = json.load(f) | |
| model_info = { | |
| "id": model_dir.name, | |
| "name": model_dir.name, | |
| "path": str(model_dir), | |
| "type": "trained_student", | |
| "created_at": config.get('created_at', 'unknown'), | |
| "architecture": config.get('architecture', 'unknown'), | |
| "modalities": config.get('modalities', ['text']), | |
| "can_be_retrained": config.get('can_be_retrained', True), | |
| "original_teachers": history.get('retraining_info', {}).get('original_teachers', []), | |
| "training_sessions": len(history.get('training_sessions', [])), | |
| "last_training": history.get('training_sessions', [{}])[-1].get('timestamp', 'unknown') if history.get('training_sessions') else 'unknown' | |
| } | |
| trained_students.append(model_info) | |
| except Exception as e: | |
| logger.warning(f"Error reading model {model_dir}: {e}") | |
| continue | |
| return {"trained_students": trained_students} | |
| except Exception as e: | |
| logger.error(f"Error listing trained students: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_models(): | |
| """List available models""" | |
| models = [] | |
| # List uploaded models | |
| uploads_dir = Path("uploads") | |
| if uploads_dir.exists(): | |
| for file_path in uploads_dir.iterdir(): | |
| if file_path.is_file(): | |
| try: | |
| info = await model_loader.get_model_info(str(file_path)) | |
| models.append(ModelInfo( | |
| name=file_path.stem, | |
| size=file_path.stat().st_size, | |
| format=file_path.suffix[1:], | |
| modality=info.get("modality", "unknown"), | |
| architecture=info.get("architecture") | |
| )) | |
| except Exception as e: | |
| logger.warning(f"Error getting info for {file_path}: {e}") | |
| return models | |
| async def websocket_endpoint(websocket: WebSocket, session_id: str): | |
| """WebSocket endpoint for real-time training updates""" | |
| await websocket.accept() | |
| active_connections[session_id] = websocket | |
| try: | |
| # Send current status if session exists | |
| if session_id in training_sessions: | |
| clean_session_data = serialize_session_for_websocket(training_sessions[session_id]) | |
| await websocket.send_json({ | |
| "type": "training_update", | |
| "data": clean_session_data | |
| }) | |
| # Keep connection alive | |
| while True: | |
| await websocket.receive_text() | |
| except WebSocketDisconnect: | |
| if session_id in active_connections: | |
| del active_connections[session_id] | |
| except Exception as e: | |
| logger.error(f"WebSocket error for session {session_id}: {e}") | |
| if session_id in active_connections: | |
| del active_connections[session_id] | |
| # ==================== NEW ADVANCED ENDPOINTS ==================== | |
| # Session Management Endpoints | |
| async def list_training_sessions(): | |
| """List all training sessions with their status""" | |
| try: | |
| sessions_info = [] | |
| for session_id, session_data in training_sessions.items(): | |
| session_info = { | |
| "session_id": session_id, | |
| "status": session_data.get("status", "unknown"), | |
| "progress": session_data.get("progress", 0.0), | |
| "current_step": session_data.get("current_step", 0), | |
| "total_steps": session_data.get("total_steps", 0), | |
| "start_time": session_data.get("start_time"), | |
| "end_time": session_data.get("end_time"), | |
| "message": session_data.get("message", ""), | |
| "loss": session_data.get("loss"), | |
| "model_path": str(session_data.get("model_path", "")) if session_data.get("model_path") else None | |
| } | |
| sessions_info.append(session_info) | |
| return { | |
| "success": True, | |
| "sessions": sessions_info, | |
| "total_sessions": len(sessions_info), | |
| "active_sessions": len([s for s in sessions_info if s["status"] in ["running", "initializing"]]) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error listing sessions: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def delete_training_session(session_id: str): | |
| """Delete a training session""" | |
| try: | |
| if session_id not in training_sessions: | |
| raise HTTPException(status_code=404, detail="Training session not found") | |
| session = training_sessions[session_id] | |
| status = session.get("status", "unknown") | |
| # Don't allow deletion of running sessions | |
| if status in ["running", "initializing"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Cannot delete active session with status: {status}" | |
| ) | |
| # Clean up session data | |
| del training_sessions[session_id] | |
| # Clean up WebSocket connection if exists | |
| if session_id in active_connections: | |
| try: | |
| await active_connections[session_id].close() | |
| except: | |
| pass | |
| del active_connections[session_id] | |
| logger.info(f"Deleted training session: {session_id}") | |
| return { | |
| "success": True, | |
| "message": f"Session {session_id} deleted successfully" | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error deleting session {session_id}: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def cancel_training_session(session_id: str): | |
| """Cancel a running training session""" | |
| try: | |
| if session_id not in training_sessions: | |
| raise HTTPException(status_code=404, detail="Training session not found") | |
| session = training_sessions[session_id] | |
| status = session.get("status", "unknown") | |
| if status not in ["running", "initializing"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Cannot cancel session with status: {status}" | |
| ) | |
| # Update session status | |
| session["status"] = "cancelled" | |
| session["message"] = "Training cancelled by user" | |
| session["end_time"] = asyncio.get_event_loop().time() | |
| # Notify WebSocket clients | |
| await update_training_status( | |
| session_id, "cancelled", session.get("progress", 0), | |
| "Training cancelled by user" | |
| ) | |
| logger.info(f"Cancelled training session: {session_id}") | |
| return { | |
| "success": True, | |
| "message": f"Session {session_id} cancelled successfully" | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error cancelling session {session_id}: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def cleanup_completed_sessions(): | |
| """Clean up all completed and failed sessions""" | |
| try: | |
| cleaned_sessions = [] | |
| sessions_to_remove = [] | |
| for session_id, session_data in training_sessions.items(): | |
| status = session_data.get("status", "unknown") | |
| if status in ["completed", "failed", "cancelled"]: | |
| sessions_to_remove.append(session_id) | |
| cleaned_sessions.append({ | |
| "session_id": session_id, | |
| "status": status | |
| }) | |
| # Remove sessions | |
| for session_id in sessions_to_remove: | |
| del training_sessions[session_id] | |
| # Clean up WebSocket connections | |
| if session_id in active_connections: | |
| try: | |
| await active_connections[session_id].close() | |
| except: | |
| pass | |
| del active_connections[session_id] | |
| logger.info(f"Cleaned up {len(cleaned_sessions)} completed sessions") | |
| return { | |
| "success": True, | |
| "message": f"Cleaned up {len(cleaned_sessions)} sessions", | |
| "cleaned_sessions": cleaned_sessions | |
| } | |
| except Exception as e: | |
| logger.error(f"Error cleaning up sessions: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Medical Dataset Management Endpoints | |
| async def get_medical_datasets(): | |
| """Get all supported medical datasets""" | |
| try: | |
| from src.medical.medical_config import SUPPORTED_MEDICAL_DATASETS, MEDICAL_SPECIALTIES | |
| return { | |
| "success": True, | |
| "datasets": SUPPORTED_MEDICAL_DATASETS, | |
| "specialties": MEDICAL_SPECIALTIES, | |
| "total_datasets": len(SUPPORTED_MEDICAL_DATASETS) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting medical datasets: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def select_medical_datasets( | |
| user_session: str = Form(...), | |
| selected_datasets: str = Form(...), # JSON string of dataset names | |
| preferences: str = Form(default="{}") # JSON string of user preferences | |
| ): | |
| """Save user's medical dataset selections""" | |
| try: | |
| from database.medical_selections import MedicalSelectionsDB | |
| from src.medical.medical_config import validate_medical_dataset_selection | |
| import json | |
| # Parse input data | |
| dataset_list = json.loads(selected_datasets) | |
| user_preferences = json.loads(preferences) | |
| # Validate selections | |
| validation_result = validate_medical_dataset_selection(dataset_list) | |
| if not validation_result['valid']: | |
| return { | |
| "success": False, | |
| "errors": validation_result['errors'], | |
| "warnings": validation_result['warnings'] | |
| } | |
| # Save selections to database | |
| db = MedicalSelectionsDB() | |
| # Clear previous selections | |
| for dataset_name in dataset_list: | |
| db.remove_dataset_selection(user_session, dataset_name) | |
| # Save new selections | |
| success_count = 0 | |
| for dataset_name in dataset_list: | |
| if db.save_dataset_selection(user_session, dataset_name): | |
| success_count += 1 | |
| # Save user preferences | |
| if user_preferences: | |
| db.save_user_preferences(user_session, user_preferences) | |
| return { | |
| "success": True, | |
| "message": f"تم حفظ {success_count} من قواعد البيانات بنجاح", | |
| "selected_count": success_count, | |
| "validation_result": validation_result | |
| } | |
| except json.JSONDecodeError as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid JSON format: {e}") | |
| except Exception as e: | |
| logger.error(f"Error selecting medical datasets: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_user_medical_selections(user_session: str): | |
| """Get user's medical dataset selections""" | |
| try: | |
| from database.medical_selections import MedicalSelectionsDB | |
| db = MedicalSelectionsDB() | |
| selections = db.get_user_dataset_selections(user_session) | |
| preferences = db.get_user_preferences(user_session) | |
| return { | |
| "success": True, | |
| "selections": selections, | |
| "preferences": preferences, | |
| "total_selected": len(selections) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting user selections: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def remove_medical_dataset_selection(user_session: str, dataset_name: str): | |
| """Remove a specific dataset selection""" | |
| try: | |
| from database.medical_selections import MedicalSelectionsDB | |
| db = MedicalSelectionsDB() | |
| success = db.remove_dataset_selection(user_session, dataset_name) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": f"تم إزالة قاعدة البيانات {dataset_name} بنجاح" | |
| } | |
| else: | |
| raise HTTPException(status_code=400, detail="فشل في إزالة قاعدة البيانات") | |
| except Exception as e: | |
| logger.error(f"Error removing dataset selection: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_dataset_recommendations(user_session: str): | |
| """Get personalized dataset recommendations""" | |
| try: | |
| from database.medical_selections import MedicalSelectionsDB | |
| from src.medical.medical_config import get_dataset_by_specialty, SUPPORTED_MEDICAL_DATASETS | |
| db = MedicalSelectionsDB() | |
| preferences = db.get_user_preferences(user_session) | |
| recommendations = [] | |
| # Get recommendations based on specialties | |
| for specialty in preferences.get('specialties', []): | |
| recommended_datasets = get_dataset_by_specialty(specialty) | |
| for dataset_name in recommended_datasets: | |
| if dataset_name in SUPPORTED_MEDICAL_DATASETS: | |
| dataset_info = SUPPORTED_MEDICAL_DATASETS[dataset_name].copy() | |
| dataset_info['recommended_for_specialty'] = specialty | |
| dataset_info['dataset_key'] = dataset_name | |
| recommendations.append(dataset_info) | |
| # Remove duplicates | |
| seen_datasets = set() | |
| unique_recommendations = [] | |
| for rec in recommendations: | |
| if rec['dataset_key'] not in seen_datasets: | |
| seen_datasets.add(rec['dataset_key']) | |
| unique_recommendations.append(rec) | |
| return { | |
| "success": True, | |
| "recommendations": unique_recommendations, | |
| "user_preferences": preferences, | |
| "total_recommendations": len(unique_recommendations) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting recommendations: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Model Management Endpoints | |
| async def get_google_models(): | |
| """Get available Google models for teacher selection""" | |
| try: | |
| # Mock Google models data - in production, this would fetch from Google's API | |
| google_models = { | |
| 'flan_t5_base': { | |
| 'name': 'FLAN-T5 Base', | |
| 'description': 'نموذج نصوص متوسط الحجم مدرب على مهام متنوعة', | |
| 'type': 'text', | |
| 'modalities': ['text'], | |
| 'parameters': '250M', | |
| 'size_category': 'medium', | |
| 'use_cases': ['الإجابة على الأسئلة', 'التلخيص', 'الترجمة'], | |
| 'performance_score': 8.5, | |
| 'repo_id': 'google/flan-t5-base', | |
| 'license': 'Apache 2.0' | |
| }, | |
| 'flan_t5_large': { | |
| 'name': 'FLAN-T5 Large', | |
| 'description': 'نموذج نصوص كبير عالي الأداء', | |
| 'type': 'text', | |
| 'modalities': ['text'], | |
| 'parameters': '780M', | |
| 'size_category': 'large', | |
| 'use_cases': ['المهام المعقدة', 'التحليل المتقدم', 'الكتابة الإبداعية'], | |
| 'performance_score': 9.2, | |
| 'repo_id': 'google/flan-t5-large', | |
| 'license': 'Apache 2.0' | |
| }, | |
| 'vit_base': { | |
| 'name': 'Vision Transformer Base', | |
| 'description': 'نموذج رؤية حاسوبية متقدم', | |
| 'type': 'vision', | |
| 'modalities': ['vision'], | |
| 'parameters': '86M', | |
| 'size_category': 'medium', | |
| 'use_cases': ['تصنيف الصور', 'التعرف على الأشياء', 'تحليل المحتوى البصري'], | |
| 'performance_score': 8.8, | |
| 'repo_id': 'google/vit-base-patch16-224', | |
| 'license': 'Apache 2.0' | |
| }, | |
| 'clip_vit': { | |
| 'name': 'CLIP Vision-Text', | |
| 'description': 'نموذج متعدد الوسائط يربط النصوص والصور', | |
| 'type': 'multimodal', | |
| 'modalities': ['text', 'vision'], | |
| 'parameters': '400M', | |
| 'size_category': 'large', | |
| 'use_cases': ['البحث بالصور', 'وصف الصور', 'التصنيف متعدد الوسائط'], | |
| 'performance_score': 9.0, | |
| 'repo_id': 'openai/clip-vit-base-patch32', | |
| 'license': 'MIT' | |
| }, | |
| 'bert_base': { | |
| 'name': 'BERT Base', | |
| 'description': 'نموذج فهم اللغة الطبيعية الكلاسيكي', | |
| 'type': 'text', | |
| 'modalities': ['text'], | |
| 'parameters': '110M', | |
| 'size_category': 'small', | |
| 'use_cases': ['تحليل المشاعر', 'تصنيف النصوص', 'استخراج المعلومات'], | |
| 'performance_score': 8.0, | |
| 'repo_id': 'bert-base-uncased', | |
| 'license': 'Apache 2.0' | |
| } | |
| } | |
| return { | |
| "success": True, | |
| "models": google_models, | |
| "total_models": len(google_models), | |
| "categories": { | |
| "text": len([m for m in google_models.values() if m['type'] == 'text']), | |
| "vision": len([m for m in google_models.values() if m['type'] == 'vision']), | |
| "multimodal": len([m for m in google_models.values() if m['type'] == 'multimodal']) | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting Google models: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def save_model_configuration(configuration: dict): | |
| """Save user's model configuration""" | |
| try: | |
| from database.medical_selections import MedicalSelectionsDB | |
| import json | |
| user_session = configuration.get('user_session') | |
| teachers = configuration.get('teachers', []) | |
| student = configuration.get('student') | |
| if not user_session: | |
| raise HTTPException(status_code=400, detail="User session is required") | |
| if not teachers: | |
| raise HTTPException(status_code=400, detail="At least one teacher model is required") | |
| # Save to database | |
| db = MedicalSelectionsDB() | |
| # Create configuration record | |
| config_data = { | |
| 'teachers': teachers, | |
| 'student': student, | |
| 'timestamp': configuration.get('timestamp'), | |
| 'total_teachers': len(teachers), | |
| 'student_type': student.get('type') if student else 'new' | |
| } | |
| # Save as user preferences | |
| success = db.save_user_preferences(user_session, { | |
| 'model_configuration': config_data, | |
| 'last_updated': configuration.get('timestamp') | |
| }) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": f"تم حفظ تكوين {len(teachers)} نماذج معلمة بنجاح", | |
| "configuration": config_data | |
| } | |
| else: | |
| raise HTTPException(status_code=500, detail="Failed to save configuration") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error saving model configuration: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_model_configuration(user_session: str): | |
| """Get user's saved model configuration""" | |
| try: | |
| from database.medical_selections import MedicalSelectionsDB | |
| db = MedicalSelectionsDB() | |
| preferences = db.get_user_preferences(user_session) | |
| model_config = preferences.get('model_configuration', {}) | |
| return { | |
| "success": True, | |
| "teachers": model_config.get('teachers', []), | |
| "student": model_config.get('student'), | |
| "last_updated": preferences.get('last_updated'), | |
| "total_teachers": len(model_config.get('teachers', [])) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting model configuration: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def clear_model_configuration(user_session: str): | |
| """Clear user's model configuration""" | |
| try: | |
| from database.medical_selections import MedicalSelectionsDB | |
| db = MedicalSelectionsDB() | |
| # Clear model configuration | |
| success = db.save_user_preferences(user_session, { | |
| 'model_configuration': {}, | |
| 'last_updated': None | |
| }) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "تم مسح تكوين النماذج بنجاح" | |
| } | |
| else: | |
| raise HTTPException(status_code=500, detail="Failed to clear configuration") | |
| except Exception as e: | |
| logger.error(f"Error clearing model configuration: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Token Management Endpoints | |
| async def token_management_page(request: Request): | |
| """Token management page""" | |
| return templates.TemplateResponse("token-management.html", {"request": request}) | |
| async def save_token( | |
| name: str = Form(...), | |
| token: str = Form(...), | |
| token_type: str = Form("read"), | |
| description: str = Form(""), | |
| is_default: bool = Form(False) | |
| ): | |
| """Save HF token""" | |
| try: | |
| success = token_manager.save_token(name, token, token_type, description, is_default) | |
| if success: | |
| return {"success": True, "message": f"Token '{name}' saved successfully"} | |
| else: | |
| raise HTTPException(status_code=400, detail="Failed to save token") | |
| except Exception as e: | |
| logger.error(f"Error saving token: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_tokens(): | |
| """List all saved tokens""" | |
| try: | |
| tokens = token_manager.list_tokens() | |
| return {"tokens": tokens} | |
| except Exception as e: | |
| logger.error(f"Error listing tokens: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def delete_token(token_name: str): | |
| """Delete a token""" | |
| try: | |
| success = token_manager.delete_token(token_name) | |
| if success: | |
| return {"success": True, "message": f"Token '{token_name}' deleted"} | |
| else: | |
| raise HTTPException(status_code=404, detail="Token not found") | |
| except Exception as e: | |
| logger.error(f"Error deleting token: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def set_default_token(token_name: str): | |
| """Set token as default""" | |
| try: | |
| success = token_manager.set_default_token(token_name) | |
| if success: | |
| return {"success": True, "message": f"Token '{token_name}' set as default"} | |
| else: | |
| raise HTTPException(status_code=404, detail="Token not found") | |
| except Exception as e: | |
| logger.error(f"Error setting default token: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def validate_token(token: str = Form(...)): | |
| """Validate HF token""" | |
| try: | |
| result = token_manager.validate_token(token) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error validating token: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_token_for_task(task_type: str): | |
| """Get appropriate token for specific task""" | |
| try: | |
| # Get token for task | |
| token = token_manager.get_token_for_task(task_type) | |
| if not token: | |
| raise HTTPException(status_code=404, detail=f"No suitable token found for task: {task_type}") | |
| # Get token information | |
| tokens = token_manager.list_tokens() | |
| token_info = None | |
| # Find which token was selected | |
| for t in tokens: | |
| test_token = token_manager.get_token(t['name']) | |
| if test_token == token: | |
| token_info = t | |
| break | |
| if not token_info: | |
| # Token from environment variable | |
| token_info = { | |
| 'name': f'{task_type}_token', | |
| 'type': task_type, | |
| 'description': f'رمز من متغيرات البيئة للمهمة: {task_type}', | |
| 'last_used': None, | |
| 'usage_count': 0 | |
| } | |
| # Get token type information | |
| type_info = token_manager.token_types.get(token_info['type'], {}) | |
| return { | |
| "success": True, | |
| "task_type": task_type, | |
| "token_info": { | |
| "token_name": token_info['name'], | |
| "type": token_info['type'], | |
| "type_name": type_info.get('name', token_info['type']), | |
| "description": token_info['description'], | |
| "security_level": type_info.get('security_level', 'medium'), | |
| "recommended_for": type_info.get('recommended_for', 'general'), | |
| "last_used": token_info.get('last_used'), | |
| "usage_count": token_info.get('usage_count', 0) | |
| } | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error getting token for task {task_type}: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Medical Dataset Endpoints | |
| async def medical_datasets_page(request: Request): | |
| """Medical datasets management page""" | |
| return templates.TemplateResponse("medical-datasets.html", {"request": request}) | |
| # Google Models Endpoints | |
| async def google_models_page(request: Request): | |
| """Google models selection page""" | |
| return templates.TemplateResponse("google-models.html", {"request": request}) | |
| async def list_medical_datasets(): | |
| """List supported medical datasets""" | |
| try: | |
| datasets = medical_dataset_manager.list_supported_datasets() | |
| return {"datasets": datasets} | |
| except Exception as e: | |
| logger.error(f"Error listing medical datasets: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def load_medical_dataset( | |
| dataset_name: str = Form(...), | |
| streaming: bool = Form(True), | |
| split: str = Form("train") | |
| ): | |
| """Load medical dataset""" | |
| try: | |
| # Get appropriate token for medical datasets (fine-grained preferred) | |
| hf_token = token_manager.get_token_for_task('medical') | |
| if not hf_token: | |
| logger.warning("No suitable token found for medical datasets, trying default") | |
| hf_token = token_manager.get_token() | |
| dataset_info = await medical_dataset_manager.load_dataset( | |
| dataset_name=dataset_name, | |
| streaming=streaming, | |
| split=split, | |
| token=hf_token | |
| ) | |
| return { | |
| "success": True, | |
| "dataset_info": { | |
| "name": dataset_info['config']['name'], | |
| "size_gb": dataset_info['config']['size_gb'], | |
| "num_samples": dataset_info['config']['num_samples'], | |
| "streaming": dataset_info['streaming'] | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading medical dataset: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Memory and Performance Endpoints | |
| async def get_memory_info(): | |
| """Get current memory information""" | |
| try: | |
| memory_info = memory_manager.get_memory_info() | |
| return memory_info | |
| except Exception as e: | |
| logger.error(f"Error getting memory info: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_performance_info(): | |
| """Get system performance information""" | |
| try: | |
| memory_info = memory_manager.get_memory_info() | |
| recommendations = memory_manager.get_memory_recommendations() | |
| return { | |
| "memory": memory_info, | |
| "recommendations": recommendations, | |
| "cpu_cores": cpu_optimizer.cpu_count, | |
| "optimizations_applied": cpu_optimizer.optimizations_applied | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting performance info: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def force_memory_cleanup(): | |
| """Force memory cleanup""" | |
| try: | |
| memory_manager.force_cleanup() | |
| return {"success": True, "message": "Memory cleanup completed"} | |
| except Exception as e: | |
| logger.error(f"Error during memory cleanup: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Google Models Support | |
| async def list_google_models(): | |
| """List available Google models""" | |
| try: | |
| google_models = [ | |
| { | |
| "name": "google/medsiglip-448", | |
| "description": "Medical SigLIP model for medical image-text understanding", | |
| "type": "vision-language", | |
| "size_gb": 1.1, | |
| "modality": "multimodal", | |
| "medical_specialized": True | |
| }, | |
| { | |
| "name": "google/gemma-3n-E4B-it", | |
| "description": "Gemma 3 model for instruction following", | |
| "type": "language", | |
| "size_gb": 8.5, | |
| "modality": "text", | |
| "medical_specialized": False | |
| } | |
| ] | |
| return {"models": google_models} | |
| except Exception as e: | |
| logger.error(f"Error listing Google models: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=int(os.getenv("PORT", 7860)), | |
| reload=False, | |
| log_level="info" | |
| ) | |