|
|
|
|
|
"""Get exact NeMo streaming inference output for comparison with Swift.""" |
|
|
|
|
|
import os |
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import librosa |
|
|
import json |
|
|
|
|
|
from nemo.collections.asr.models import SortformerEncLabelModel |
|
|
|
|
|
def main(): |
|
|
print("Loading NeMo model...") |
|
|
model = SortformerEncLabelModel.restore_from( |
|
|
'diar_streaming_sortformer_4spk-v2.nemo', map_location='cpu' |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
if hasattr(model.preprocessor, 'featurizer'): |
|
|
if hasattr(model.preprocessor.featurizer, 'dither'): |
|
|
model.preprocessor.featurizer.dither = 0.0 |
|
|
|
|
|
|
|
|
modules = model.sortformer_modules |
|
|
modules.chunk_len = 6 |
|
|
modules.chunk_left_context = 1 |
|
|
modules.chunk_right_context = 7 |
|
|
modules.fifo_len = 40 |
|
|
modules.spkcache_len = 188 |
|
|
modules.spkcache_update_period = 31 |
|
|
|
|
|
print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}") |
|
|
print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}") |
|
|
|
|
|
|
|
|
audio_path = "../audio.wav" |
|
|
audio, sr = librosa.load(audio_path, sr=16000, mono=True) |
|
|
print(f"Loaded audio: {len(audio)} samples ({len(audio)/16000:.2f}s)") |
|
|
|
|
|
waveform = torch.from_numpy(audio).unsqueeze(0).float() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
audio_len = torch.tensor([waveform.shape[1]]) |
|
|
features, feat_len = model.process_signal( |
|
|
audio_signal=waveform, audio_signal_length=audio_len |
|
|
) |
|
|
|
|
|
|
|
|
features = features[:, :, :feat_len.max()] |
|
|
print(f"Features: {features.shape} (batch, mel, time)") |
|
|
|
|
|
|
|
|
subsampling = modules.subsampling_factor |
|
|
chunk_len = modules.chunk_len |
|
|
left_context = modules.chunk_left_context |
|
|
right_context = modules.chunk_right_context |
|
|
core_frames = chunk_len * subsampling |
|
|
|
|
|
total_mel_frames = features.shape[2] |
|
|
print(f"Total mel frames: {total_mel_frames}") |
|
|
print(f"Core frames per chunk: {core_frames}") |
|
|
|
|
|
|
|
|
streaming_state = modules.init_streaming_state(device=features.device) |
|
|
|
|
|
|
|
|
total_preds = torch.zeros((1, 0, 4), device=features.device) |
|
|
|
|
|
all_preds = [] |
|
|
chunk_idx = 0 |
|
|
|
|
|
|
|
|
stt_feat = 0 |
|
|
while stt_feat < total_mel_frames: |
|
|
end_feat = min(stt_feat + core_frames, total_mel_frames) |
|
|
|
|
|
|
|
|
left_offset = min(left_context * subsampling, stt_feat) |
|
|
right_offset = min(right_context * subsampling, total_mel_frames - end_feat) |
|
|
|
|
|
chunk_start = stt_feat - left_offset |
|
|
chunk_end = end_feat + right_offset |
|
|
|
|
|
|
|
|
chunk = features[:, :, chunk_start:chunk_end] |
|
|
chunk_t = chunk.transpose(1, 2) |
|
|
chunk_len_tensor = torch.tensor([chunk_t.shape[1]], dtype=torch.long) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
streaming_state, total_preds = model.forward_streaming_step( |
|
|
processed_signal=chunk_t, |
|
|
processed_signal_length=chunk_len_tensor, |
|
|
streaming_state=streaming_state, |
|
|
total_preds=total_preds, |
|
|
left_offset=left_offset, |
|
|
right_offset=right_offset, |
|
|
) |
|
|
|
|
|
chunk_idx += 1 |
|
|
stt_feat = end_feat |
|
|
|
|
|
|
|
|
all_preds = total_preds[0].numpy() |
|
|
print(f"\nTotal output frames: {all_preds.shape[0]}") |
|
|
print(f"Predictions shape: {all_preds.shape}") |
|
|
|
|
|
|
|
|
print("\n=== NeMo Streaming Timeline (80ms per frame, threshold=0.55) ===") |
|
|
print("Frame Time Spk0 Spk1 Spk2 Spk3 | Visual") |
|
|
print("-" * 60) |
|
|
|
|
|
for frame in range(all_preds.shape[0]): |
|
|
time_sec = frame * 0.08 |
|
|
probs = all_preds[frame] |
|
|
visual = ['■' if p > 0.55 else '·' for p in probs] |
|
|
print(f"{frame:5d} {time_sec:5.2f}s {probs[0]:.3f} {probs[1]:.3f} {probs[2]:.3f} {probs[3]:.3f} | [{visual[0]}{visual[1]}{visual[2]}{visual[3]}]") |
|
|
|
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
print("\n=== Speaker Activity Summary ===") |
|
|
threshold = 0.55 |
|
|
for spk in range(4): |
|
|
active_frames = np.sum(all_preds[:, spk] > threshold) |
|
|
active_time = active_frames * 0.08 |
|
|
percent = active_time / (all_preds.shape[0] * 0.08) * 100 |
|
|
print(f"Speaker_{spk}: {active_time:.1f}s active ({percent:.1f}%)") |
|
|
|
|
|
|
|
|
output = { |
|
|
"total_frames": int(all_preds.shape[0]), |
|
|
"frame_duration_seconds": 0.08, |
|
|
"probabilities": all_preds.flatten().tolist(), |
|
|
"config": { |
|
|
"chunk_len": chunk_len, |
|
|
"chunk_left_context": left_context, |
|
|
"chunk_right_context": right_context, |
|
|
"fifo_len": modules.fifo_len, |
|
|
"spkcache_len": modules.spkcache_len, |
|
|
} |
|
|
} |
|
|
|
|
|
with open("/tmp/nemo_streaming_reference.json", "w") as f: |
|
|
json.dump(output, f, indent=2) |
|
|
print("\nSaved to /tmp/nemo_streaming_reference.json") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|