File size: 4,485 Bytes
f31ad7c
17ffacb
 
69d1fbe
 
 
 
17ffacb
 
 
 
38cee9b
17ffacb
 
 
 
 
 
 
69d1fbe
64d8bb5
 
c80fc63
1f5c9b2
c80fc63
64d8bb5
 
1f5c9b2
7f675eb
64d8bb5
59a8ebd
38cee9b
69d1fbe
1f5c9b2
 
64d8bb5
 
26b54c7
65e53c1
17ffacb
 
 
 
 
 
 
64d8bb5
1fbfc02
 
 
 
 
 
 
 
 
3a15b02
69d1fbe
 
 
 
 
 
 
1fbfc02
64d8bb5
 
 
1fbfc02
 
 
 
1f5c9b2
64d8bb5
3a15b02
59a8ebd
 
 
 
bb2c162
 
 
 
 
 
59a8ebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb33b00
59a8ebd
 
286596c
59a8ebd
 
286596c
59a8ebd
286596c
 
59a8ebd
 
fb33b00
286596c
fb33b00
 
2364325
 
286596c
fb33b00
59a8ebd
 
1f5c9b2
4a55f3c
38cee9b
4d6a27e
38cee9b
 
 
 
 
 
 
 
 
4a55f3c
c80fc63
64d8bb5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from grammar import analyse_grammar_gf
from phoneme import transcribe_tensor
from audio import decode_audio_bytes, preprocess_audio
from utils import (
    arpabet_to_ipa_seq,
    levenshtein_similarity_score as similarity_score
)
from model import (
    AlignmentRequest,
    CorrectionRequest
)
from gramformer import Gramformer
from minineedle import needle, core
from g2p_en import G2p
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
from fastapi.responses import JSONResponse
from fastapi import FastAPI, UploadFile, File
import os


# Configure environment
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:64")

DEVICE = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")
TARGET_SR = 16000
MODEL_ID = "vitouphy/wav2vec2-xls-r-300m-phoneme"
# Load model and processor
g2p = G2p()
gf = Gramformer(models=1, use_gpu=torch.cuda.is_available())

processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to(DEVICE).eval()
# App instance
app = FastAPI(title="Audio Phoneme Transcription API")


# on app shutdown
@app.on_event("shutdown")
def shutdown_event():
    if language_tool:
        language_tool.close()


# region FastAPI app
@app.post("/transcribe_file/")
async def transcribe_file(audio_file: UploadFile = File(...)):
    try:
        audio_bytes = await audio_file.read()
        if not audio_file:
            return JSONResponse(status_code=400, content={"error": "Empty audio data"})

        content_type = audio_file.content_type
        audio, sr = decode_audio_bytes(audio_bytes, content_type)
        waveform = preprocess_audio(audio, sr, TARGET_SR)
        phonemes = transcribe_tensor(
            model=model,
            processor=processor,
            waveform=waveform,
            sr=TARGET_SR,
            device=DEVICE
        )

        print(
            f"[OK] {waveform.shape[-1]/TARGET_SR:.2f}s audio")
        return {"transcription": phonemes}
    except Exception as e:
        print(f"[ERR] {e}")
        return JSONResponse(status_code=500, content={"error": str(e)})


@app.post("/align_phoneme/")
async def align_phoneme(request: AlignmentRequest):
    # Extract phonemes and words
    phonemes_pred = request.phonemes
    words = request.words

    if not phonemes_pred or not words:
        return JSONResponse(status_code=400, content={"error": "Empty phonemes or words"})

    # Convert predicted phonemes from ARPAbet to IPA
    phonemes_pred = arpabet_to_ipa_seq(phonemes_pred)

    # Convert words to ground truh phonemes
    phonemes_gt = []
    word_boundaries = []
    for word in words:
        phs = [p for p in g2p(word) if p != ' ']
        phs = arpabet_to_ipa_seq(phs)
        word_boundaries.append(
            (
                word,
                len(phs)
            )
        )
        phonemes_gt.extend(phs)

    # Perform alignment
    alignment = needle.NeedlemanWunsch(phonemes_gt, phonemes_pred)
    alignment.align()
    al_gt, al_pred = alignment.get_aligned_sequences()
    al_gt = ["-" if isinstance(a, core.Gap) else a for a in al_gt]
    al_pred = ["-" if isinstance(a, core.Gap) else a for a in al_pred]

    # Map back phonemes to words
    word_to_pred = []
    current_idx = 0
    for word, word_len in word_boundaries:
        _phonemes_gt_len = 0
        _phonemes_gt = []
        _phonemes_pred = []
        while _phonemes_gt_len != word_len and current_idx < len(al_gt):
            if al_gt[current_idx] != "-":
                _phonemes_gt_len += 1
            _phonemes_gt.append(al_gt[current_idx])
            _phonemes_pred.append(al_pred[current_idx])
            current_idx += 1

        score = similarity_score(_phonemes_gt, _phonemes_pred)
        word_to_pred.append({
            "word": word,
            "correct_ipa": _phonemes_gt,
            "user_ipa": _phonemes_pred,
            "score": score
        })

    return {"alignment": word_to_pred}


@app.post("/correct_grammar_gf/")
async def correct_grammar_gf(request: CorrectionRequest):
    # Extract phonemes and words
    transcript = request.transcript
    if not transcript:
        return JSONResponse(status_code=400, content={"error": "Empty transcript"})

    corrections = analyse_grammar_gf(transcript, gf)
    return {"corrections": corrections}


@app.get("/")
def health():
    return {"status": "ok", "model": MODEL_ID, "device": str(DEVICE)}