Commit
·
870988c
1
Parent(s):
0b81d49
vosk-model-ru-0.54
Browse files- am-onnx/decoder.int8.onnx +2 -2
- am-onnx/decoder.onnx +2 -2
- am-onnx/encoder.int8.onnx +2 -2
- am-onnx/encoder.onnx +2 -2
- am-onnx/joiner.int8.onnx +2 -2
- am-onnx/joiner.onnx +2 -2
- am/jit_script.pt +2 -2
- decode-onnx.py +46 -0
- decode.py +92 -109
am-onnx/decoder.int8.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6380fc4c6dd867b3d263aef71abe8de5a5785b2fcf0e5d619b4ccc2df2119d4f
|
| 3 |
+
size 540689
|
am-onnx/decoder.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dcbe1ffa0211e77ca6d3a80164df13fbda3ec00e47d12b9f449f89572df12136
|
| 3 |
+
size 2093080
|
am-onnx/encoder.int8.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb6c12fbad810d5bc3e427802e604604c69b5943a91feebc43424dd09d9ec407
|
| 3 |
+
size 70876638
|
am-onnx/encoder.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8bca034acab837e4b30625f4101b27385c8553ea44abfa5bd89c4581667f250c
|
| 3 |
+
size 261058126
|
am-onnx/joiner.int8.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:93f2e1d12b78d53e7802f1606488c14bb3d764b15fadf5ef6c022f6ba1fa40f7
|
| 3 |
+
size 259417
|
am-onnx/joiner.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d94a1c4273ad750d98cbe89320a5b1860143059162fb8407cc22706bcfe5835
|
| 3 |
+
size 1026462
|
am/jit_script.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:91323267f4a017096429d16783ccfd9366bc005b3447b0a78d4865ded08652fc
|
| 3 |
+
size 265975361
|
decode-onnx.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import wave
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
import sys
|
| 6 |
+
import numpy as np
|
| 7 |
+
import sherpa_onnx
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
| 12 |
+
with wave.open(wave_filename) as f:
|
| 13 |
+
assert f.getnchannels() == 1, f.getnchannels()
|
| 14 |
+
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
| 15 |
+
num_samples = f.getnframes()
|
| 16 |
+
samples = f.readframes(num_samples)
|
| 17 |
+
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
| 18 |
+
samples_float32 = samples_int16.astype(np.float32)
|
| 19 |
+
samples_float32 = samples_float32 / 32768
|
| 20 |
+
return samples_float32, f.getframerate()
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
|
| 24 |
+
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
| 25 |
+
encoder="am-onnx/encoder.onnx",
|
| 26 |
+
decoder="am-onnx/decoder.onnx",
|
| 27 |
+
joiner="am-onnx/joiner.onnx",
|
| 28 |
+
tokens="lang/tokens.txt",
|
| 29 |
+
num_threads=0,
|
| 30 |
+
provider='cpu',
|
| 31 |
+
sample_rate=16000,
|
| 32 |
+
dither=3e-5,
|
| 33 |
+
max_active_paths=10,
|
| 34 |
+
decoding_method="modified_beam_search")
|
| 35 |
+
|
| 36 |
+
samples, sample_rate = read_wave(sys.argv[1])
|
| 37 |
+
s = recognizer.create_stream()
|
| 38 |
+
s.accept_waveform(sample_rate, samples)
|
| 39 |
+
recognizer.decode_stream(s)
|
| 40 |
+
print ("Text:", s.result.text)
|
| 41 |
+
print ("Tokens:", s.result.tokens)
|
| 42 |
+
print ("Timestamps:", s.result.timestamps)
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
| 46 |
+
|
decode.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
|
| 3 |
#
|
|
|
|
|
|
|
| 4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
# you may not use this file except in compliance with the License.
|
| 6 |
# You may obtain a copy of the License at
|
|
@@ -12,11 +14,33 @@
|
|
| 12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
import sys
|
| 17 |
import argparse
|
| 18 |
import logging
|
| 19 |
import math
|
|
|
|
|
|
|
|
|
|
| 20 |
import warnings
|
| 21 |
from dataclasses import dataclass, field
|
| 22 |
from typing import Dict, List, Optional, Tuple, Union
|
|
@@ -25,8 +49,8 @@ import kaldifeat
|
|
| 25 |
import sentencepiece as spm
|
| 26 |
import torch
|
| 27 |
import torchaudio
|
| 28 |
-
|
| 29 |
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
| 30 |
|
| 31 |
from icefall import NgramLm, NgramLmStateCost
|
| 32 |
from icefall.decode import Nbest, one_best_decoding
|
|
@@ -38,6 +62,37 @@ from icefall.lexicon import Lexicon
|
|
| 38 |
|
| 39 |
import k2
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def read_sound_files(
|
| 43 |
filenames: List[str], expected_sample_rate: float = 16000
|
|
@@ -59,6 +114,8 @@ def read_sound_files(
|
|
| 59 |
ans.append(wav)
|
| 60 |
return ans
|
| 61 |
|
|
|
|
|
|
|
| 62 |
@dataclass
|
| 63 |
class Hypothesis:
|
| 64 |
# The predicted tokens so far.
|
|
@@ -299,7 +356,7 @@ def modified_beam_search_LODR(
|
|
| 299 |
for i in range(N):
|
| 300 |
B[i].add(
|
| 301 |
Hypothesis(
|
| 302 |
-
ys=[
|
| 303 |
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
| 304 |
state=init_states, # state of the NN LM
|
| 305 |
lm_score=init_score.reshape(-1),
|
|
@@ -501,132 +558,41 @@ def modified_beam_search_LODR(
|
|
| 501 |
return ans
|
| 502 |
|
| 503 |
|
| 504 |
-
def greedy_search(
|
| 505 |
-
model: torch.jit.ScriptModule,
|
| 506 |
-
encoder_out: torch.Tensor,
|
| 507 |
-
encoder_out_lens: torch.Tensor,
|
| 508 |
-
) -> List[List[int]]:
|
| 509 |
-
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
| 510 |
-
Args:
|
| 511 |
-
model:
|
| 512 |
-
The transducer model.
|
| 513 |
-
encoder_out:
|
| 514 |
-
A 3-D tensor of shape (N, T, C)
|
| 515 |
-
encoder_out_lens:
|
| 516 |
-
A 1-D tensor of shape (N,).
|
| 517 |
-
Returns:
|
| 518 |
-
Return the decoded results for each utterance.
|
| 519 |
-
"""
|
| 520 |
-
assert encoder_out.ndim == 3
|
| 521 |
-
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
| 522 |
-
|
| 523 |
-
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
| 524 |
-
input=encoder_out,
|
| 525 |
-
lengths=encoder_out_lens.cpu(),
|
| 526 |
-
batch_first=True,
|
| 527 |
-
enforce_sorted=False,
|
| 528 |
-
)
|
| 529 |
-
|
| 530 |
-
device = encoder_out.device
|
| 531 |
-
blank_id = 0 # hard-code to 0
|
| 532 |
-
|
| 533 |
-
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
| 534 |
-
N = encoder_out.size(0)
|
| 535 |
-
|
| 536 |
-
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
| 537 |
-
assert N == batch_size_list[0], (N, batch_size_list)
|
| 538 |
-
|
| 539 |
-
context_size = model.decoder.context_size
|
| 540 |
-
hyps = [[blank_id] * context_size for _ in range(N)]
|
| 541 |
-
|
| 542 |
-
decoder_input = torch.tensor(
|
| 543 |
-
hyps,
|
| 544 |
-
device=device,
|
| 545 |
-
dtype=torch.int64,
|
| 546 |
-
) # (N, context_size)
|
| 547 |
-
|
| 548 |
-
decoder_out = model.decoder(
|
| 549 |
-
decoder_input,
|
| 550 |
-
need_pad=torch.tensor([False]),
|
| 551 |
-
).squeeze(1)
|
| 552 |
-
|
| 553 |
-
offset = 0
|
| 554 |
-
for batch_size in batch_size_list:
|
| 555 |
-
start = offset
|
| 556 |
-
end = offset + batch_size
|
| 557 |
-
current_encoder_out = packed_encoder_out.data[start:end]
|
| 558 |
-
current_encoder_out = current_encoder_out
|
| 559 |
-
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
| 560 |
-
offset = end
|
| 561 |
-
|
| 562 |
-
decoder_out = decoder_out[:batch_size]
|
| 563 |
-
|
| 564 |
-
logits = model.joiner(
|
| 565 |
-
current_encoder_out,
|
| 566 |
-
decoder_out,
|
| 567 |
-
)
|
| 568 |
-
# logits'shape (batch_size, vocab_size)
|
| 569 |
-
|
| 570 |
-
assert logits.ndim == 2, logits.shape
|
| 571 |
-
y = logits.argmax(dim=1).tolist()
|
| 572 |
-
emitted = False
|
| 573 |
-
for i, v in enumerate(y):
|
| 574 |
-
if v != blank_id:
|
| 575 |
-
hyps[i].append(v)
|
| 576 |
-
emitted = True
|
| 577 |
-
if emitted:
|
| 578 |
-
# update decoder output
|
| 579 |
-
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
| 580 |
-
decoder_input = torch.tensor(
|
| 581 |
-
decoder_input,
|
| 582 |
-
device=device,
|
| 583 |
-
dtype=torch.int64,
|
| 584 |
-
)
|
| 585 |
-
decoder_out = model.decoder(
|
| 586 |
-
decoder_input,
|
| 587 |
-
need_pad=torch.tensor([False]),
|
| 588 |
-
)
|
| 589 |
-
decoder_out = decoder_out.squeeze(1)
|
| 590 |
-
|
| 591 |
-
sorted_ans = [h[context_size:] for h in hyps]
|
| 592 |
-
ans = []
|
| 593 |
-
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
| 594 |
-
for i in range(N):
|
| 595 |
-
ans.append(sorted_ans[unsorted_indices[i]])
|
| 596 |
-
|
| 597 |
-
return ans
|
| 598 |
-
|
| 599 |
-
|
| 600 |
@torch.no_grad()
|
| 601 |
def main():
|
| 602 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
device = torch.device("cpu")
|
| 605 |
if torch.cuda.is_available():
|
| 606 |
device = torch.device("cuda", 0)
|
| 607 |
|
| 608 |
-
model = torch.jit.load(
|
| 609 |
-
|
| 610 |
model.eval()
|
| 611 |
-
|
| 612 |
model.to(device)
|
| 613 |
|
| 614 |
sp = spm.SentencePieceProcessor()
|
| 615 |
-
sp.load(
|
|
|
|
|
|
|
| 616 |
|
| 617 |
opts = kaldifeat.FbankOptions()
|
| 618 |
opts.device = device
|
| 619 |
-
opts.frame_opts.dither =
|
| 620 |
opts.frame_opts.snip_edges = False
|
| 621 |
opts.frame_opts.samp_freq = 16000
|
| 622 |
opts.mel_opts.num_bins = 80
|
|
|
|
| 623 |
|
| 624 |
fbank = kaldifeat.Fbank(opts)
|
| 625 |
|
| 626 |
all_filenames = sys.argv[1:]
|
| 627 |
|
| 628 |
params = AttributeDict()
|
| 629 |
-
params.
|
| 630 |
params.rnn_lm_embedding_dim = 2048
|
| 631 |
params.rnn_lm_hidden_dim = 2048
|
| 632 |
params.rnn_lm_num_layers = 3
|
|
@@ -651,8 +617,15 @@ def main():
|
|
| 651 |
)
|
| 652 |
ngram_lm_scale = -0.1
|
| 653 |
|
| 654 |
-
|
| 655 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
waves = read_sound_files(
|
| 657 |
filenames=filenames,
|
| 658 |
)
|
|
@@ -684,9 +657,19 @@ def main():
|
|
| 684 |
LM=LM,
|
| 685 |
)
|
| 686 |
|
|
|
|
| 687 |
for f, hyp in zip(filenames, hyps):
|
| 688 |
words = sp.decode(hyp)
|
| 689 |
-
print(f"{f.split('/')[-1][0:-4]} {words}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
|
| 691 |
if __name__ == "__main__":
|
| 692 |
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
|
| 3 |
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
# you may not use this file except in compliance with the License.
|
| 8 |
# You may obtain a copy of the License at
|
|
|
|
| 14 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
# See the License for the specific language governing permissions and
|
| 16 |
# limitations under the License.
|
| 17 |
+
"""
|
| 18 |
+
This script loads torchscript models, exported by `torch.jit.script()`
|
| 19 |
+
and uses them to decode waves.
|
| 20 |
+
You can use the following command to get the exported models:
|
| 21 |
+
|
| 22 |
+
./zipformer/export.py \
|
| 23 |
+
--exp-dir ./zipformer/exp \
|
| 24 |
+
--bpe-model data/lang_bpe_500/bpe.model \
|
| 25 |
+
--epoch 30 \
|
| 26 |
+
--avg 9 \
|
| 27 |
+
--jit 1
|
| 28 |
+
|
| 29 |
+
Usage of this script:
|
| 30 |
+
|
| 31 |
+
./zipformer/jit_pretrained.py \
|
| 32 |
+
--nn-model-filename ./zipformer/exp/cpu_jit.pt \
|
| 33 |
+
--bpe-model ./data/lang_bpe_500/bpe.model \
|
| 34 |
+
/path/to/foo.wav \
|
| 35 |
+
/path/to/bar.wav
|
| 36 |
+
"""
|
| 37 |
|
|
|
|
| 38 |
import argparse
|
| 39 |
import logging
|
| 40 |
import math
|
| 41 |
+
import random
|
| 42 |
+
import os
|
| 43 |
+
import sys
|
| 44 |
import warnings
|
| 45 |
from dataclasses import dataclass, field
|
| 46 |
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
| 49 |
import sentencepiece as spm
|
| 50 |
import torch
|
| 51 |
import torchaudio
|
|
|
|
| 52 |
from torch.nn.utils.rnn import pad_sequence
|
| 53 |
+
from timeit import default_timer as timer
|
| 54 |
|
| 55 |
from icefall import NgramLm, NgramLmStateCost
|
| 56 |
from icefall.decode import Nbest, one_best_decoding
|
|
|
|
| 62 |
|
| 63 |
import k2
|
| 64 |
|
| 65 |
+
def get_parser():
|
| 66 |
+
parser = argparse.ArgumentParser(
|
| 67 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--nn-model-filename",
|
| 72 |
+
default='am/jit_script.pt',
|
| 73 |
+
type=str,
|
| 74 |
+
help="Path to the torchscript model cpu_jit.pt",
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--bpe-model",
|
| 79 |
+
default='lang/bpe.model',
|
| 80 |
+
type=str,
|
| 81 |
+
help="""Path to bpe.model.""",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"sound_files",
|
| 86 |
+
type=str,
|
| 87 |
+
nargs="+",
|
| 88 |
+
help="The input sound file(s) to transcribe. "
|
| 89 |
+
"Supported formats are those supported by torchaudio.load(). "
|
| 90 |
+
"For example, wav and flac are supported. "
|
| 91 |
+
"The sample rate has to be 16kHz.",
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return parser
|
| 95 |
+
|
| 96 |
|
| 97 |
def read_sound_files(
|
| 98 |
filenames: List[str], expected_sample_rate: float = 16000
|
|
|
|
| 114 |
ans.append(wav)
|
| 115 |
return ans
|
| 116 |
|
| 117 |
+
|
| 118 |
+
|
| 119 |
@dataclass
|
| 120 |
class Hypothesis:
|
| 121 |
# The predicted tokens so far.
|
|
|
|
| 356 |
for i in range(N):
|
| 357 |
B[i].add(
|
| 358 |
Hypothesis(
|
| 359 |
+
ys=([-1] * (context_size - 1) + [blank_id]),
|
| 360 |
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
| 361 |
state=init_states, # state of the NN LM
|
| 362 |
lm_score=init_score.reshape(-1),
|
|
|
|
| 558 |
return ans
|
| 559 |
|
| 560 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
@torch.no_grad()
|
| 562 |
def main():
|
| 563 |
+
|
| 564 |
+
torch.set_num_threads(4)
|
| 565 |
+
|
| 566 |
+
parser = get_parser()
|
| 567 |
+
args = parser.parse_args()
|
| 568 |
|
| 569 |
device = torch.device("cpu")
|
| 570 |
if torch.cuda.is_available():
|
| 571 |
device = torch.device("cuda", 0)
|
| 572 |
|
| 573 |
+
model = torch.jit.load(args.nn_model_filename)
|
|
|
|
| 574 |
model.eval()
|
|
|
|
| 575 |
model.to(device)
|
| 576 |
|
| 577 |
sp = spm.SentencePieceProcessor()
|
| 578 |
+
sp.load(args.bpe_model)
|
| 579 |
+
|
| 580 |
+
random.seed(17)
|
| 581 |
|
| 582 |
opts = kaldifeat.FbankOptions()
|
| 583 |
opts.device = device
|
| 584 |
+
opts.frame_opts.dither = 3e-5
|
| 585 |
opts.frame_opts.snip_edges = False
|
| 586 |
opts.frame_opts.samp_freq = 16000
|
| 587 |
opts.mel_opts.num_bins = 80
|
| 588 |
+
opts.mel_opts.high_freq = -400
|
| 589 |
|
| 590 |
fbank = kaldifeat.Fbank(opts)
|
| 591 |
|
| 592 |
all_filenames = sys.argv[1:]
|
| 593 |
|
| 594 |
params = AttributeDict()
|
| 595 |
+
params.lm_vocab_size = 500
|
| 596 |
params.rnn_lm_embedding_dim = 2048
|
| 597 |
params.rnn_lm_hidden_dim = 2048
|
| 598 |
params.rnn_lm_num_layers = 3
|
|
|
|
| 617 |
)
|
| 618 |
ngram_lm_scale = -0.1
|
| 619 |
|
| 620 |
+
start_time = timer()
|
| 621 |
+
samples = 0
|
| 622 |
+
|
| 623 |
+
for f in all_filenames:
|
| 624 |
+
samples = samples + os.path.getsize(f) / 2
|
| 625 |
+
|
| 626 |
+
batch_size = 8
|
| 627 |
+
for i in range(0, len(all_filenames), batch_size):
|
| 628 |
+
filenames = all_filenames[i:i+batch_size]
|
| 629 |
waves = read_sound_files(
|
| 630 |
filenames=filenames,
|
| 631 |
)
|
|
|
|
| 657 |
LM=LM,
|
| 658 |
)
|
| 659 |
|
| 660 |
+
|
| 661 |
for f, hyp in zip(filenames, hyps):
|
| 662 |
words = sp.decode(hyp)
|
| 663 |
+
print(f"{f.split('/')[-1][0:-4]} {words}", flush=True)
|
| 664 |
+
|
| 665 |
+
end_time = timer()
|
| 666 |
+
|
| 667 |
+
print("Processed %.3f seconds of audio in %.3f seconds (%.3f xRT)"
|
| 668 |
+
% (samples / 16000.0,
|
| 669 |
+
end_time - start_time,
|
| 670 |
+
(end_time - start_time) / (samples / 16000.0)),
|
| 671 |
+
file=sys.stderr)
|
| 672 |
+
|
| 673 |
|
| 674 |
if __name__ == "__main__":
|
| 675 |
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|