thanhhungtakeshi's picture
fix
f31ad7c
from grammar import analyse_grammar_gf
from phoneme import transcribe_tensor
from audio import decode_audio_bytes, preprocess_audio
from utils import (
arpabet_to_ipa_seq,
levenshtein_similarity_score as similarity_score
)
from model import (
AlignmentRequest,
CorrectionRequest
)
from gramformer import Gramformer
from minineedle import needle, core
from g2p_en import G2p
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
from fastapi.responses import JSONResponse
from fastapi import FastAPI, UploadFile, File
import os
# Configure environment
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:64")
DEVICE = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
TARGET_SR = 16000
MODEL_ID = "vitouphy/wav2vec2-xls-r-300m-phoneme"
# Load model and processor
g2p = G2p()
gf = Gramformer(models=1, use_gpu=torch.cuda.is_available())
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to(DEVICE).eval()
# App instance
app = FastAPI(title="Audio Phoneme Transcription API")
# on app shutdown
@app.on_event("shutdown")
def shutdown_event():
if language_tool:
language_tool.close()
# region FastAPI app
@app.post("/transcribe_file/")
async def transcribe_file(audio_file: UploadFile = File(...)):
try:
audio_bytes = await audio_file.read()
if not audio_file:
return JSONResponse(status_code=400, content={"error": "Empty audio data"})
content_type = audio_file.content_type
audio, sr = decode_audio_bytes(audio_bytes, content_type)
waveform = preprocess_audio(audio, sr, TARGET_SR)
phonemes = transcribe_tensor(
model=model,
processor=processor,
waveform=waveform,
sr=TARGET_SR,
device=DEVICE
)
print(
f"[OK] {waveform.shape[-1]/TARGET_SR:.2f}s audio")
return {"transcription": phonemes}
except Exception as e:
print(f"[ERR] {e}")
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/align_phoneme/")
async def align_phoneme(request: AlignmentRequest):
# Extract phonemes and words
phonemes_pred = request.phonemes
words = request.words
if not phonemes_pred or not words:
return JSONResponse(status_code=400, content={"error": "Empty phonemes or words"})
# Convert predicted phonemes from ARPAbet to IPA
phonemes_pred = arpabet_to_ipa_seq(phonemes_pred)
# Convert words to ground truh phonemes
phonemes_gt = []
word_boundaries = []
for word in words:
phs = [p for p in g2p(word) if p != ' ']
phs = arpabet_to_ipa_seq(phs)
word_boundaries.append(
(
word,
len(phs)
)
)
phonemes_gt.extend(phs)
# Perform alignment
alignment = needle.NeedlemanWunsch(phonemes_gt, phonemes_pred)
alignment.align()
al_gt, al_pred = alignment.get_aligned_sequences()
al_gt = ["-" if isinstance(a, core.Gap) else a for a in al_gt]
al_pred = ["-" if isinstance(a, core.Gap) else a for a in al_pred]
# Map back phonemes to words
word_to_pred = []
current_idx = 0
for word, word_len in word_boundaries:
_phonemes_gt_len = 0
_phonemes_gt = []
_phonemes_pred = []
while _phonemes_gt_len != word_len and current_idx < len(al_gt):
if al_gt[current_idx] != "-":
_phonemes_gt_len += 1
_phonemes_gt.append(al_gt[current_idx])
_phonemes_pred.append(al_pred[current_idx])
current_idx += 1
score = similarity_score(_phonemes_gt, _phonemes_pred)
word_to_pred.append({
"word": word,
"correct_ipa": _phonemes_gt,
"user_ipa": _phonemes_pred,
"score": score
})
return {"alignment": word_to_pred}
@app.post("/correct_grammar_gf/")
async def correct_grammar_gf(request: CorrectionRequest):
# Extract phonemes and words
transcript = request.transcript
if not transcript:
return JSONResponse(status_code=400, content={"error": "Empty transcript"})
corrections = analyse_grammar_gf(transcript, gf)
return {"corrections": corrections}
@app.get("/")
def health():
return {"status": "ok", "model": MODEL_ID, "device": str(DEVICE)}