Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| import base64 | |
| import torch | |
| import torchaudio | |
| from einops import rearrange | |
| from stable_audio_tools import get_pretrained_model | |
| from stable_audio_tools.inference.generation import generate_diffusion_cond | |
| from diffusers import DiffusionPipeline | |
| from huggingface_hub import InferenceClient, cached_download, hf_hub_url | |
| from huggingface_hub import HfApi | |
| import os | |
| from typing import List, Dict | |
| # Authentication | |
| client = InferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct", token=os.environ.get("api_key")) | |
| # Load models | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0") | |
| sample_rate = model_config["sample_rate"] | |
| sample_size = model_config["sample_size"] | |
| model = model.to(device) | |
| pipeline = DiffusionPipeline.from_pretrained("fluently/Fluently-XL-v2") | |
| pipeline.load_lora_weights("ehristoforu/dalle-3-xl-v2") | |
| # --- Hugging Face Spaces Storage --- | |
| api = HfApi() | |
| repo_id = "kvikontent/suno-ai" # Replace with your Hugging Face repository ID | |
| # --- Global Variables --- | |
| generated_songs = {} | |
| # Function to generate audio (Requires GPU) | |
| def generate_audio(prompt: str) -> List[bytes]: | |
| """Generates music, image, and names a song.""" | |
| # --- Audio Generation --- | |
| conditioning = [{ | |
| "prompt": prompt, | |
| }] | |
| output = generate_diffusion_cond( | |
| model, | |
| conditioning=conditioning, | |
| sample_size=sample_size, | |
| device=device | |
| ) | |
| output = rearrange(output, "b d n -> d (b n)") | |
| # Peak normalize, clip, convert to int16, and save to file | |
| output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() | |
| # Save audio to memory | |
| buffer = BytesIO() | |
| torchaudio.save(buffer, output, sample_rate) | |
| audio_data = buffer.getvalue() | |
| # --- Image Generation --- | |
| image = pipeline(prompt).images[0] | |
| buffer = BytesIO() | |
| image.save(buffer, format='png') | |
| image_data = buffer.getvalue() | |
| # --- Name Generation --- | |
| for message in client.chat_completion( | |
| messages=[{"role": "user", "content": "Name the song based on this prompt: " + prompt}], | |
| max_tokens=500, | |
| stream=True, | |
| ): | |
| song_name = message.choices[0].delta.content | |
| return audio_data, image_data, song_name | |
| # Function to download generated audio and image | |
| def download_audio_image(audio_data, image_data, song_name): | |
| """Downloads generated audio and image.""" | |
| audio_bytes = base64.b64encode(audio_data).decode('utf-8') | |
| image_bytes = base64.b64encode(image_data).decode('utf-8') | |
| audio_url = f"data:audio/wav;base64,{audio_bytes}" | |
| image_url = f"data:image/png;base64,{image_bytes}" | |
| return audio_url, image_url, song_name | |
| # Function to make a song public | |
| def make_public(song_id, audio_data, image_data, song_name, user_id): | |
| """Makes a song public.""" | |
| generated_songs[song_id]["public"] = True | |
| # Save the song data to Hugging Face Spaces | |
| api.upload_file( | |
| path="audio.wav", | |
| path_in_repo=f"songs/{song_id}/audio.wav", | |
| repo_id=repo_id, | |
| repo_type="space", | |
| data=audio_data | |
| ) | |
| api.upload_file( | |
| path="image.png", | |
| path_in_repo=f"songs/{song_id}/image.png", | |
| repo_id=repo_id, | |
| repo_type="space", | |
| data=image_data | |
| ) | |
| # Save the song name as a text file | |
| with open(f"song_name.txt", "w") as f: | |
| f.write(song_name) | |
| api.upload_file( | |
| path="song_name.txt", | |
| path_in_repo=f"songs/{song_id}/song_name.txt", | |
| repo_id=repo_id, | |
| repo_type="space", | |
| ) | |
| return generated_songs | |
| # Function to fetch songs from Hugging Face Spaces | |
| def fetch_songs(user_id=None): | |
| """Fetches songs from Hugging Face Spaces.""" | |
| songs = {} | |
| files = api.list_repo_files(repo_id=repo_id, repo_type="space") | |
| for file in files: | |
| if file["path"].startswith("songs"): | |
| song_id = file["path"].split("/")[1] | |
| if song_id not in songs: | |
| songs[song_id] = {} | |
| if "audio.wav" in file["path"]: | |
| # Fetch audio data | |
| audio_data = api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"]) | |
| songs[song_id]["audio"] = audio_data | |
| if "image.png" in file["path"]: | |
| # Fetch image data | |
| image_data = api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"]) | |
| songs[song_id]["image"] = image_data | |
| if "song_name.txt" in file["path"]: | |
| # Fetch song name data | |
| with open("song_name.txt", "wb") as f: | |
| f.write(api.download_file(repo_id=repo_id, repo_type="space", revision="main", path=file["path"])) | |
| with open("song_name.txt", "r") as f: | |
| song_name = f.read() | |
| songs[song_id]["name"] = song_name | |
| # Extract the public/private status and user ID from the file name (if available) | |
| # ... (Implement logic here based on how you store this information) | |
| # ... | |
| return songs | |
| # --- User Interface --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Neon Synth Music Generator") | |
| # Input area | |
| prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., 128 BPM tech house drum loop") | |
| generate_button = gr.Button("Generate") | |
| # Output area | |
| generated_audio = gr.Audio(label="Generated Audio", playable=True, source="upload") | |
| generated_image = gr.Image(label="Generated Image") | |
| song_name = gr.Textbox(label="Song Name") | |
| make_public_button = gr.Button("Make Public") | |
| # User authentication | |
| login_button = gr.Button("Login") | |
| logout_button = gr.Button("Logout", visible=False) | |
| user_name = gr.Textbox(label="Username", interactive=False, visible=False) | |
| # Feed area | |
| public_feed = gr.Gallery(label="Public Feed", show_label=False, elem_id="public-feed") | |
| user_feed = gr.Gallery(label="Your Feed", show_label=False, elem_id="user-feed") | |
| # --- Event Handlers --- | |
| generate_button.click(fn=generate_audio, inputs=prompt_input, outputs=[generated_audio, generated_image, song_name]) | |
| make_public_button.click(fn=make_public, inputs=[gr.State(generated_songs), generated_audio, generated_image, song_name, gr.State(user_name)], outputs=[gr.State(generated_songs)], show_error=False) | |
| login_button.click(fn=lambda: "YourUsername", inputs=[], outputs=[user_name], show_error=False) | |
| logout_button.click(fn=lambda: "", inputs=[], outputs=[user_name], show_error=False) | |
| login_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=login_button, show_error=False) | |
| login_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=logout_button, show_error=False) | |
| login_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=user_name, show_error=False) | |
| logout_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=login_button, show_error=False) | |
| logout_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=logout_button, show_error=False) | |
| logout_button.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=user_name, show_error=False) | |
| # --- Update the feed --- | |
| generated_audio.change(fn=download_audio_image, inputs=[generated_audio, generated_image, song_name], outputs=[generated_audio, generated_image, song_name], show_error=False) | |
| generated_audio.change( | |
| fn=lambda audio_data, image_data, song_name, user_name: [ | |
| {"audio": audio_data, "image": image_data, "name": song_name, "public": False, "user": user_name} | |
| ], | |
| inputs=[generated_audio, generated_image, song_name, user_name], | |
| outputs=[gr.State(generated_songs)], | |
| show_error=False, | |
| ) | |
| # Refresh the feed when a new song is added | |
| generated_songs.change( | |
| fn=lambda generated_songs: [ | |
| [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in generated_songs.values() if s["public"]], | |
| [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in generated_songs.values() if not s["public"] and s["user"] == user_name] | |
| ], | |
| inputs=[gr.State(generated_songs)], | |
| outputs=[public_feed, user_feed], | |
| show_error=False, | |
| ) | |
| # Fetch and display the feeds | |
| demo.load( | |
| fn=lambda: [ | |
| [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in fetch_songs().values() if s["public"]], | |
| [gr.update(value=download_audio_image(s["audio"], s["image"], s["name"])) for s in fetch_songs(user_name).values() if not s["public"]] | |
| ], | |
| outputs=[public_feed, user_feed], | |
| show_error=False, | |
| ) | |
| # --- Layout --- | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input | |
| generate_button | |
| login_button | |
| logout_button | |
| user_name | |
| with gr.Column(): | |
| generated_audio | |
| generated_image | |
| song_name | |
| make_public_button | |
| with gr.Row(): | |
| with gr.Column(): | |
| public_feed | |
| with gr.Column(): | |
| user_feed | |
| # Run the Gradio interface | |
| demo.launch() |