phoneme_transciptor / phoneme.py
thanhhungtakeshi's picture
add grammar correction endpoint
69d1fbe
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from itertools import groupby
def decode_phonemes(
ids: torch.Tensor,
processor: Wav2Vec2Processor,
ignore_stress: bool = True,
ignore_pause: bool = True,
) -> str:
"""CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
# removes consecutive duplicates
ids = [id_ for id_, _ in groupby(ids)]
special_token_ids = processor.tokenizer.all_special_ids + [
processor.tokenizer.word_delimiter_token_id
]
# converts id to token, skipping special tokens
phonemes = [processor.decode(id_)
for id_ in ids if id_ not in special_token_ids]
# whether to ignore IPA stress marks
if ignore_stress is True:
phonemes = [p.replace("ˈ", "").replace("ˌ", "") for p in phonemes]
if ignore_pause is True:
phonemes = [p for p in phonemes if p not in ["h#", "pau"]]
return phonemes
def transcribe_tensor(model: Wav2Vec2ForCTC,
processor: Wav2Vec2Processor,
waveform: torch.Tensor,
sr: int,
device=torch.device("cpu")) -> str:
inputs = processor(
waveform.squeeze(0).numpy(),
sampling_rate=sr,
return_tensors="pt",
padding=True
).input_values.to(device)
with torch.inference_mode():
logits = model(inputs).logits
pred_ids = torch.argmax(logits, dim=-1)
return decode_phonemes(pred_ids[0], processor, ignore_stress=True)