File size: 1,566 Bytes
69d1fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from itertools import groupby


def decode_phonemes(
    ids: torch.Tensor,
    processor: Wav2Vec2Processor,
    ignore_stress: bool = True,
    ignore_pause: bool = True,
) -> str:
    """CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
    # removes consecutive duplicates
    ids = [id_ for id_, _ in groupby(ids)]

    special_token_ids = processor.tokenizer.all_special_ids + [
        processor.tokenizer.word_delimiter_token_id
    ]
    # converts id to token, skipping special tokens
    phonemes = [processor.decode(id_)
                for id_ in ids if id_ not in special_token_ids]

    # whether to ignore IPA stress marks
    if ignore_stress is True:
        phonemes = [p.replace("ˈ", "").replace("ˌ", "") for p in phonemes]

    if ignore_pause is True:
        phonemes = [p for p in phonemes if p not in ["h#", "pau"]]

    return phonemes


def transcribe_tensor(model: Wav2Vec2ForCTC,
                      processor: Wav2Vec2Processor,
                      waveform: torch.Tensor,
                      sr: int,
                      device=torch.device("cpu")) -> str:
    inputs = processor(
        waveform.squeeze(0).numpy(),
        sampling_rate=sr,
        return_tensors="pt",
        padding=True
    ).input_values.to(device)
    with torch.inference_mode():
        logits = model(inputs).logits
        pred_ids = torch.argmax(logits, dim=-1)
    return decode_phonemes(pred_ids[0], processor, ignore_stress=True)