File size: 2,008 Bytes
3a15b02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69d1fbe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import io
import numpy as np
import torch
import torchaudio
import soundfile as sf
import av


def decode_wav_flac(audio_bytes: bytes):
    """Decode WAV/FLAC using torchaudio (native support)."""
    audio, sr = torchaudio.load(io.BytesIO(audio_bytes))
    return audio, sr


def decode_mp3(audio_bytes: bytes):
    """Decode MP3 using soundfile (libsndfile backend)."""
    with io.BytesIO(audio_bytes) as bio:
        data, sr = sf.read(bio, dtype="float32")
    audio = torch.from_numpy(data).unsqueeze(
        0) if data.ndim == 1 else torch.from_numpy(data.T)
    return audio, sr


def decode_webm(audio_bytes: bytes):
    """Decode WEBM using PyAV (FFmpeg libs, no CLI)."""
    with io.BytesIO(audio_bytes) as bio:
        container = av.open(bio, format="webm")
        frames = [f.to_ndarray() for f in container.decode(audio=0)]
    if not frames:
        raise ValueError("No audio frames decoded from WEBM.")
    data = np.concatenate(frames, axis=1).astype(np.float32) / 32768.0
    audio = torch.from_numpy(data)
    if audio.ndim > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)
    sr = container.streams.audio[0].rate
    return audio, sr


def decode_audio_bytes(audio_bytes: bytes, content_type: str):
    """Dispatch decoder based on Content-Type."""
    fmt = (content_type or "").split(";")[0].strip().lower()
    print(f"Content type: {fmt}")
    if fmt in ["audio/mp3", "audio/mpeg"]:
        return decode_mp3(audio_bytes)
    elif fmt in ["audio/webm", "video/webm"]:
        return decode_webm(audio_bytes)
    else:
        return decode_wav_flac(audio_bytes)


def preprocess_audio(audio: torch.Tensor,
                     sr: int,
                     target_sr: int = 16000) -> torch.Tensor:
    """Convert to mono and resample to TARGET_SR."""
    if audio.ndim > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(sr, target_sr)
        audio = resampler(audio)
    return audio