Spaces:
Sleeping
Sleeping
| from grammar import analyse_grammar_gf | |
| from phoneme import transcribe_tensor | |
| from audio import decode_audio_bytes, preprocess_audio | |
| from utils import ( | |
| arpabet_to_ipa_seq, | |
| levenshtein_similarity_score as similarity_score | |
| ) | |
| from model import ( | |
| AlignmentRequest, | |
| CorrectionRequest | |
| ) | |
| from gramformer import Gramformer | |
| from minineedle import needle, core | |
| from g2p_en import G2p | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| import torch | |
| from fastapi.responses import JSONResponse | |
| from fastapi import FastAPI, UploadFile, File | |
| import os | |
| # Configure environment | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:64") | |
| DEVICE = torch.device( | |
| "cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| TARGET_SR = 16000 | |
| MODEL_ID = "vitouphy/wav2vec2-xls-r-300m-phoneme" | |
| # Load model and processor | |
| g2p = G2p() | |
| gf = Gramformer(models=1, use_gpu=torch.cuda.is_available()) | |
| processor = Wav2Vec2Processor.from_pretrained(MODEL_ID) | |
| model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to(DEVICE).eval() | |
| # App instance | |
| app = FastAPI(title="Audio Phoneme Transcription API") | |
| # on app shutdown | |
| def shutdown_event(): | |
| if language_tool: | |
| language_tool.close() | |
| # region FastAPI app | |
| async def transcribe_file(audio_file: UploadFile = File(...)): | |
| try: | |
| audio_bytes = await audio_file.read() | |
| if not audio_file: | |
| return JSONResponse(status_code=400, content={"error": "Empty audio data"}) | |
| content_type = audio_file.content_type | |
| audio, sr = decode_audio_bytes(audio_bytes, content_type) | |
| waveform = preprocess_audio(audio, sr, TARGET_SR) | |
| phonemes = transcribe_tensor( | |
| model=model, | |
| processor=processor, | |
| waveform=waveform, | |
| sr=TARGET_SR, | |
| device=DEVICE | |
| ) | |
| print( | |
| f"[OK] {waveform.shape[-1]/TARGET_SR:.2f}s audio") | |
| return {"transcription": phonemes} | |
| except Exception as e: | |
| print(f"[ERR] {e}") | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| async def align_phoneme(request: AlignmentRequest): | |
| # Extract phonemes and words | |
| phonemes_pred = request.phonemes | |
| words = request.words | |
| if not phonemes_pred or not words: | |
| return JSONResponse(status_code=400, content={"error": "Empty phonemes or words"}) | |
| # Convert predicted phonemes from ARPAbet to IPA | |
| phonemes_pred = arpabet_to_ipa_seq(phonemes_pred) | |
| # Convert words to ground truh phonemes | |
| phonemes_gt = [] | |
| word_boundaries = [] | |
| for word in words: | |
| phs = [p for p in g2p(word) if p != ' '] | |
| phs = arpabet_to_ipa_seq(phs) | |
| word_boundaries.append( | |
| ( | |
| word, | |
| len(phs) | |
| ) | |
| ) | |
| phonemes_gt.extend(phs) | |
| # Perform alignment | |
| alignment = needle.NeedlemanWunsch(phonemes_gt, phonemes_pred) | |
| alignment.align() | |
| al_gt, al_pred = alignment.get_aligned_sequences() | |
| al_gt = ["-" if isinstance(a, core.Gap) else a for a in al_gt] | |
| al_pred = ["-" if isinstance(a, core.Gap) else a for a in al_pred] | |
| # Map back phonemes to words | |
| word_to_pred = [] | |
| current_idx = 0 | |
| for word, word_len in word_boundaries: | |
| _phonemes_gt_len = 0 | |
| _phonemes_gt = [] | |
| _phonemes_pred = [] | |
| while _phonemes_gt_len != word_len and current_idx < len(al_gt): | |
| if al_gt[current_idx] != "-": | |
| _phonemes_gt_len += 1 | |
| _phonemes_gt.append(al_gt[current_idx]) | |
| _phonemes_pred.append(al_pred[current_idx]) | |
| current_idx += 1 | |
| score = similarity_score(_phonemes_gt, _phonemes_pred) | |
| word_to_pred.append({ | |
| "word": word, | |
| "correct_ipa": _phonemes_gt, | |
| "user_ipa": _phonemes_pred, | |
| "score": score | |
| }) | |
| return {"alignment": word_to_pred} | |
| async def correct_grammar_gf(request: CorrectionRequest): | |
| # Extract phonemes and words | |
| transcript = request.transcript | |
| if not transcript: | |
| return JSONResponse(status_code=400, content={"error": "Empty transcript"}) | |
| corrections = analyse_grammar_gf(transcript, gf) | |
| return {"corrections": corrections} | |
| def health(): | |
| return {"status": "ok", "model": MODEL_ID, "device": str(DEVICE)} | |