#!/usr/bin/env python3 """Pytest-based minimal sanity tests for `perform_speaker_diarization_on_utterances`. These tests avoid heavy dependencies (sherpa_onnx/faiss/sklearn) by using a mock extractor and rely on the lightweight paths & heuristics implemented in `src.diarization`. Run: pytest -q tests/test_diarization_minimal.py Or standalone (still works): python3 tests/test_diarization_minimal.py """ from __future__ import annotations import sys from pathlib import Path from typing import Iterable, List, Tuple import numpy as np import pytest ROOT = Path(__file__).resolve().parent.parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from src.diarization import perform_speaker_diarization_on_utterances # type: ignore EMB_DIM = 192 def _emb(seed: int, delta: float | None = None) -> np.ndarray: rng = np.random.default_rng(seed) v = rng.normal(size=EMB_DIM).astype(np.float32) if delta is not None: v = (v + delta).astype(np.float32) return v class MockStream: def __init__(self, sample_rate: int, segment: np.ndarray | None): self.sample_rate = sample_rate self.segment = segment def accept_waveform(self, sr, seg): # pragma: no cover - no-op pass def input_finished(self): # pragma: no cover - no-op pass class MockExtractor: """Mimics the subset of sherpa_onnx SpeakerEmbeddingExtractor we use.""" def __init__(self, embeddings_sequence: List[np.ndarray]): self._embs = embeddings_sequence self._i = 0 def create_stream(self): return MockStream(16000, None) def compute(self, _stream): if self._i >= len(self._embs): return self._embs[-1] emb = self._embs[self._i] self._i += 1 return emb def _collect(gen) -> List[Tuple[float, float, int]]: result: List[Tuple[float, float, int]] | None = None for item in gen: if isinstance(item, list): result = item # final segments emitted break if result is None: # Drain StopIteration try: while True: next(gen) except StopIteration as e: result = e.value # type: ignore assert result is not None, "Generator produced no result list" return result def _run_case(embeddings: List[np.ndarray], utterances: List[Tuple[float, float, str]]): extractor = MockExtractor(embeddings) audio = np.zeros(int(16000 * 3), dtype=np.float32) # 3s silence is enough gen = perform_speaker_diarization_on_utterances( audio=audio, sample_rate=16000, utterances=utterances, embedding_extractor=extractor, config_dict={"cluster_threshold": 0.5, "num_speakers": -1}, progress_callback=None, ) segments = _collect(gen) # Basic validation for seg in segments: assert isinstance(seg, tuple) and len(seg) == 3 s, e, spk = seg assert 0 <= s < e, "Invalid time bounds" assert isinstance(spk, int) return segments def test_single_segment(): utts = [(0.0, 2.0, "Hello world")] segs = _run_case([_emb(1)], utts) assert len(segs) == 1 assert segs[0][2] == 0 def test_two_similar_segments_same_speaker(): base = _emb(2) almost_same = (base + 0.001).astype(np.float32) utts = [(0.0, 2.0, "Bonjour"), (2.1, 4.0, "Bonjour encore")] segs = _run_case([base, almost_same], utts) assert len(segs) == 2 assert len({spk for *_rest, spk in segs}) == 1, "Should have merged speaker IDs" def test_two_different_segments_distinct_speakers(): utts = [(0.0, 1.5, "Hola"), (1.6, 3.2, "Adios")] segs = _run_case([_emb(10), _emb(200)], utts) assert len(segs) == 2 # Can be 1 or 2 depending on heuristic similarity, but expecting at least one speaker id present assert len(segs) >= 1 def test_three_segments_enhanced_or_fallback(): utts = [(0.0, 1.0, "A"), (1.1, 2.2, "B"), (2.3, 3.4, "C")] segs = _run_case([_emb(11), _emb(12), _emb(13)], utts) assert len(segs) == 3, "Granularity should be preserved for small n" # Allow running directly without pytest invocation if __name__ == "__main__": # pragma: no cover import pytest as _pytest raise SystemExit(_pytest.main([__file__]))