train-modle / app.py
fokan's picture
Force Space rebuild v2.1.0 with incremental training
aec0216
raw
history blame
77.5 kB
"""
Multi-Modal Knowledge Distillation Web Application
A FastAPI-based web application for creating new AI models through knowledge distillation
from multiple pre-trained models across different modalities.
"""
import os
import asyncio
import logging
import uuid
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
import json
import shutil
from datetime import datetime
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, WebSocket, WebSocketDisconnect, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
from src.model_loader import ModelLoader
from src.distillation import KnowledgeDistillationTrainer
from src.utils import setup_logging, validate_file, cleanup_temp_files, get_system_info
# Import new core components
from src.core.memory_manager import AdvancedMemoryManager
from src.core.chunk_loader import AdvancedChunkLoader
from src.core.cpu_optimizer import CPUOptimizer
from src.core.token_manager import TokenManager
# Import medical components
from src.medical.medical_datasets import MedicalDatasetManager
from src.medical.dicom_handler import DicomHandler
from src.medical.medical_preprocessing import MedicalPreprocessor
# Import database components
from database.database import DatabaseManager
# Setup logging with error handling
try:
setup_logging()
logger = logging.getLogger(__name__)
except Exception as e:
# Fallback to basic logging if setup fails
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.warning(f"Failed to setup advanced logging: {e}")
# Initialize FastAPI app
app = FastAPI(
title="Multi-Modal Knowledge Distillation",
description="Create new AI models through knowledge distillation from multiple pre-trained models",
version="2.1.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files and templates
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# Global variables for tracking training sessions
training_sessions: Dict[str, Dict[str, Any]] = {}
active_connections: Dict[str, WebSocket] = {}
def serialize_session_for_websocket(session_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Clean session data for WebSocket JSON serialization
Converts Path objects and other non-serializable types to strings
"""
cleaned_data = {}
for key, value in session_data.items():
try:
if isinstance(value, Path):
# Convert Path objects to strings
cleaned_data[key] = str(value)
elif isinstance(value, (list, tuple)):
# Clean lists/tuples recursively
cleaned_data[key] = [
str(item) if isinstance(item, Path) else
serialize_session_for_websocket(item) if isinstance(item, dict) else
item for item in value
]
elif isinstance(value, dict):
# Clean nested dictionaries recursively
cleaned_data[key] = serialize_session_for_websocket(value)
elif hasattr(value, '__dict__') and not isinstance(value, (str, int, float, bool, type(None))):
# Convert complex objects to string representation
cleaned_data[key] = str(value)
else:
# Keep simple types as-is
cleaned_data[key] = value
except Exception as e:
# If anything fails, convert to string
logger.warning(f"Error serializing session key '{key}': {e}")
cleaned_data[key] = str(value) if value is not None else None
return cleaned_data
# Pydantic models for API
class TrainingConfig(BaseModel):
session_id: str = Field(..., description="Unique session identifier")
teacher_models: List[Union[str, Dict[str, Any]]] = Field(..., description="List of teacher model paths/URLs or model configs")
student_config: Dict[str, Any] = Field(default_factory=dict, description="Student model configuration")
training_params: Dict[str, Any] = Field(default_factory=dict, description="Training parameters")
distillation_strategy: str = Field(default="ensemble", description="Distillation strategy")
hf_token: Optional[str] = Field(default=None, description="Hugging Face token")
trust_remote_code: bool = Field(default=False, description="Trust remote code execution")
existing_student_model: Optional[str] = Field(default=None, description="Path to existing trained student model for retraining")
incremental_training: bool = Field(default=False, description="Whether this is incremental training")
class TrainingStatus(BaseModel):
session_id: str
status: str
progress: float
current_step: int
total_steps: int
loss: Optional[float] = None
eta: Optional[str] = None
message: str = ""
class ModelInfo(BaseModel):
name: str
size: int
format: str
modality: str
architecture: Optional[str] = None
# Initialize components
model_loader = ModelLoader()
distillation_trainer = KnowledgeDistillationTrainer()
# Initialize new advanced components
memory_manager = AdvancedMemoryManager(max_memory_gb=14.0) # 14GB for 16GB systems
chunk_loader = AdvancedChunkLoader(memory_manager)
cpu_optimizer = CPUOptimizer(memory_manager)
token_manager = TokenManager()
database_manager = DatabaseManager()
# Initialize medical components
medical_dataset_manager = MedicalDatasetManager(memory_manager)
dicom_handler = DicomHandler(memory_limit_mb=1000.0)
medical_preprocessor = MedicalPreprocessor()
@app.on_event("startup")
async def startup_event():
"""Initialize application on startup"""
logger.info("Starting Multi-Modal Knowledge Distillation application")
# Create necessary directories with error handling
for directory in ["uploads", "models", "temp", "logs"]:
try:
Path(directory).mkdir(exist_ok=True)
logger.info(f"Created/verified directory: {directory}")
except PermissionError:
logger.warning(f"Cannot create directory {directory}, using temp directory")
except Exception as e:
logger.warning(f"Error creating directory {directory}: {e}")
# Log system information
try:
system_info = get_system_info()
logger.info(f"System info: {system_info}")
except Exception as e:
logger.warning(f"Could not get system info: {e}")
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on application shutdown"""
logger.info("Shutting down application")
cleanup_temp_files()
@app.get("/", response_class=HTMLResponse)
async def read_root():
"""Serve the main web interface"""
return templates.TemplateResponse("index.html", {"request": {}})
@app.get("/health")
async def health_check():
"""Health check endpoint for Docker and monitoring"""
try:
# Get system information
memory_info = memory_manager.get_memory_info()
# Check if default token is available
default_token = token_manager.get_token()
return {
"status": "healthy",
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"memory": {
"usage_percent": memory_info.get("process_memory_percent", 0),
"available_gb": memory_info.get("system_memory_available_gb", 0),
"status": memory_manager.check_memory_status()
},
"tokens": {
"default_available": bool(default_token),
"total_tokens": len(token_manager.list_tokens())
},
"features": {
"memory_management": True,
"chunk_loading": True,
"cpu_optimization": True,
"medical_datasets": True,
"token_management": True
},
"system_info": get_system_info()
}
except Exception as e:
logger.error(f"Health check failed: {e}")
return {
"status": "unhealthy",
"error": str(e),
"timestamp": datetime.now().isoformat(),
"version": "2.0.0"
}
@app.get("/test-token")
async def test_token():
"""Test if HF token is working"""
hf_token = (
os.getenv('HF_TOKEN') or
os.getenv('HUGGINGFACE_TOKEN') or
os.getenv('HUGGINGFACE_HUB_TOKEN')
)
if not hf_token:
return {
"token_available": False,
"message": "No HF token found in environment variables"
}
try:
# Test token by trying to access a gated model's config
from transformers import AutoConfig
config = AutoConfig.from_pretrained("google/gemma-2b", token=hf_token)
return {
"token_available": True,
"token_valid": True,
"message": "Token is working correctly"
}
except Exception as e:
return {
"token_available": True,
"token_valid": False,
"message": f"Token validation failed: {str(e)}"
}
@app.post("/test-model")
async def test_model_loading(request: Dict[str, Any]):
"""Test loading a specific model"""
try:
model_path = request.get('model_path')
trust_remote_code = request.get('trust_remote_code', False)
if not model_path:
return {"success": False, "error": "model_path is required"}
# Get appropriate token based on access type
access_type = request.get('access_type', 'read')
hf_token = request.get('token')
if not hf_token or hf_token == 'auto':
# Get appropriate token for the access type
hf_token = token_manager.get_token_for_task(access_type)
if hf_token:
logger.info(f"Using {access_type} token for model testing")
else:
logger.warning(f"No suitable token found for {access_type} access")
# Fallback to environment variables
hf_token = (
os.getenv('HF_TOKEN') or
os.getenv('HUGGINGFACE_TOKEN') or
os.getenv('HUGGINGFACE_HUB_TOKEN')
)
# Test model loading
model_info = await model_loader.get_model_info(model_path)
return {
"success": True,
"model_info": model_info,
"message": f"Model {model_path} can be loaded"
}
except Exception as e:
error_msg = str(e)
suggestions = []
if 'trust_remote_code' in error_msg.lower():
suggestions.append("فعّل 'Trust Remote Code' للنماذج التي تتطلب كود مخصص")
elif 'gated' in error_msg.lower():
suggestions.append("النموذج يتطلب إذن وصول خاص - استخدم رمز مخصص")
elif 'siglip' in error_msg.lower():
suggestions.append("جرب تفعيل 'Trust Remote Code' لنماذج SigLIP")
elif '401' in error_msg or 'authentication' in error_msg.lower():
suggestions.append("تحقق من رمز Hugging Face الخاص بك")
suggestions.append("تأكد من أن الرمز له صلاحية الوصول لهذا النموذج")
elif '404' in error_msg or 'not found' in error_msg.lower():
suggestions.append("تحقق من اسم مستودع النموذج")
suggestions.append("تأكد من وجود النموذج على Hugging Face")
return {
"success": False,
"error": error_msg,
"suggestions": suggestions
}
@app.post("/upload", response_model=Dict[str, Any])
async def upload_model(
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...),
model_names: List[str] = Form(...)
):
"""Upload model files"""
try:
uploaded_models = []
for file, name in zip(files, model_names):
# Validate file
validation_result = validate_file(file)
if not validation_result["valid"]:
raise HTTPException(status_code=400, detail=validation_result["error"])
# Generate unique filename
file_id = str(uuid.uuid4())
file_extension = Path(file.filename).suffix
safe_filename = f"{file_id}{file_extension}"
file_path = Path("uploads") / safe_filename
# Save file
with open(file_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
# Get model info
model_info = await model_loader.get_model_info(str(file_path))
uploaded_models.append({
"id": file_id,
"name": name,
"filename": file.filename,
"path": str(file_path),
"size": len(content),
"info": model_info
})
logger.info(f"Uploaded model: {name} ({file.filename})")
# Schedule cleanup of old files
background_tasks.add_task(cleanup_temp_files, max_age_hours=24)
return {
"success": True,
"models": uploaded_models,
"message": f"Successfully uploaded {len(uploaded_models)} model(s)"
}
except Exception as e:
logger.error(f"Error uploading models: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/start-training", response_model=Dict[str, Any])
async def start_training(
background_tasks: BackgroundTasks,
config: TrainingConfig
):
"""Start knowledge distillation training"""
try:
session_id = config.session_id
# Handle existing sessions intelligently
if session_id in training_sessions:
existing_session = training_sessions[session_id]
status = existing_session.get("status", "unknown")
# If session is completed or failed, allow reuse by cleaning it up
if status in ["completed", "failed"]:
logger.info(f"Cleaning up previous session {session_id} with status: {status}")
del training_sessions[session_id]
# Also clean up WebSocket connection if exists
if session_id in active_connections:
try:
await active_connections[session_id].close()
except:
pass
del active_connections[session_id]
else:
# Session is still active
raise HTTPException(
status_code=400,
detail=f"Training session already exists with status: {status}. Please wait for completion or use a different session ID."
)
# Set HF token from environment if available
hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_TOKEN')
if hf_token:
os.environ['HF_TOKEN'] = hf_token
logger.info("Using Hugging Face token from environment")
# Check for large models and warn
large_models = []
for model_info in config.teacher_models:
model_path = model_info if isinstance(model_info, str) else model_info.get('path', '')
if any(size_indicator in model_path.lower() for size_indicator in ['27b', '70b', '13b']):
large_models.append(model_path)
# Initialize training session
training_sessions[session_id] = {
"status": "initializing",
"progress": 0.0,
"current_step": 0,
"total_steps": config.training_params.get("max_steps", 1000),
"config": config.dict(),
"start_time": None,
"end_time": None,
"model_path": None,
"logs": [],
"large_models": large_models,
"message": "Initializing training session..." + (
f" (Large models detected: {', '.join(large_models)})" if large_models else ""
)
}
# Start training in background
background_tasks.add_task(run_training, session_id, config)
logger.info(f"Started training session: {session_id}")
return {
"success": True,
"session_id": session_id,
"message": "Training started successfully"
}
except Exception as e:
logger.error(f"Error starting training: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def run_training(session_id: str, config: TrainingConfig):
"""Run knowledge distillation training in background"""
try:
session = training_sessions[session_id]
session["status"] = "running"
session["start_time"] = asyncio.get_event_loop().time()
# Set timeout for the entire operation (30 minutes)
timeout_seconds = 30 * 60
# Set HF token for this session - prioritize config token
config_token = getattr(config, 'hf_token', None)
env_token = (
os.getenv('HF_TOKEN') or
os.getenv('HUGGINGFACE_TOKEN') or
os.getenv('HUGGINGFACE_HUB_TOKEN')
)
hf_token = config_token or env_token
if hf_token:
logger.info(f"Using Hugging Face token from {'config' if config_token else 'environment'}")
# Set token in environment for this session
os.environ['HF_TOKEN'] = hf_token
else:
logger.warning("No Hugging Face token found - private models may fail")
# Handle existing student model for incremental training
existing_student = None
if config.existing_student_model and config.incremental_training:
try:
await update_training_status(session_id, "loading_student", 0.05, "Loading existing student model...")
# Determine student source and load accordingly
student_source = getattr(config, 'student_source', 'local')
student_path = config.existing_student_model
if student_source == 'huggingface' or ('/' in student_path and not Path(student_path).exists()):
logger.info(f"Loading student model from Hugging Face: {student_path}")
existing_student = await model_loader.load_trained_student(student_path)
elif student_source == 'space':
logger.info(f"Loading student model from Hugging Face Space: {student_path}")
# For spaces, we'll try to load from the space's models directory
space_model_path = f"spaces/{student_path}/models"
existing_student = await model_loader.load_trained_student_from_space(student_path)
else:
logger.info(f"Loading student model from local path: {student_path}")
existing_student = await model_loader.load_trained_student(student_path)
logger.info(f"Successfully loaded existing student model: {existing_student.get('type', 'unknown')}")
# Merge original teachers with new teachers
original_teachers = existing_student.get('original_teachers', [])
new_teachers = [
model_info if isinstance(model_info, str) else model_info.get('path', '')
for model_info in config.teacher_models
]
# Combine teachers (avoid duplicates)
all_teachers = original_teachers.copy()
for teacher in new_teachers:
if teacher not in all_teachers:
all_teachers.append(teacher)
logger.info(f"Incremental training: Original teachers: {original_teachers}")
logger.info(f"Incremental training: New teachers: {new_teachers}")
logger.info(f"Incremental training: All teachers: {all_teachers}")
# Update config with all teachers
config.teacher_models = all_teachers
except Exception as e:
logger.error(f"Error loading existing student model: {e}")
await update_training_status(session_id, "failed", session.get("progress", 0), f"Failed to load existing student: {str(e)}")
return
# Load teacher models
await update_training_status(session_id, "loading_models", 0.1, "Loading teacher models...")
teacher_models = []
trust_remote_code = config.training_params.get('trust_remote_code', False)
total_models = len(config.teacher_models)
for i, model_info in enumerate(config.teacher_models):
try:
# Handle both old format (string) and new format (dict)
if isinstance(model_info, str):
model_path = model_info
model_token = hf_token
model_trust_code = trust_remote_code
else:
model_path = model_info.get('path', model_info)
model_token = model_info.get('token') or hf_token
model_trust_code = model_info.get('trust_remote_code', trust_remote_code)
# Update progress
progress = 0.1 + (i * 0.3 / total_models) # 0.1 to 0.4
await update_training_status(
session_id,
"loading_models",
progress,
f"Loading model {i+1}/{total_models}: {model_path}..."
)
logger.info(f"Loading model {model_path} with trust_remote_code={model_trust_code}")
# Special handling for known problematic models
if model_path == 'Wan-AI/Wan2.2-TI2V-5B':
logger.info(f"Detected ti2v model {model_path}, forcing trust_remote_code=True")
model_trust_code = True
elif model_path == 'deepseek-ai/DeepSeek-V3.1-Base':
logger.warning(f"Skipping {model_path}: Requires GPU with FP8 quantization support")
await update_training_status(
session_id,
"loading_models",
progress,
f"Skipping {model_path}: Requires GPU with FP8 quantization"
)
continue
model = await model_loader.load_model(
model_path,
token=model_token,
trust_remote_code=model_trust_code
)
teacher_models.append(model)
logger.info(f"Successfully loaded model: {model_path}")
# Update progress after successful load
progress = 0.1 + ((i + 1) * 0.3 / total_models)
await update_training_status(
session_id,
"loading_models",
progress,
f"Loaded {i+1}/{total_models} models successfully"
)
except Exception as e:
error_msg = f"Failed to load model {model_path}: {str(e)}"
logger.error(error_msg)
# Provide helpful suggestions based on the error
suggestions = []
error_str = str(e).lower()
# Check if we should retry with trust_remote_code=True
if not model_trust_code and ('ti2v' in error_str or 'does not recognize this architecture' in error_str):
try:
logger.info(f"Retrying {model_path} with trust_remote_code=True")
await update_training_status(
session_id,
"loading_models",
progress,
f"Retrying {model_path} with trust_remote_code=True..."
)
model = await model_loader.load_model(
model_path,
token=model_token,
trust_remote_code=True
)
teacher_models.append(model)
logger.info(f"Successfully loaded model on retry: {model_path}")
# Update progress after successful retry
progress = 0.1 + ((i + 1) * 0.3 / total_models)
await update_training_status(
session_id,
"loading_models",
progress,
f"Loaded {i+1}/{total_models} models successfully (retry)"
)
continue
except Exception as retry_e:
logger.error(f"Retry also failed for {model_path}: {str(retry_e)}")
error_msg = f"Failed even with trust_remote_code=True: {str(retry_e)}"
if 'trust_remote_code' in error_str:
suggestions.append("Try enabling 'Trust Remote Code' option")
elif 'gated' in error_str or 'access' in error_str:
suggestions.append("This model requires access permission and a valid HF token")
elif 'siglip' in error_str or 'unknown' in error_str:
suggestions.append("This model may require special loading. Try enabling 'Trust Remote Code'")
elif 'connection' in error_str or 'network' in error_str:
suggestions.append("Check your internet connection")
elif 'ti2v' in error_str:
suggestions.append("This ti2v model requires trust_remote_code=True")
if suggestions:
error_msg += f". Suggestions: {'; '.join(suggestions)}"
await update_training_status(session_id, "failed", session.get("progress", 0), error_msg)
return
# Initialize student model
await update_training_status(session_id, "initializing_student", 0.2, "Initializing student model...")
student_model = await distillation_trainer.create_student_model(
teacher_models, config.student_config
)
# Run distillation training
await update_training_status(session_id, "training", 0.3, "Starting knowledge distillation...")
async def progress_callback(step: int, total_steps: int, loss: float, metrics: Dict[str, Any]):
progress = 0.3 + (step / total_steps) * 0.6 # 30% to 90%
await update_training_status(
session_id, "training", progress,
f"Training step {step}/{total_steps}, Loss: {loss:.4f}",
current_step=step, loss=loss
)
trained_model = await distillation_trainer.train(
student_model, teacher_models, config.training_params, progress_callback
)
# Save trained model with metadata
await update_training_status(session_id, "saving", 0.9, "Saving trained model...")
# Create model directory with proper structure
model_dir = Path("models") / f"distilled_model_{session_id}"
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / "pytorch_model.safetensors"
# Prepare training metadata for saving
training_metadata = {
'session_id': session_id,
'teacher_models': [
model_info if isinstance(model_info, str) else model_info.get('path', '')
for model_info in config.teacher_models
],
'strategy': config.distillation_strategy,
'training_params': config.training_params,
'incremental_training': config.incremental_training,
'existing_student_model': config.existing_student_model
}
await distillation_trainer.save_model(trained_model, str(model_path), training_metadata)
# Complete training
session["status"] = "completed"
session["progress"] = 1.0
session["end_time"] = asyncio.get_event_loop().time()
session["model_path"] = model_path
session["training_metadata"] = training_metadata
await update_training_status(session_id, "completed", 1.0, "Training completed successfully!")
logger.info(f"Training session {session_id} completed successfully")
except Exception as e:
logger.error(f"Training session {session_id} failed: {str(e)}")
session = training_sessions.get(session_id, {})
session["status"] = "failed"
session["error"] = str(e)
await update_training_status(session_id, "failed", session.get("progress", 0), f"Training failed: {str(e)}")
async def update_training_status(
session_id: str,
status: str,
progress: float,
message: str,
current_step: int = None,
loss: float = None
):
"""Update training status and notify connected clients"""
if session_id in training_sessions:
session = training_sessions[session_id]
session["status"] = status
session["progress"] = progress
session["message"] = message
if current_step is not None:
session["current_step"] = current_step
if loss is not None:
session["loss"] = loss
# Calculate ETA
if session.get("start_time") and progress > 0:
elapsed = asyncio.get_event_loop().time() - session["start_time"]
if progress < 1.0:
eta_seconds = (elapsed / progress) * (1.0 - progress)
eta = f"{int(eta_seconds // 60)}m {int(eta_seconds % 60)}s"
session["eta"] = eta
# Notify WebSocket clients with cleaned data
if session_id in active_connections:
try:
# Clean session data for JSON serialization
clean_session_data = serialize_session_for_websocket(session)
await active_connections[session_id].send_json({
"type": "training_update",
"data": clean_session_data
})
except Exception as ws_error:
logger.warning(f"WebSocket error for session {session_id}: {ws_error}")
# Remove disconnected client
if session_id in active_connections:
del active_connections[session_id]
@app.get("/progress/{session_id}", response_model=TrainingStatus)
async def get_training_progress(session_id: str):
"""Get training progress for a session"""
if session_id not in training_sessions:
raise HTTPException(status_code=404, detail="Training session not found")
session = training_sessions[session_id]
return TrainingStatus(
session_id=session_id,
status=session["status"],
progress=session["progress"],
current_step=session["current_step"],
total_steps=session["total_steps"],
loss=session.get("loss"),
eta=session.get("eta"),
message=session.get("message", "")
)
@app.get("/download/{session_id}")
async def download_model(session_id: str):
"""Download trained model"""
try:
if session_id not in training_sessions:
raise HTTPException(status_code=404, detail="Training session not found")
session = training_sessions[session_id]
if session["status"] != "completed":
raise HTTPException(status_code=400, detail="Training not completed")
model_path = session.get("model_path")
if not model_path:
# Try to find model in models directory
models_dir = Path("models")
possible_paths = [
models_dir / f"distilled_model_{session_id}",
models_dir / f"distilled_model_{session_id}.safetensors",
models_dir / f"model_{session_id}",
models_dir / f"student_model_{session_id}"
]
for path in possible_paths:
if path.exists():
model_path = str(path)
break
if not model_path or not Path(model_path).exists():
raise HTTPException(status_code=404, detail="Model file not found. The model may not have been saved properly.")
# Create a zip file with all model files
import zipfile
import tempfile
model_dir = Path(model_path)
if model_dir.is_file():
# Single file
return FileResponse(
model_path,
media_type="application/octet-stream",
filename=f"distilled_model_{session_id}.safetensors"
)
else:
# Directory with multiple files
temp_zip = tempfile.NamedTemporaryFile(delete=False, suffix='.zip')
with zipfile.ZipFile(temp_zip.name, 'w') as zipf:
for file_path in model_dir.rglob('*'):
if file_path.is_file():
zipf.write(file_path, file_path.relative_to(model_dir))
return FileResponse(
temp_zip.name,
media_type="application/zip",
filename=f"distilled_model_{session_id}.zip"
)
except Exception as e:
logger.error(f"Error downloading model: {e}")
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
@app.post("/upload-to-hf/{session_id}")
async def upload_to_huggingface(
session_id: str,
repo_name: str = Form(...),
description: str = Form(""),
private: bool = Form(False),
hf_token: str = Form(...)
):
"""Upload trained model to Hugging Face Hub"""
try:
if session_id not in training_sessions:
raise HTTPException(status_code=404, detail="Training session not found")
session = training_sessions[session_id]
if session["status"] != "completed":
raise HTTPException(status_code=400, detail="Training not completed")
model_path = session.get("model_path")
if not model_path or not Path(model_path).exists():
raise HTTPException(status_code=404, detail="Model file not found")
# Import huggingface_hub
try:
from huggingface_hub import HfApi, create_repo
except ImportError:
raise HTTPException(status_code=500, detail="huggingface_hub not installed")
# Initialize HF API
api = HfApi(token=hf_token)
# Validate repository name format
if '/' not in repo_name:
raise HTTPException(status_code=400, detail="Repository name must be in format 'username/model-name'")
username, model_name = repo_name.split('/', 1)
# Create repository with better error handling
try:
repo_url = create_repo(
repo_id=repo_name,
token=hf_token,
private=private,
exist_ok=True
)
logger.info(f"Created/accessed repository: {repo_url}")
except Exception as e:
error_msg = str(e)
if "403" in error_msg or "Forbidden" in error_msg:
raise HTTPException(
status_code=403,
detail=f"Permission denied. Please check: 1) Your token has 'Write' permissions, 2) You own the namespace '{username}', 3) The repository name is correct. Error: {error_msg}"
)
elif "401" in error_msg or "Unauthorized" in error_msg:
raise HTTPException(
status_code=401,
detail=f"Invalid token. Please check your Hugging Face token. Error: {error_msg}"
)
else:
raise HTTPException(status_code=400, detail=f"Failed to create repository: {error_msg}")
# Upload model files
model_path_obj = Path(model_path)
uploaded_files = []
# Determine the model directory
if model_path_obj.is_file():
model_dir = model_path_obj.parent
else:
model_dir = model_path_obj
# Upload all files in the model directory
essential_files = [
'pytorch_model.safetensors', 'config.json', 'model.py',
'training_history.json', 'README.md'
]
# Upload essential files first
for file_name in essential_files:
file_path = model_dir / file_name
if file_path.exists():
try:
api.upload_file(
path_or_fileobj=str(file_path),
path_in_repo=file_name,
repo_id=repo_name,
token=hf_token
)
uploaded_files.append(file_name)
logger.info(f"Uploaded {file_name}")
except Exception as e:
logger.warning(f"Failed to upload {file_name}: {e}")
# Upload any additional files
for file_path in model_dir.rglob('*'):
if file_path.is_file() and file_path.name not in essential_files:
try:
relative_path = file_path.relative_to(model_dir)
api.upload_file(
path_or_fileobj=str(file_path),
path_in_repo=str(relative_path),
repo_id=repo_name,
token=hf_token
)
uploaded_files.append(str(relative_path))
logger.info(f"Uploaded additional file: {relative_path}")
except Exception as e:
logger.warning(f"Failed to upload {relative_path}: {e}")
# Create README.md
config_info = session.get("config", {})
teacher_models_raw = config_info.get("teacher_models", [])
# Extract model paths from teacher_models (handle both string and dict formats)
teacher_models = []
for model in teacher_models_raw:
if isinstance(model, str):
teacher_models.append(model)
elif isinstance(model, dict):
teacher_models.append(model.get('path', str(model)))
else:
teacher_models.append(str(model))
readme_content = f"""---
license: apache-2.0
tags:
- knowledge-distillation
- pytorch
- transformers
base_model: {teacher_models[0] if teacher_models else 'unknown'}
---
# {repo_name}
This model was created using knowledge distillation from the following teacher model(s):
{chr(10).join([f"- {model}" for model in teacher_models])}
## Model Description
{description if description else 'A distilled model created using multi-modal knowledge distillation.'}
## Training Details
- **Teacher Models**: {', '.join(teacher_models)}
- **Distillation Strategy**: {config_info.get('distillation_strategy', 'ensemble')}
- **Training Steps**: {config_info.get('training_params', {}).get('max_steps', 'unknown')}
- **Learning Rate**: {config_info.get('training_params', {}).get('learning_rate', 'unknown')}
## Usage
```python
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("{repo_name}")
tokenizer = AutoTokenizer.from_pretrained("{teacher_models[0] if teacher_models else 'bert-base-uncased'}")
```
## Created with
This model was created using the Multi-Modal Knowledge Distillation platform.
"""
# Upload README
api.upload_file(
path_or_fileobj=readme_content.encode(),
path_in_repo="README.md",
repo_id=repo_name,
token=hf_token
)
uploaded_files.append("README.md")
return {
"success": True,
"repo_url": f"https://huggingface.co/{repo_name}",
"uploaded_files": uploaded_files,
"message": f"Model successfully uploaded to {repo_name}"
}
except Exception as e:
logger.error(f"Error uploading to Hugging Face: {e}")
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
@app.post("/validate-repo-name")
async def validate_repo_name(request: Dict[str, Any]):
"""Validate repository name and check permissions"""
try:
repo_name = request.get('repo_name', '').strip()
hf_token = request.get('hf_token', '').strip()
if not repo_name or not hf_token:
return {"valid": False, "error": "Repository name and token are required"}
if '/' not in repo_name:
return {"valid": False, "error": "Repository name must be in format 'username/model-name'"}
username, model_name = repo_name.split('/', 1)
# Check if username matches token owner
try:
from huggingface_hub import HfApi
api = HfApi(token=hf_token)
# Try to get user info
user_info = api.whoami()
token_username = user_info.get('name', '')
if username != token_username:
return {
"valid": False,
"error": f"Username mismatch. Token belongs to '{token_username}' but trying to create repo under '{username}'. Use '{token_username}/{model_name}' instead.",
"suggested_name": f"{token_username}/{model_name}"
}
return {
"valid": True,
"message": f"Repository name '{repo_name}' is valid for your account",
"username": token_username
}
except Exception as e:
return {"valid": False, "error": f"Token validation failed: {str(e)}"}
except Exception as e:
return {"valid": False, "error": f"Validation error: {str(e)}"}
@app.post("/test-space")
async def test_space(request: Dict[str, Any]):
"""Test if a Hugging Face Space exists and has trained models"""
try:
space_name = request.get('space_name', '').strip()
hf_token = request.get('hf_token', '').strip()
if not space_name:
return {"success": False, "error": "Space name is required"}
if '/' not in space_name:
return {"success": False, "error": "Space name must be in format 'username/space-name'"}
try:
from huggingface_hub import HfApi
api = HfApi(token=hf_token if hf_token else None)
# Check if the Space exists
try:
space_info = api.space_info(space_name)
logger.info(f"Found Space: {space_name}")
except Exception as e:
return {"success": False, "error": f"Space not found or not accessible: {str(e)}"}
# Try to list files in the Space to see if it has models
try:
files = api.list_repo_files(space_name, repo_type="space")
model_files = [f for f in files if f.endswith(('.safetensors', '.bin', '.pt'))]
# Check for models directory
models_dir_files = [f for f in files if f.startswith('models/')]
return {
"success": True,
"space_info": {
"name": space_name,
"model_files": model_files,
"models_directory": len(models_dir_files) > 0,
"total_files": len(files)
},
"models": model_files,
"message": f"Space {space_name} is accessible"
}
except Exception as e:
# Space exists but we can't list files (might be private or no access)
return {
"success": True,
"space_info": {"name": space_name},
"models": [],
"message": f"Space {space_name} exists but file listing not available (might be private)"
}
except Exception as e:
return {"success": False, "error": f"Error accessing Hugging Face: {str(e)}"}
except Exception as e:
logger.error(f"Error testing Space: {e}")
return {"success": False, "error": f"Test failed: {str(e)}"}
@app.get("/trained-students")
async def list_trained_students():
"""List available trained student models for retraining"""
try:
models_dir = Path("models")
trained_students = []
if models_dir.exists():
for model_dir in models_dir.iterdir():
if model_dir.is_dir():
try:
# Check if it's a trained student model
config_files = list(model_dir.glob("*config.json"))
history_files = list(model_dir.glob("*training_history.json"))
if config_files:
with open(config_files[0], 'r') as f:
config = json.load(f)
if config.get('is_student_model', False):
history = {}
if history_files:
with open(history_files[0], 'r') as f:
history = json.load(f)
model_info = {
"id": model_dir.name,
"name": model_dir.name,
"path": str(model_dir),
"type": "trained_student",
"created_at": config.get('created_at', 'unknown'),
"architecture": config.get('architecture', 'unknown'),
"modalities": config.get('modalities', ['text']),
"can_be_retrained": config.get('can_be_retrained', True),
"original_teachers": history.get('retraining_info', {}).get('original_teachers', []),
"training_sessions": len(history.get('training_sessions', [])),
"last_training": history.get('training_sessions', [{}])[-1].get('timestamp', 'unknown') if history.get('training_sessions') else 'unknown'
}
trained_students.append(model_info)
except Exception as e:
logger.warning(f"Error reading model {model_dir}: {e}")
continue
return {"trained_students": trained_students}
except Exception as e:
logger.error(f"Error listing trained students: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models", response_model=List[ModelInfo])
async def list_models():
"""List available models"""
models = []
# List uploaded models
uploads_dir = Path("uploads")
if uploads_dir.exists():
for file_path in uploads_dir.iterdir():
if file_path.is_file():
try:
info = await model_loader.get_model_info(str(file_path))
models.append(ModelInfo(
name=file_path.stem,
size=file_path.stat().st_size,
format=file_path.suffix[1:],
modality=info.get("modality", "unknown"),
architecture=info.get("architecture")
))
except Exception as e:
logger.warning(f"Error getting info for {file_path}: {e}")
return models
@app.websocket("/ws/{session_id}")
async def websocket_endpoint(websocket: WebSocket, session_id: str):
"""WebSocket endpoint for real-time training updates"""
await websocket.accept()
active_connections[session_id] = websocket
try:
# Send current status if session exists
if session_id in training_sessions:
clean_session_data = serialize_session_for_websocket(training_sessions[session_id])
await websocket.send_json({
"type": "training_update",
"data": clean_session_data
})
# Keep connection alive
while True:
await websocket.receive_text()
except WebSocketDisconnect:
if session_id in active_connections:
del active_connections[session_id]
except Exception as e:
logger.error(f"WebSocket error for session {session_id}: {e}")
if session_id in active_connections:
del active_connections[session_id]
# ==================== NEW ADVANCED ENDPOINTS ====================
# Session Management Endpoints
@app.get("/api/sessions")
async def list_training_sessions():
"""List all training sessions with their status"""
try:
sessions_info = []
for session_id, session_data in training_sessions.items():
session_info = {
"session_id": session_id,
"status": session_data.get("status", "unknown"),
"progress": session_data.get("progress", 0.0),
"current_step": session_data.get("current_step", 0),
"total_steps": session_data.get("total_steps", 0),
"start_time": session_data.get("start_time"),
"end_time": session_data.get("end_time"),
"message": session_data.get("message", ""),
"loss": session_data.get("loss"),
"model_path": str(session_data.get("model_path", "")) if session_data.get("model_path") else None
}
sessions_info.append(session_info)
return {
"success": True,
"sessions": sessions_info,
"total_sessions": len(sessions_info),
"active_sessions": len([s for s in sessions_info if s["status"] in ["running", "initializing"]])
}
except Exception as e:
logger.error(f"Error listing sessions: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/sessions/{session_id}")
async def delete_training_session(session_id: str):
"""Delete a training session"""
try:
if session_id not in training_sessions:
raise HTTPException(status_code=404, detail="Training session not found")
session = training_sessions[session_id]
status = session.get("status", "unknown")
# Don't allow deletion of running sessions
if status in ["running", "initializing"]:
raise HTTPException(
status_code=400,
detail=f"Cannot delete active session with status: {status}"
)
# Clean up session data
del training_sessions[session_id]
# Clean up WebSocket connection if exists
if session_id in active_connections:
try:
await active_connections[session_id].close()
except:
pass
del active_connections[session_id]
logger.info(f"Deleted training session: {session_id}")
return {
"success": True,
"message": f"Session {session_id} deleted successfully"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting session {session_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/sessions/{session_id}/cancel")
async def cancel_training_session(session_id: str):
"""Cancel a running training session"""
try:
if session_id not in training_sessions:
raise HTTPException(status_code=404, detail="Training session not found")
session = training_sessions[session_id]
status = session.get("status", "unknown")
if status not in ["running", "initializing"]:
raise HTTPException(
status_code=400,
detail=f"Cannot cancel session with status: {status}"
)
# Update session status
session["status"] = "cancelled"
session["message"] = "Training cancelled by user"
session["end_time"] = asyncio.get_event_loop().time()
# Notify WebSocket clients
await update_training_status(
session_id, "cancelled", session.get("progress", 0),
"Training cancelled by user"
)
logger.info(f"Cancelled training session: {session_id}")
return {
"success": True,
"message": f"Session {session_id} cancelled successfully"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error cancelling session {session_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/sessions/cleanup")
async def cleanup_completed_sessions():
"""Clean up all completed and failed sessions"""
try:
cleaned_sessions = []
sessions_to_remove = []
for session_id, session_data in training_sessions.items():
status = session_data.get("status", "unknown")
if status in ["completed", "failed", "cancelled"]:
sessions_to_remove.append(session_id)
cleaned_sessions.append({
"session_id": session_id,
"status": status
})
# Remove sessions
for session_id in sessions_to_remove:
del training_sessions[session_id]
# Clean up WebSocket connections
if session_id in active_connections:
try:
await active_connections[session_id].close()
except:
pass
del active_connections[session_id]
logger.info(f"Cleaned up {len(cleaned_sessions)} completed sessions")
return {
"success": True,
"message": f"Cleaned up {len(cleaned_sessions)} sessions",
"cleaned_sessions": cleaned_sessions
}
except Exception as e:
logger.error(f"Error cleaning up sessions: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Medical Dataset Management Endpoints
@app.get("/api/medical-datasets")
async def get_medical_datasets():
"""Get all supported medical datasets"""
try:
from src.medical.medical_config import SUPPORTED_MEDICAL_DATASETS, MEDICAL_SPECIALTIES
return {
"success": True,
"datasets": SUPPORTED_MEDICAL_DATASETS,
"specialties": MEDICAL_SPECIALTIES,
"total_datasets": len(SUPPORTED_MEDICAL_DATASETS)
}
except Exception as e:
logger.error(f"Error getting medical datasets: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/medical-datasets/select")
async def select_medical_datasets(
user_session: str = Form(...),
selected_datasets: str = Form(...), # JSON string of dataset names
preferences: str = Form(default="{}") # JSON string of user preferences
):
"""Save user's medical dataset selections"""
try:
from database.medical_selections import MedicalSelectionsDB
from src.medical.medical_config import validate_medical_dataset_selection
import json
# Parse input data
dataset_list = json.loads(selected_datasets)
user_preferences = json.loads(preferences)
# Validate selections
validation_result = validate_medical_dataset_selection(dataset_list)
if not validation_result['valid']:
return {
"success": False,
"errors": validation_result['errors'],
"warnings": validation_result['warnings']
}
# Save selections to database
db = MedicalSelectionsDB()
# Clear previous selections
for dataset_name in dataset_list:
db.remove_dataset_selection(user_session, dataset_name)
# Save new selections
success_count = 0
for dataset_name in dataset_list:
if db.save_dataset_selection(user_session, dataset_name):
success_count += 1
# Save user preferences
if user_preferences:
db.save_user_preferences(user_session, user_preferences)
return {
"success": True,
"message": f"تم حفظ {success_count} من قواعد البيانات بنجاح",
"selected_count": success_count,
"validation_result": validation_result
}
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON format: {e}")
except Exception as e:
logger.error(f"Error selecting medical datasets: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/medical-datasets/selections/{user_session}")
async def get_user_medical_selections(user_session: str):
"""Get user's medical dataset selections"""
try:
from database.medical_selections import MedicalSelectionsDB
db = MedicalSelectionsDB()
selections = db.get_user_dataset_selections(user_session)
preferences = db.get_user_preferences(user_session)
return {
"success": True,
"selections": selections,
"preferences": preferences,
"total_selected": len(selections)
}
except Exception as e:
logger.error(f"Error getting user selections: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/medical-datasets/selections/{user_session}/{dataset_name}")
async def remove_medical_dataset_selection(user_session: str, dataset_name: str):
"""Remove a specific dataset selection"""
try:
from database.medical_selections import MedicalSelectionsDB
db = MedicalSelectionsDB()
success = db.remove_dataset_selection(user_session, dataset_name)
if success:
return {
"success": True,
"message": f"تم إزالة قاعدة البيانات {dataset_name} بنجاح"
}
else:
raise HTTPException(status_code=400, detail="فشل في إزالة قاعدة البيانات")
except Exception as e:
logger.error(f"Error removing dataset selection: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/medical-datasets/recommendations/{user_session}")
async def get_dataset_recommendations(user_session: str):
"""Get personalized dataset recommendations"""
try:
from database.medical_selections import MedicalSelectionsDB
from src.medical.medical_config import get_dataset_by_specialty, SUPPORTED_MEDICAL_DATASETS
db = MedicalSelectionsDB()
preferences = db.get_user_preferences(user_session)
recommendations = []
# Get recommendations based on specialties
for specialty in preferences.get('specialties', []):
recommended_datasets = get_dataset_by_specialty(specialty)
for dataset_name in recommended_datasets:
if dataset_name in SUPPORTED_MEDICAL_DATASETS:
dataset_info = SUPPORTED_MEDICAL_DATASETS[dataset_name].copy()
dataset_info['recommended_for_specialty'] = specialty
dataset_info['dataset_key'] = dataset_name
recommendations.append(dataset_info)
# Remove duplicates
seen_datasets = set()
unique_recommendations = []
for rec in recommendations:
if rec['dataset_key'] not in seen_datasets:
seen_datasets.add(rec['dataset_key'])
unique_recommendations.append(rec)
return {
"success": True,
"recommendations": unique_recommendations,
"user_preferences": preferences,
"total_recommendations": len(unique_recommendations)
}
except Exception as e:
logger.error(f"Error getting recommendations: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Model Management Endpoints
@app.get("/api/google-models")
async def get_google_models():
"""Get available Google models for teacher selection"""
try:
# Mock Google models data - in production, this would fetch from Google's API
google_models = {
'flan_t5_base': {
'name': 'FLAN-T5 Base',
'description': 'نموذج نصوص متوسط الحجم مدرب على مهام متنوعة',
'type': 'text',
'modalities': ['text'],
'parameters': '250M',
'size_category': 'medium',
'use_cases': ['الإجابة على الأسئلة', 'التلخيص', 'الترجمة'],
'performance_score': 8.5,
'repo_id': 'google/flan-t5-base',
'license': 'Apache 2.0'
},
'flan_t5_large': {
'name': 'FLAN-T5 Large',
'description': 'نموذج نصوص كبير عالي الأداء',
'type': 'text',
'modalities': ['text'],
'parameters': '780M',
'size_category': 'large',
'use_cases': ['المهام المعقدة', 'التحليل المتقدم', 'الكتابة الإبداعية'],
'performance_score': 9.2,
'repo_id': 'google/flan-t5-large',
'license': 'Apache 2.0'
},
'vit_base': {
'name': 'Vision Transformer Base',
'description': 'نموذج رؤية حاسوبية متقدم',
'type': 'vision',
'modalities': ['vision'],
'parameters': '86M',
'size_category': 'medium',
'use_cases': ['تصنيف الصور', 'التعرف على الأشياء', 'تحليل المحتوى البصري'],
'performance_score': 8.8,
'repo_id': 'google/vit-base-patch16-224',
'license': 'Apache 2.0'
},
'clip_vit': {
'name': 'CLIP Vision-Text',
'description': 'نموذج متعدد الوسائط يربط النصوص والصور',
'type': 'multimodal',
'modalities': ['text', 'vision'],
'parameters': '400M',
'size_category': 'large',
'use_cases': ['البحث بالصور', 'وصف الصور', 'التصنيف متعدد الوسائط'],
'performance_score': 9.0,
'repo_id': 'openai/clip-vit-base-patch32',
'license': 'MIT'
},
'bert_base': {
'name': 'BERT Base',
'description': 'نموذج فهم اللغة الطبيعية الكلاسيكي',
'type': 'text',
'modalities': ['text'],
'parameters': '110M',
'size_category': 'small',
'use_cases': ['تحليل المشاعر', 'تصنيف النصوص', 'استخراج المعلومات'],
'performance_score': 8.0,
'repo_id': 'bert-base-uncased',
'license': 'Apache 2.0'
}
}
return {
"success": True,
"models": google_models,
"total_models": len(google_models),
"categories": {
"text": len([m for m in google_models.values() if m['type'] == 'text']),
"vision": len([m for m in google_models.values() if m['type'] == 'vision']),
"multimodal": len([m for m in google_models.values() if m['type'] == 'multimodal'])
}
}
except Exception as e:
logger.error(f"Error getting Google models: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/model-configuration/save")
async def save_model_configuration(configuration: dict):
"""Save user's model configuration"""
try:
from database.medical_selections import MedicalSelectionsDB
import json
user_session = configuration.get('user_session')
teachers = configuration.get('teachers', [])
student = configuration.get('student')
if not user_session:
raise HTTPException(status_code=400, detail="User session is required")
if not teachers:
raise HTTPException(status_code=400, detail="At least one teacher model is required")
# Save to database
db = MedicalSelectionsDB()
# Create configuration record
config_data = {
'teachers': teachers,
'student': student,
'timestamp': configuration.get('timestamp'),
'total_teachers': len(teachers),
'student_type': student.get('type') if student else 'new'
}
# Save as user preferences
success = db.save_user_preferences(user_session, {
'model_configuration': config_data,
'last_updated': configuration.get('timestamp')
})
if success:
return {
"success": True,
"message": f"تم حفظ تكوين {len(teachers)} نماذج معلمة بنجاح",
"configuration": config_data
}
else:
raise HTTPException(status_code=500, detail="Failed to save configuration")
except HTTPException:
raise
except Exception as e:
logger.error(f"Error saving model configuration: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/model-configuration/{user_session}")
async def get_model_configuration(user_session: str):
"""Get user's saved model configuration"""
try:
from database.medical_selections import MedicalSelectionsDB
db = MedicalSelectionsDB()
preferences = db.get_user_preferences(user_session)
model_config = preferences.get('model_configuration', {})
return {
"success": True,
"teachers": model_config.get('teachers', []),
"student": model_config.get('student'),
"last_updated": preferences.get('last_updated'),
"total_teachers": len(model_config.get('teachers', []))
}
except Exception as e:
logger.error(f"Error getting model configuration: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/model-configuration/{user_session}")
async def clear_model_configuration(user_session: str):
"""Clear user's model configuration"""
try:
from database.medical_selections import MedicalSelectionsDB
db = MedicalSelectionsDB()
# Clear model configuration
success = db.save_user_preferences(user_session, {
'model_configuration': {},
'last_updated': None
})
if success:
return {
"success": True,
"message": "تم مسح تكوين النماذج بنجاح"
}
else:
raise HTTPException(status_code=500, detail="Failed to clear configuration")
except Exception as e:
logger.error(f"Error clearing model configuration: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Token Management Endpoints
@app.get("/tokens")
async def token_management_page(request: Request):
"""Token management page"""
return templates.TemplateResponse("token-management.html", {"request": request})
@app.post("/api/tokens")
async def save_token(
name: str = Form(...),
token: str = Form(...),
token_type: str = Form("read"),
description: str = Form(""),
is_default: bool = Form(False)
):
"""Save HF token"""
try:
success = token_manager.save_token(name, token, token_type, description, is_default)
if success:
return {"success": True, "message": f"Token '{name}' saved successfully"}
else:
raise HTTPException(status_code=400, detail="Failed to save token")
except Exception as e:
logger.error(f"Error saving token: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/tokens")
async def list_tokens():
"""List all saved tokens"""
try:
tokens = token_manager.list_tokens()
return {"tokens": tokens}
except Exception as e:
logger.error(f"Error listing tokens: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/tokens/{token_name}")
async def delete_token(token_name: str):
"""Delete a token"""
try:
success = token_manager.delete_token(token_name)
if success:
return {"success": True, "message": f"Token '{token_name}' deleted"}
else:
raise HTTPException(status_code=404, detail="Token not found")
except Exception as e:
logger.error(f"Error deleting token: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/tokens/{token_name}/set-default")
async def set_default_token(token_name: str):
"""Set token as default"""
try:
success = token_manager.set_default_token(token_name)
if success:
return {"success": True, "message": f"Token '{token_name}' set as default"}
else:
raise HTTPException(status_code=404, detail="Token not found")
except Exception as e:
logger.error(f"Error setting default token: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/tokens/validate")
async def validate_token(token: str = Form(...)):
"""Validate HF token"""
try:
result = token_manager.validate_token(token)
return result
except Exception as e:
logger.error(f"Error validating token: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/tokens/for-task/{task_type}")
async def get_token_for_task(task_type: str):
"""Get appropriate token for specific task"""
try:
# Get token for task
token = token_manager.get_token_for_task(task_type)
if not token:
raise HTTPException(status_code=404, detail=f"No suitable token found for task: {task_type}")
# Get token information
tokens = token_manager.list_tokens()
token_info = None
# Find which token was selected
for t in tokens:
test_token = token_manager.get_token(t['name'])
if test_token == token:
token_info = t
break
if not token_info:
# Token from environment variable
token_info = {
'name': f'{task_type}_token',
'type': task_type,
'description': f'رمز من متغيرات البيئة للمهمة: {task_type}',
'last_used': None,
'usage_count': 0
}
# Get token type information
type_info = token_manager.token_types.get(token_info['type'], {})
return {
"success": True,
"task_type": task_type,
"token_info": {
"token_name": token_info['name'],
"type": token_info['type'],
"type_name": type_info.get('name', token_info['type']),
"description": token_info['description'],
"security_level": type_info.get('security_level', 'medium'),
"recommended_for": type_info.get('recommended_for', 'general'),
"last_used": token_info.get('last_used'),
"usage_count": token_info.get('usage_count', 0)
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting token for task {task_type}: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Medical Dataset Endpoints
@app.get("/medical-datasets")
async def medical_datasets_page(request: Request):
"""Medical datasets management page"""
return templates.TemplateResponse("medical-datasets.html", {"request": request})
# Google Models Endpoints
@app.get("/google-models")
async def google_models_page(request: Request):
"""Google models selection page"""
return templates.TemplateResponse("google-models.html", {"request": request})
@app.get("/api/medical-datasets")
async def list_medical_datasets():
"""List supported medical datasets"""
try:
datasets = medical_dataset_manager.list_supported_datasets()
return {"datasets": datasets}
except Exception as e:
logger.error(f"Error listing medical datasets: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/medical-datasets/load")
async def load_medical_dataset(
dataset_name: str = Form(...),
streaming: bool = Form(True),
split: str = Form("train")
):
"""Load medical dataset"""
try:
# Get appropriate token for medical datasets (fine-grained preferred)
hf_token = token_manager.get_token_for_task('medical')
if not hf_token:
logger.warning("No suitable token found for medical datasets, trying default")
hf_token = token_manager.get_token()
dataset_info = await medical_dataset_manager.load_dataset(
dataset_name=dataset_name,
streaming=streaming,
split=split,
token=hf_token
)
return {
"success": True,
"dataset_info": {
"name": dataset_info['config']['name'],
"size_gb": dataset_info['config']['size_gb'],
"num_samples": dataset_info['config']['num_samples'],
"streaming": dataset_info['streaming']
}
}
except Exception as e:
logger.error(f"Error loading medical dataset: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Memory and Performance Endpoints
@app.get("/api/system/memory")
async def get_memory_info():
"""Get current memory information"""
try:
memory_info = memory_manager.get_memory_info()
return memory_info
except Exception as e:
logger.error(f"Error getting memory info: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/system/performance")
async def get_performance_info():
"""Get system performance information"""
try:
memory_info = memory_manager.get_memory_info()
recommendations = memory_manager.get_memory_recommendations()
return {
"memory": memory_info,
"recommendations": recommendations,
"cpu_cores": cpu_optimizer.cpu_count,
"optimizations_applied": cpu_optimizer.optimizations_applied
}
except Exception as e:
logger.error(f"Error getting performance info: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/system/cleanup")
async def force_memory_cleanup():
"""Force memory cleanup"""
try:
memory_manager.force_cleanup()
return {"success": True, "message": "Memory cleanup completed"}
except Exception as e:
logger.error(f"Error during memory cleanup: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Google Models Support
@app.get("/api/models/google")
async def list_google_models():
"""List available Google models"""
try:
google_models = [
{
"name": "google/medsiglip-448",
"description": "Medical SigLIP model for medical image-text understanding",
"type": "vision-language",
"size_gb": 1.1,
"modality": "multimodal",
"medical_specialized": True
},
{
"name": "google/gemma-3n-E4B-it",
"description": "Gemma 3 model for instruction following",
"type": "language",
"size_gb": 8.5,
"modality": "text",
"medical_specialized": False
}
]
return {"models": google_models}
except Exception as e:
logger.error(f"Error listing Google models: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="0.0.0.0",
port=int(os.getenv("PORT", 7860)),
reload=False,
log_level="info"
)