Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| from itertools import groupby | |
| def decode_phonemes( | |
| ids: torch.Tensor, | |
| processor: Wav2Vec2Processor, | |
| ignore_stress: bool = True, | |
| ignore_pause: bool = True, | |
| ) -> str: | |
| """CTC-like decoding. First removes consecutive duplicates, then removes special tokens.""" | |
| # removes consecutive duplicates | |
| ids = [id_ for id_, _ in groupby(ids)] | |
| special_token_ids = processor.tokenizer.all_special_ids + [ | |
| processor.tokenizer.word_delimiter_token_id | |
| ] | |
| # converts id to token, skipping special tokens | |
| phonemes = [processor.decode(id_) | |
| for id_ in ids if id_ not in special_token_ids] | |
| # whether to ignore IPA stress marks | |
| if ignore_stress is True: | |
| phonemes = [p.replace("ˈ", "").replace("ˌ", "") for p in phonemes] | |
| if ignore_pause is True: | |
| phonemes = [p for p in phonemes if p not in ["h#", "pau"]] | |
| return phonemes | |
| def transcribe_tensor(model: Wav2Vec2ForCTC, | |
| processor: Wav2Vec2Processor, | |
| waveform: torch.Tensor, | |
| sr: int, | |
| device=torch.device("cpu")) -> str: | |
| inputs = processor( | |
| waveform.squeeze(0).numpy(), | |
| sampling_rate=sr, | |
| return_tensors="pt", | |
| padding=True | |
| ).input_values.to(device) | |
| with torch.inference_mode(): | |
| logits = model(inputs).logits | |
| pred_ids = torch.argmax(logits, dim=-1) | |
| return decode_phonemes(pred_ids[0], processor, ignore_stress=True) | |