Spaces:
Running
Running
| import os | |
| import tempfile | |
| import logging | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import whisper | |
| import librosa | |
| import asyncio | |
| from typing import List, Dict, Any, Optional | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification | |
| from speechbrain.inference.classifiers import EncoderClassifier | |
| import torchaudio | |
| import json | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| """Centralized model management for all ML models.""" | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(ModelManager, cls).__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: | |
| return | |
| self._initialized = True | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| self.emotion_model = None | |
| self.whisper_model = None | |
| self.text_tokenizer = None | |
| self.text_model = None | |
| self.speechbrain_model = None | |
| # Model paths | |
| self.MODEL_PATHS = { | |
| 'whisper_model': 'base', | |
| 'text_model': 'emotion-distilbert-model', | |
| 'speechbrain_model': 'speechbrain/emotion-recognition-wav2vec2-IEMOCAP' | |
| } | |
| # Constants | |
| self.EMOTIONS = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"] | |
| self.SAMPLE_RATE = 16000 | |
| self.TEXT_EMOTIONS = ["sadness", "joy", "love", "anger", "fear", "surprise"] | |
| # SpeechBrain emotion mapping | |
| self.SPEECHBRAIN_EMOTION_MAP = { | |
| 'neu': 'Neutral', | |
| 'hap': 'Happy', | |
| 'sad': 'Sad', | |
| 'ang': 'Angry', | |
| 'fea': 'Fear', | |
| 'dis': 'Disgust', | |
| 'sur': 'Surprise' | |
| } | |
| def load_all_models(self): | |
| """Load all required models.""" | |
| try: | |
| logger.info("Starting to load all models...") | |
| self._load_emotion_model() | |
| self._load_whisper_model() | |
| self._load_text_models() | |
| self._load_speechbrain_model() | |
| logger.info("All models loaded successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading models: {str(e)}") | |
| raise | |
| def _load_emotion_model(self): | |
| """Use DeepFace for emotion recognition.""" | |
| try: | |
| logger.info("Loading DeepFace for emotion recognition...") | |
| from deepface import DeepFace | |
| self.emotion_model = DeepFace | |
| logger.info("DeepFace loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize DeepFace: {str(e)}") | |
| raise | |
| def _load_whisper_model(self): | |
| """Load the Whisper speech-to-text model.""" | |
| try: | |
| logger.info("Loading Whisper model...") | |
| self.whisper_model = whisper.load_model(self.MODEL_PATHS['whisper_model']) | |
| logger.info("Whisper model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load Whisper model: {str(e)}") | |
| raise | |
| def _load_text_models(self): | |
| """Load the text emotion classification model and tokenizer.""" | |
| try: | |
| logger.info("Loading text emotion model...") | |
| model_path = self.MODEL_PATHS['text_model'] | |
| # Try to load from local path first, then from HuggingFace Hub | |
| if os.path.exists(model_path): | |
| self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(model_path) | |
| self.text_model = DistilBertForSequenceClassification.from_pretrained(model_path) | |
| else: | |
| # Use a public emotion model from HuggingFace | |
| logger.info("Local model not found, using HuggingFace model...") | |
| self.text_tokenizer = DistilBertTokenizerFast.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") | |
| self.text_model = DistilBertForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") | |
| self.text_model.eval() | |
| logger.info("Text models loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load text models: {str(e)}") | |
| raise | |
| def _load_speechbrain_model(self): | |
| """Load SpeechBrain emotion recognition model.""" | |
| try: | |
| logger.info("Loading SpeechBrain emotion recognition model...") | |
| self.speechbrain_model = EncoderClassifier.from_hparams( | |
| source=self.MODEL_PATHS['speechbrain_model'], | |
| savedir="pretrained_models/emotion-recognition-wav2vec2-IEMOCAP", | |
| run_opts={"device": "cpu"} | |
| ) | |
| logger.info("SpeechBrain emotion recognition model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load SpeechBrain model: {str(e)}") | |
| raise | |
| def get_emotion_model(self): | |
| if self.emotion_model is None: | |
| self._load_emotion_model() | |
| return self.emotion_model | |
| def get_whisper_model(self): | |
| if self.whisper_model is None: | |
| self._load_whisper_model() | |
| return self.whisper_model | |
| def get_text_models(self): | |
| if self.text_model is None or self.text_tokenizer is None: | |
| self._load_text_models() | |
| return self.text_tokenizer, self.text_model | |
| def get_speechbrain_model(self): | |
| if self.speechbrain_model is None: | |
| self._load_speechbrain_model() | |
| return self.speechbrain_model | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Manan ML API - Emotion Recognition") | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"] | |
| ) | |
| # Initialize model manager | |
| model_manager = ModelManager() | |
| # Image transformation pipeline | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| async def startup_event(): | |
| """Initialize all models when the application starts.""" | |
| try: | |
| logger.info("Starting model initialization...") | |
| model_manager.load_all_models() | |
| logger.info("All models initialized successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize models: {str(e)}") | |
| # Don't raise - let the app start and load models on demand | |
| async def root(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "running", | |
| "message": "Manan ML API is running!", | |
| "endpoints": [ | |
| "/pred_face - Face emotion prediction", | |
| "/predict_audio_batch - Voice emotion prediction", | |
| "/predict_text/ - Text emotion prediction" | |
| ] | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy", "device": str(model_manager.device)} | |
| # Helper function for SpeechBrain prediction | |
| def predict_emotion_speechbrain(audio_path: str) -> Dict[str, Any]: | |
| """Predict emotion from audio using SpeechBrain.""" | |
| try: | |
| speechbrain_model = model_manager.get_speechbrain_model() | |
| signal, sr = torchaudio.load(audio_path) | |
| if sr != 16000: | |
| resampler = torchaudio.transforms.Resample(sr, 16000) | |
| signal = resampler(signal) | |
| if signal.dim() == 1: | |
| signal = signal.unsqueeze(0) | |
| elif signal.dim() == 3: | |
| signal = signal.squeeze(1) | |
| device = next(speechbrain_model.mods.wav2vec2.parameters()).device | |
| signal = signal.to(device) | |
| with torch.no_grad(): | |
| feats = speechbrain_model.mods.wav2vec2(signal) | |
| pooled = speechbrain_model.mods.avg_pool(feats) | |
| out = speechbrain_model.mods.output_mlp(pooled) | |
| out_prob = speechbrain_model.hparams.softmax(out) | |
| score, index = torch.max(out_prob, dim=-1) | |
| predicted_emotion = speechbrain_model.hparams.label_encoder.decode_ndim(index.cpu()) | |
| if isinstance(predicted_emotion, list): | |
| if isinstance(predicted_emotion[0], list): | |
| emotion_key = str(predicted_emotion[0][0]).lower()[:3] | |
| else: | |
| emotion_key = str(predicted_emotion[0]).lower()[:3] | |
| else: | |
| emotion_key = str(predicted_emotion).lower()[:3] | |
| emotion = model_manager.SPEECHBRAIN_EMOTION_MAP.get(emotion_key, 'Neutral') | |
| probs = out_prob[0].detach().cpu().numpy() | |
| if probs.ndim > 1: | |
| probs = probs.flatten() | |
| all_emotions = speechbrain_model.hparams.label_encoder.decode_ndim( | |
| torch.arange(len(probs)) | |
| ) | |
| prob_dict = {} | |
| for i in range(len(probs)): | |
| if i < len(all_emotions): | |
| if isinstance(all_emotions[i], list): | |
| key = str(all_emotions[i][0]).lower()[:3] | |
| else: | |
| key = str(all_emotions[i]).lower()[:3] | |
| emotion_name = model_manager.SPEECHBRAIN_EMOTION_MAP.get(key, f'emotion_{i}') | |
| prob_dict[emotion_name] = float(probs[i]) | |
| confidence = float(score[0]) | |
| return { | |
| 'emotion': emotion, | |
| 'confidence': confidence, | |
| 'probabilities': prob_dict | |
| } | |
| except Exception as e: | |
| logger.error(f"Error predicting emotion with SpeechBrain: {str(e)}") | |
| raise | |
| def transcribe_audio(audio_path: str) -> str: | |
| """Transcribe audio to text using Whisper.""" | |
| try: | |
| result = model_manager.whisper_model.transcribe(audio_path) | |
| return result["text"].strip() | |
| except Exception as e: | |
| logger.error(f"Error in audio transcription: {str(e)}") | |
| return "" | |
| # ============== API ENDPOINTS ============== | |
| async def predict_face_emotion( | |
| files: List[UploadFile] = File(...), | |
| questions: str = Form(None) | |
| ): | |
| """Predict emotions from face images using DeepFace.""" | |
| from deepface import DeepFace | |
| logger.info(f"Received {len(files)} files for face prediction") | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No files provided") | |
| temp_files = [] | |
| try: | |
| questions_data = {} | |
| question_count = 0 | |
| if questions: | |
| try: | |
| questions_data = json.loads(questions) | |
| question_count = len(questions_data) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid questions JSON format.") | |
| else: | |
| question_count = 3 | |
| questions_data = {str(i): {"text": f"Question {i+1}", "imageCount": 1} for i in range(question_count)} | |
| question_files = {str(i): [] for i in range(question_count)} | |
| for file in files: | |
| if '_' in file.filename and file.filename.startswith('q'): | |
| try: | |
| q_idx = file.filename.split('_')[0][1:] | |
| if q_idx in question_files: | |
| question_files[q_idx].append(file) | |
| except Exception as e: | |
| logger.warning(f"Skipping file {file.filename}: {e}") | |
| results = [] | |
| for q_idx, q_files in question_files.items(): | |
| if not q_files: | |
| results.append({ | |
| "emotion": "Unknown", | |
| "probabilities": {e: 0.0 for e in model_manager.EMOTIONS} | |
| }) | |
| continue | |
| probs_list = [] | |
| for file in q_files: | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| content = await file.read() | |
| tmp.write(content) | |
| temp_path = tmp.name | |
| temp_files.append(temp_path) | |
| analysis = DeepFace.analyze( | |
| img_path=temp_path, | |
| actions=['emotion'], | |
| enforce_detection=False, | |
| silent=True | |
| ) | |
| if isinstance(analysis, list): | |
| analysis = analysis[0] | |
| emotion_scores = analysis.get('emotion', {}) | |
| dominant_emotion = analysis.get('dominant_emotion', 'neutral') | |
| normalized_probs = {} | |
| for emo in model_manager.EMOTIONS: | |
| key = emo.lower() | |
| normalized_probs[emo] = emotion_scores.get(key, 0.0) / 100.0 | |
| probs_list.append(normalized_probs) | |
| except Exception as e: | |
| logger.error(f"Error processing {file.filename}: {e}") | |
| if probs_list: | |
| avg_probs = {} | |
| for emo in model_manager.EMOTIONS: | |
| avg_probs[emo] = sum(p.get(emo, 0) for p in probs_list) / len(probs_list) | |
| dominant_emotion = max(avg_probs, key=avg_probs.get) | |
| results.append({ | |
| "emotion": dominant_emotion, | |
| "probabilities": avg_probs | |
| }) | |
| else: | |
| results.append({ | |
| "emotion": "Unknown", | |
| "probabilities": {e: 0.0 for e in model_manager.EMOTIONS} | |
| }) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error in face emotion prediction: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| for file_path in temp_files: | |
| try: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| except Exception as e: | |
| logger.warning(f"Failed to delete temp file {file_path}: {e}") | |
| async def predict_audio_batch(files: List[UploadFile] = File(...)): | |
| """Predict emotions from multiple audio files using SpeechBrain.""" | |
| logger.info(f"Received {len(files)} audio files for prediction") | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No audio files provided") | |
| temp_files = [] | |
| results = [] | |
| try: | |
| for file in files: | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| content = await file.read() | |
| tmp.write(content) | |
| temp_path = tmp.name | |
| temp_files.append(temp_path) | |
| prediction = predict_emotion_speechbrain(temp_path) | |
| results.append(prediction) | |
| logger.info(f"Predicted emotion for {file.filename}: {prediction['emotion']}") | |
| except Exception as e: | |
| logger.error(f"Error processing {file.filename}: {e}") | |
| results.append({ | |
| 'emotion': 'Unknown', | |
| 'confidence': 0.0, | |
| 'probabilities': {}, | |
| 'error': str(e) | |
| }) | |
| return {'status': 'success', 'results': results} | |
| except Exception as e: | |
| logger.error(f"Error in audio batch prediction: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| for file_path in temp_files: | |
| try: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| except Exception as e: | |
| logger.warning(f"Failed to delete temp file {file_path}: {e}") | |
| async def predict_text_emotion(files: List[UploadFile] = File(...)): | |
| """Transcribe audio and predict text emotion.""" | |
| logger.info(f"Received {len(files)} audio files for text prediction") | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No audio files provided") | |
| temp_files = [] | |
| results = [] | |
| try: | |
| tokenizer, text_model = model_manager.get_text_models() | |
| whisper_model = model_manager.get_whisper_model() | |
| for file in files: | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| content = await file.read() | |
| tmp.write(content) | |
| temp_path = tmp.name | |
| temp_files.append(temp_path) | |
| # Transcribe | |
| transcription = whisper_model.transcribe(temp_path) | |
| transcript = transcription["text"].strip() | |
| logger.info(f"Transcribed: {transcript}") | |
| if not transcript: | |
| results.append({ | |
| 'transcript': '', | |
| 'emotion': 'neutral', | |
| 'confidence': 0.0, | |
| 'probabilities': {} | |
| }) | |
| continue | |
| # Predict emotion from text | |
| inputs = tokenizer( | |
| transcript, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=128, | |
| padding=True | |
| ) | |
| with torch.no_grad(): | |
| outputs = text_model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=1)[0] | |
| # Get emotion labels | |
| emotion_labels = model_manager.TEXT_EMOTIONS | |
| if hasattr(text_model.config, 'id2label'): | |
| emotion_labels = [text_model.config.id2label[i] for i in range(len(probs))] | |
| prob_dict = {emotion_labels[i]: float(probs[i]) for i in range(len(probs))} | |
| predicted_idx = torch.argmax(probs).item() | |
| predicted_emotion = emotion_labels[predicted_idx] | |
| confidence = float(probs[predicted_idx]) | |
| results.append({ | |
| 'transcript': transcript, | |
| 'emotion': predicted_emotion, | |
| 'confidence': confidence, | |
| 'probabilities': prob_dict | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error processing {file.filename}: {e}") | |
| results.append({ | |
| 'transcript': '', | |
| 'emotion': 'unknown', | |
| 'confidence': 0.0, | |
| 'error': str(e) | |
| }) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error in text prediction: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| for file_path in temp_files: | |
| try: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| except Exception as e: | |
| logger.warning(f"Failed to delete temp file {file_path}: {e}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |