diar-streaming-sortformer-coreml / nemo_streaming_reference.py
alexwengg's picture
Upload 33 files
ed33fd7 verified
#!/usr/bin/env python3
"""Get exact NeMo streaming inference output for comparison with Swift."""
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
import numpy as np
import librosa
import json
from nemo.collections.asr.models import SortformerEncLabelModel
def main():
print("Loading NeMo model...")
model = SortformerEncLabelModel.restore_from(
'diar_streaming_sortformer_4spk-v2.nemo', map_location='cpu'
)
model.eval()
# Disable dither for deterministic output
if hasattr(model.preprocessor, 'featurizer'):
if hasattr(model.preprocessor.featurizer, 'dither'):
model.preprocessor.featurizer.dither = 0.0
# Configure for Gradient Descent's streaming config (same as Swift)
modules = model.sortformer_modules
modules.chunk_len = 6
modules.chunk_left_context = 1
modules.chunk_right_context = 7
modules.fifo_len = 40
modules.spkcache_len = 188
modules.spkcache_update_period = 31
print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}")
print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}")
# Load audio
audio_path = "../audio.wav"
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
print(f"Loaded audio: {len(audio)} samples ({len(audio)/16000:.2f}s)")
waveform = torch.from_numpy(audio).unsqueeze(0).float()
# Get mel features using model's preprocessor
with torch.no_grad():
audio_len = torch.tensor([waveform.shape[1]])
features, feat_len = model.process_signal(
audio_signal=waveform, audio_signal_length=audio_len
)
# features is [batch, mel, time], need [batch, time, mel] for streaming
features = features[:, :, :feat_len.max()]
print(f"Features: {features.shape} (batch, mel, time)")
# Streaming inference using forward_streaming_step
subsampling = modules.subsampling_factor # 8
chunk_len = modules.chunk_len # 6
left_context = modules.chunk_left_context # 1
right_context = modules.chunk_right_context # 7
core_frames = chunk_len * subsampling # 48 mel frames
total_mel_frames = features.shape[2]
print(f"Total mel frames: {total_mel_frames}")
print(f"Core frames per chunk: {core_frames}")
# Initialize streaming state
streaming_state = modules.init_streaming_state(device=features.device)
# Initialize total_preds tensor
total_preds = torch.zeros((1, 0, 4), device=features.device)
all_preds = []
chunk_idx = 0
# Process chunks like streaming_feat_loader
stt_feat = 0
while stt_feat < total_mel_frames:
end_feat = min(stt_feat + core_frames, total_mel_frames)
# Calculate context (in mel frames)
left_offset = min(left_context * subsampling, stt_feat)
right_offset = min(right_context * subsampling, total_mel_frames - end_feat)
chunk_start = stt_feat - left_offset
chunk_end = end_feat + right_offset
# Extract chunk - [batch, mel, time] -> [batch, time, mel]
chunk = features[:, :, chunk_start:chunk_end] # [1, 128, T]
chunk_t = chunk.transpose(1, 2) # [1, T, 128]
chunk_len_tensor = torch.tensor([chunk_t.shape[1]], dtype=torch.long)
with torch.no_grad():
# Use forward_streaming_step
streaming_state, total_preds = model.forward_streaming_step(
processed_signal=chunk_t,
processed_signal_length=chunk_len_tensor,
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
chunk_idx += 1
stt_feat = end_feat
# total_preds now contains all predictions
all_preds = total_preds[0].numpy() # [total_frames, 4]
print(f"\nTotal output frames: {all_preds.shape[0]}")
print(f"Predictions shape: {all_preds.shape}")
# Print timeline
print("\n=== NeMo Streaming Timeline (80ms per frame, threshold=0.55) ===")
print("Frame Time Spk0 Spk1 Spk2 Spk3 | Visual")
print("-" * 60)
for frame in range(all_preds.shape[0]):
time_sec = frame * 0.08
probs = all_preds[frame]
visual = ['■' if p > 0.55 else '·' for p in probs]
print(f"{frame:5d} {time_sec:5.2f}s {probs[0]:.3f} {probs[1]:.3f} {probs[2]:.3f} {probs[3]:.3f} | [{visual[0]}{visual[1]}{visual[2]}{visual[3]}]")
print("-" * 60)
# Speaker activity summary
print("\n=== Speaker Activity Summary ===")
threshold = 0.55
for spk in range(4):
active_frames = np.sum(all_preds[:, spk] > threshold)
active_time = active_frames * 0.08
percent = active_time / (all_preds.shape[0] * 0.08) * 100
print(f"Speaker_{spk}: {active_time:.1f}s active ({percent:.1f}%)")
# Save to JSON for comparison
output = {
"total_frames": int(all_preds.shape[0]),
"frame_duration_seconds": 0.08,
"probabilities": all_preds.flatten().tolist(),
"config": {
"chunk_len": chunk_len,
"chunk_left_context": left_context,
"chunk_right_context": right_context,
"fifo_len": modules.fifo_len,
"spkcache_len": modules.spkcache_len,
}
}
with open("/tmp/nemo_streaming_reference.json", "w") as f:
json.dump(output, f, indent=2)
print("\nSaved to /tmp/nemo_streaming_reference.json")
if __name__ == "__main__":
main()