Spaces:
Runtime error
Runtime error
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from zonos.model import Zonos | |
| from zonos.conditioning import make_cond_dict | |
| # Load the hybrid model | |
| model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device="cuda") | |
| model.bfloat16() # Switch model weights to bfloat16 precision (optional, but recommended for GPU) | |
| # Main inference function for Gradio | |
| def tts(text, reference_audio): | |
| """ | |
| text: str | |
| reference_audio: (numpy.ndarray, int) -> (data, sample_rate) | |
| """ | |
| if reference_audio is None: | |
| return "No reference audio provided." | |
| # reference_audio[0] is a NumPy float32 array of shape (num_samples, 1) or (num_samples,) | |
| # reference_audio[1] is the sample rate | |
| wav_np, sr = reference_audio | |
| # Convert NumPy audio to Torch tensor | |
| wav_torch = torch.from_numpy(wav_np).float().unsqueeze(0) # shape: (1, num_samples) | |
| if wav_torch.dim() == 2 and wav_torch.shape[0] > wav_torch.shape[1]: | |
| # If the shape is (samples, 1), reorder to (1, samples) | |
| wav_torch = wav_torch.T | |
| # Create speaker embedding | |
| spk_embedding = model.embed_spk_audio(wav_torch, sr) | |
| # Prepare conditioning | |
| cond_dict = make_cond_dict( | |
| text=text, | |
| speaker=spk_embedding.to(torch.bfloat16), | |
| language="en-us", | |
| ) | |
| conditioning = model.prepare_conditioning(cond_dict) | |
| # Generate codes | |
| with torch.no_grad(): | |
| torch.manual_seed(421) # Seeding for reproducible results | |
| codes = model.generate(conditioning) | |
| # Decode the codes into waveform | |
| wavs = model.autoencoder.decode(codes).cpu() | |
| out_audio = wavs[0].numpy() # shape: (num_samples,) | |
| # Return as (sample_rate, audio_ndarray) for Gradio's "audio" output | |
| return (model.autoencoder.sampling_rate, out_audio) | |
| # Define the Gradio interface | |
| # - text input for the prompt | |
| # - audio input for the speaker reference | |
| # - audio output with the generated speech | |
| demo = gr.Interface( | |
| fn=tts, | |
| inputs=[ | |
| gr.Textbox(label="Text to Synthesize"), | |
| gr.Audio(label="Reference Audio (for speaker embedding)"), | |
| ], | |
| outputs=gr.Audio(label="Generated Audio"), | |
| title="Zonos TTS Demo (Hybrid)", | |
| description=( | |
| "Provide a reference audio snippet for speaker embedding, " | |
| "enter text, and generate speech with Zonos TTS." | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |