#!/usr/bin/env python3 """Get exact NeMo streaming inference output for comparison with Swift.""" import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" import torch import numpy as np import librosa import json from nemo.collections.asr.models import SortformerEncLabelModel def main(): print("Loading NeMo model...") model = SortformerEncLabelModel.restore_from( 'diar_streaming_sortformer_4spk-v2.nemo', map_location='cpu' ) model.eval() # Disable dither for deterministic output if hasattr(model.preprocessor, 'featurizer'): if hasattr(model.preprocessor.featurizer, 'dither'): model.preprocessor.featurizer.dither = 0.0 # Configure for Gradient Descent's streaming config (same as Swift) modules = model.sortformer_modules modules.chunk_len = 6 modules.chunk_left_context = 1 modules.chunk_right_context = 7 modules.fifo_len = 40 modules.spkcache_len = 188 modules.spkcache_update_period = 31 print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}") print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}") # Load audio audio_path = "../audio.wav" audio, sr = librosa.load(audio_path, sr=16000, mono=True) print(f"Loaded audio: {len(audio)} samples ({len(audio)/16000:.2f}s)") waveform = torch.from_numpy(audio).unsqueeze(0).float() # Get mel features using model's preprocessor with torch.no_grad(): audio_len = torch.tensor([waveform.shape[1]]) features, feat_len = model.process_signal( audio_signal=waveform, audio_signal_length=audio_len ) # features is [batch, mel, time], need [batch, time, mel] for streaming features = features[:, :, :feat_len.max()] print(f"Features: {features.shape} (batch, mel, time)") # Streaming inference using forward_streaming_step subsampling = modules.subsampling_factor # 8 chunk_len = modules.chunk_len # 6 left_context = modules.chunk_left_context # 1 right_context = modules.chunk_right_context # 7 core_frames = chunk_len * subsampling # 48 mel frames total_mel_frames = features.shape[2] print(f"Total mel frames: {total_mel_frames}") print(f"Core frames per chunk: {core_frames}") # Initialize streaming state streaming_state = modules.init_streaming_state(device=features.device) # Initialize total_preds tensor total_preds = torch.zeros((1, 0, 4), device=features.device) all_preds = [] chunk_idx = 0 # Process chunks like streaming_feat_loader stt_feat = 0 while stt_feat < total_mel_frames: end_feat = min(stt_feat + core_frames, total_mel_frames) # Calculate context (in mel frames) left_offset = min(left_context * subsampling, stt_feat) right_offset = min(right_context * subsampling, total_mel_frames - end_feat) chunk_start = stt_feat - left_offset chunk_end = end_feat + right_offset # Extract chunk - [batch, mel, time] -> [batch, time, mel] chunk = features[:, :, chunk_start:chunk_end] # [1, 128, T] chunk_t = chunk.transpose(1, 2) # [1, T, 128] chunk_len_tensor = torch.tensor([chunk_t.shape[1]], dtype=torch.long) with torch.no_grad(): # Use forward_streaming_step streaming_state, total_preds = model.forward_streaming_step( processed_signal=chunk_t, processed_signal_length=chunk_len_tensor, streaming_state=streaming_state, total_preds=total_preds, left_offset=left_offset, right_offset=right_offset, ) chunk_idx += 1 stt_feat = end_feat # total_preds now contains all predictions all_preds = total_preds[0].numpy() # [total_frames, 4] print(f"\nTotal output frames: {all_preds.shape[0]}") print(f"Predictions shape: {all_preds.shape}") # Print timeline print("\n=== NeMo Streaming Timeline (80ms per frame, threshold=0.55) ===") print("Frame Time Spk0 Spk1 Spk2 Spk3 | Visual") print("-" * 60) for frame in range(all_preds.shape[0]): time_sec = frame * 0.08 probs = all_preds[frame] visual = ['■' if p > 0.55 else '·' for p in probs] print(f"{frame:5d} {time_sec:5.2f}s {probs[0]:.3f} {probs[1]:.3f} {probs[2]:.3f} {probs[3]:.3f} | [{visual[0]}{visual[1]}{visual[2]}{visual[3]}]") print("-" * 60) # Speaker activity summary print("\n=== Speaker Activity Summary ===") threshold = 0.55 for spk in range(4): active_frames = np.sum(all_preds[:, spk] > threshold) active_time = active_frames * 0.08 percent = active_time / (all_preds.shape[0] * 0.08) * 100 print(f"Speaker_{spk}: {active_time:.1f}s active ({percent:.1f}%)") # Save to JSON for comparison output = { "total_frames": int(all_preds.shape[0]), "frame_duration_seconds": 0.08, "probabilities": all_preds.flatten().tolist(), "config": { "chunk_len": chunk_len, "chunk_left_context": left_context, "chunk_right_context": right_context, "fifo_len": modules.fifo_len, "spkcache_len": modules.spkcache_len, } } with open("/tmp/nemo_streaming_reference.json", "w") as f: json.dump(output, f, indent=2) print("\nSaved to /tmp/nemo_streaming_reference.json") if __name__ == "__main__": main()