Spaces:
Sleeping
Sleeping
File size: 1,566 Bytes
69d1fbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from itertools import groupby
def decode_phonemes(
ids: torch.Tensor,
processor: Wav2Vec2Processor,
ignore_stress: bool = True,
ignore_pause: bool = True,
) -> str:
"""CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
# removes consecutive duplicates
ids = [id_ for id_, _ in groupby(ids)]
special_token_ids = processor.tokenizer.all_special_ids + [
processor.tokenizer.word_delimiter_token_id
]
# converts id to token, skipping special tokens
phonemes = [processor.decode(id_)
for id_ in ids if id_ not in special_token_ids]
# whether to ignore IPA stress marks
if ignore_stress is True:
phonemes = [p.replace("ˈ", "").replace("ˌ", "") for p in phonemes]
if ignore_pause is True:
phonemes = [p for p in phonemes if p not in ["h#", "pau"]]
return phonemes
def transcribe_tensor(model: Wav2Vec2ForCTC,
processor: Wav2Vec2Processor,
waveform: torch.Tensor,
sr: int,
device=torch.device("cpu")) -> str:
inputs = processor(
waveform.squeeze(0).numpy(),
sampling_rate=sr,
return_tensors="pt",
padding=True
).input_values.to(device)
with torch.inference_mode():
logits = model(inputs).logits
pred_ids = torch.argmax(logits, dim=-1)
return decode_phonemes(pred_ids[0], processor, ignore_stress=True)
|