thanhhungtakeshi's picture
add grammar correction endpoint
69d1fbe
import io
import numpy as np
import torch
import torchaudio
import soundfile as sf
import av
def decode_wav_flac(audio_bytes: bytes):
"""Decode WAV/FLAC using torchaudio (native support)."""
audio, sr = torchaudio.load(io.BytesIO(audio_bytes))
return audio, sr
def decode_mp3(audio_bytes: bytes):
"""Decode MP3 using soundfile (libsndfile backend)."""
with io.BytesIO(audio_bytes) as bio:
data, sr = sf.read(bio, dtype="float32")
audio = torch.from_numpy(data).unsqueeze(
0) if data.ndim == 1 else torch.from_numpy(data.T)
return audio, sr
def decode_webm(audio_bytes: bytes):
"""Decode WEBM using PyAV (FFmpeg libs, no CLI)."""
with io.BytesIO(audio_bytes) as bio:
container = av.open(bio, format="webm")
frames = [f.to_ndarray() for f in container.decode(audio=0)]
if not frames:
raise ValueError("No audio frames decoded from WEBM.")
data = np.concatenate(frames, axis=1).astype(np.float32) / 32768.0
audio = torch.from_numpy(data)
if audio.ndim > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
sr = container.streams.audio[0].rate
return audio, sr
def decode_audio_bytes(audio_bytes: bytes, content_type: str):
"""Dispatch decoder based on Content-Type."""
fmt = (content_type or "").split(";")[0].strip().lower()
print(f"Content type: {fmt}")
if fmt in ["audio/mp3", "audio/mpeg"]:
return decode_mp3(audio_bytes)
elif fmt in ["audio/webm", "video/webm"]:
return decode_webm(audio_bytes)
else:
return decode_wav_flac(audio_bytes)
def preprocess_audio(audio: torch.Tensor,
sr: int,
target_sr: int = 16000) -> torch.Tensor:
"""Convert to mono and resample to TARGET_SR."""
if audio.ndim > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if sr != target_sr:
resampler = torchaudio.transforms.Resample(sr, target_sr)
audio = resampler(audio)
return audio