Spaces:
Build error
Build error
Update process_input function in app.py to handle audio generation output more robustly, introducing a fallback mechanism for text generation in case of unexpected output formats. Improve error handling during audio and text generation processes. Additionally, update requirements.txt to include flash-attn for enhanced performance.
c98fc82
| import gradio as gr | |
| import torch | |
| from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor, TextStreamer | |
| from qwen_omni_utils import process_mm_info | |
| import soundfile as sf | |
| import tempfile | |
| import spaces | |
| import gc | |
| # Initialize the model and processor | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 | |
| def get_model(): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| model = Qwen2_5OmniModel.from_pretrained( | |
| "Qwen/Qwen2.5-Omni-7B", | |
| torch_dtype=torch_dtype, | |
| device_map="auto", | |
| enable_audio_output=True, | |
| low_cpu_mem_usage=True, | |
| attn_implementation="flash_attention_2" if torch.cuda.is_available() else None | |
| ) | |
| return model | |
| model = get_model() | |
| processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") | |
| # System prompt | |
| SYSTEM_PROMPT = { | |
| "role": "system", | |
| "content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." | |
| } | |
| # Voice options | |
| VOICE_OPTIONS = { | |
| "Chelsie (Female)": "Chelsie", | |
| "Ethan (Male)": "Ethan" | |
| } | |
| def process_input(image, audio, video, text, chat_history, voice_type, enable_audio_output): | |
| try: | |
| # Clear GPU memory before processing | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Initialize user_message_for_display at the start | |
| user_message_for_display = str(text) if text is not None else "" | |
| if image is not None: | |
| user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Image]" | |
| if audio is not None: | |
| user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Audio]" | |
| if video is not None: | |
| user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Video]" | |
| # If empty, provide a default message | |
| if not user_message_for_display.strip(): | |
| user_message_for_display = "Multimodal input" | |
| # Combine multimodal inputs | |
| user_input = { | |
| "text": text, | |
| "image": image if image is not None else None, | |
| "audio": audio if audio is not None else None, | |
| "video": video if video is not None else None | |
| } | |
| # Prepare conversation history for model processing | |
| conversation = [SYSTEM_PROMPT] | |
| # Add previous chat history | |
| if isinstance(chat_history, list): | |
| for message in chat_history: | |
| if isinstance(message, dict) and "role" in message and "content" in message: | |
| # Messages are already in the correct format | |
| conversation.append(message) | |
| elif isinstance(message, list) and len(message) == 2: | |
| # Convert old format to new format | |
| user_msg, bot_msg = message | |
| if bot_msg is not None: # Only add complete message pairs | |
| # Convert display format back to processable format | |
| processed_msg = user_msg | |
| if "[Image]" in user_msg: | |
| processed_msg = {"type": "text", "text": user_msg.replace("[Image]", "").strip()} | |
| if "[Audio]" in user_msg: | |
| processed_msg = {"type": "text", "text": user_msg.replace("[Audio]", "").strip()} | |
| if "[Video]" in user_msg: | |
| processed_msg = {"type": "text", "text": user_msg.replace("[Video]", "").strip()} | |
| conversation.append({"role": "user", "content": processed_msg}) | |
| conversation.append({"role": "assistant", "content": bot_msg}) | |
| # Add current user input | |
| conversation.append({"role": "user", "content": user_input_to_content(user_input)}) | |
| # Prepare for inference | |
| model_input = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) | |
| try: | |
| audios, images, videos = process_mm_info(conversation, use_audio_in_video=False) # Default to no audio in video | |
| except Exception as e: | |
| print(f"Error processing multimedia: {str(e)}") | |
| audios, images, videos = [], [], [] # Fallback to empty lists | |
| inputs = processor( | |
| text=model_input, | |
| audios=audios, | |
| images=images, | |
| videos=videos, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| # Move inputs to device and convert dtype | |
| inputs = {k: v.to(device=model.device, dtype=model.dtype) if isinstance(v, torch.Tensor) else v | |
| for k, v in inputs.items()} | |
| # Generate response with streaming | |
| try: | |
| text_ids = None | |
| audio_path = None | |
| generation_output = None | |
| if enable_audio_output: | |
| voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie") | |
| try: | |
| generation_output = model.generate( | |
| **inputs, | |
| use_audio_in_video=False, | |
| return_audio=True, | |
| spk=voice_type_value, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| streamer=TextStreamer(processor, skip_prompt=True) | |
| ) | |
| if generation_output is not None and isinstance(generation_output, tuple) and len(generation_output) == 2: | |
| text_ids, audio = generation_output | |
| if audio is not None: | |
| # Save audio to temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
| sf.write( | |
| tmp_file.name, | |
| audio.reshape(-1).detach().cpu().numpy(), | |
| samplerate=24000, | |
| ) | |
| audio_path = tmp_file.name | |
| else: | |
| print("Warning: Unexpected generation output format") | |
| # Fall back to text-only generation | |
| text_ids = model.generate( | |
| **inputs, | |
| use_audio_in_video=False, | |
| return_audio=False, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| streamer=TextStreamer(processor, skip_prompt=True) | |
| ) | |
| except Exception as e: | |
| print(f"Error during audio generation: {str(e)}") | |
| # Fall back to text-only generation | |
| try: | |
| text_ids = model.generate( | |
| **inputs, | |
| use_audio_in_video=False, | |
| return_audio=False, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| streamer=TextStreamer(processor, skip_prompt=True) | |
| ) | |
| except Exception as e: | |
| print(f"Error during fallback text generation: {str(e)}") | |
| text_ids = None | |
| else: | |
| try: | |
| text_ids = model.generate( | |
| **inputs, | |
| use_audio_in_video=False, | |
| return_audio=False, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| streamer=TextStreamer(processor, skip_prompt=True) | |
| ) | |
| except Exception as e: | |
| print(f"Error during text generation: {str(e)}") | |
| text_ids = None | |
| # Process the response | |
| if text_ids is not None and len(text_ids) > 0: | |
| try: | |
| text_response = processor.batch_decode( | |
| text_ids, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| # Clean up text response | |
| text_response = text_response.strip() | |
| if "<|im_start|>assistant" in text_response: | |
| text_response = text_response.split("<|im_start|>assistant")[-1] | |
| text_response = text_response.replace("<|im_end|>", "").replace("<|im_start|>", "") | |
| if text_response.startswith(":"): | |
| text_response = text_response[1:].strip() | |
| except Exception as e: | |
| print(f"Error during text decoding: {str(e)}") | |
| text_response = "I apologize, but I encountered an error processing the response." | |
| else: | |
| text_response = "I apologize, but I encountered an error generating a response." | |
| # Update chat history with properly formatted entries | |
| if not isinstance(chat_history, list): | |
| chat_history = [] | |
| # Convert the current messages to the proper format | |
| user_message = {"role": "user", "content": user_message_for_display} | |
| assistant_message = {"role": "assistant", "content": text_response} | |
| # Find the last incomplete message pair if it exists | |
| if chat_history and isinstance(chat_history[-1], dict) and chat_history[-1]["role"] == "user": | |
| chat_history.append(assistant_message) | |
| else: | |
| chat_history.extend([user_message, assistant_message]) | |
| # Clear GPU memory after processing | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Prepare output | |
| if enable_audio_output and audio_path: | |
| return chat_history, text_response, audio_path | |
| else: | |
| return chat_history, text_response, None | |
| except Exception as e: | |
| print(f"Error during generation: {str(e)}") | |
| error_msg = "I apologize, but I encountered an error processing your request. Please try again." | |
| chat_history.append( | |
| {"role": "assistant", "content": error_msg} | |
| ) | |
| return chat_history, error_msg, None | |
| except Exception as e: | |
| print(f"Error in process_input: {str(e)}") | |
| if not isinstance(chat_history, list): | |
| chat_history = [] | |
| error_msg = "I apologize, but I encountered an error processing your request. Please try again." | |
| chat_history.extend([ | |
| {"role": "user", "content": user_message_for_display}, | |
| {"role": "assistant", "content": error_msg} | |
| ]) | |
| return chat_history, error_msg, None | |
| def user_input_to_content(user_input): | |
| if isinstance(user_input, str): | |
| return user_input | |
| elif isinstance(user_input, dict): | |
| # Handle file uploads | |
| content = [] | |
| if "text" in user_input and user_input["text"]: | |
| content.append({"type": "text", "text": user_input["text"]}) | |
| if "image" in user_input and user_input["image"]: | |
| content.append({"type": "image", "image": user_input["image"]}) | |
| if "audio" in user_input and user_input["audio"]: | |
| content.append({"type": "audio", "audio": user_input["audio"]}) | |
| if "video" in user_input and user_input["video"]: | |
| content.append({"type": "video", "video": user_input["video"]}) | |
| return content | |
| return user_input | |
| def create_demo(): | |
| with gr.Blocks(title="Qwen2.5-Omni Chat Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Qwen2.5-Omni Multimodal Chat Demo") | |
| gr.Markdown("Experience the omni-modal capabilities of Qwen2.5-Omni through text, images, audio, and video interactions.") | |
| # Hidden placeholder components for text-only input | |
| placeholder_image = gr.Image(type="filepath", visible=False) | |
| placeholder_audio = gr.Audio(type="filepath", visible=False) | |
| placeholder_video = gr.Video(visible=False) | |
| # Chat interface | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| height=600, | |
| show_label=False, | |
| avatar_images=["user.png", "assistant.png"], | |
| type="messages" | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| voice_type = gr.Dropdown( | |
| choices=list(VOICE_OPTIONS.keys()), | |
| value="Chelsie (Female)", | |
| label="Voice Type" | |
| ) | |
| enable_audio_output = gr.Checkbox( | |
| value=True, | |
| label="Enable Audio Output" | |
| ) | |
| # Multimodal input components | |
| with gr.Tabs(): | |
| with gr.TabItem("Text Input"): | |
| text_input = gr.Textbox( | |
| placeholder="Type your message here...", | |
| label="Text Input", | |
| autofocus=True, | |
| container=False, | |
| ) | |
| text_submit = gr.Button("Send Text", variant="primary") | |
| with gr.TabItem("Multimodal Input"): | |
| with gr.Row(): | |
| image_input = gr.Image( | |
| type="filepath", | |
| label="Upload Image" | |
| ) | |
| audio_input = gr.Audio( | |
| type="filepath", | |
| label="Upload Audio" | |
| ) | |
| with gr.Row(): | |
| video_input = gr.Video( | |
| label="Upload Video" | |
| ) | |
| additional_text = gr.Textbox( | |
| placeholder="Additional text message...", | |
| label="Additional Text", | |
| container=False, | |
| ) | |
| multimodal_submit = gr.Button("Send Multimodal Input", variant="primary") | |
| clear_button = gr.Button("Clear Chat") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Model Capabilities") | |
| gr.Markdown(""" | |
| **Qwen2.5-Omni can:** | |
| - Process and understand text | |
| - Analyze images and answer questions about them | |
| - Transcribe and understand audio | |
| - Analyze video content (with or without audio) | |
| - Generate natural speech responses | |
| """) | |
| gr.Markdown("### Example Prompts") | |
| gr.Examples( | |
| examples=[ | |
| ["Describe what you see in this image", "image"], | |
| ["What is being said in this audio clip?", "audio"], | |
| ["What's happening in this video?", "video"], | |
| ["Explain quantum computing in simple terms", "text"], | |
| ["Generate a short story about a robot learning to paint", "text"] | |
| ], | |
| inputs=[text_input, gr.Textbox(visible=False)], | |
| label="Text Examples" | |
| ) | |
| audio_output = gr.Audio( | |
| label="Model Speech Output", | |
| visible=True, | |
| autoplay=True | |
| ) | |
| text_output = gr.Textbox( | |
| label="Model Text Response", | |
| interactive=False | |
| ) | |
| # Text input handling | |
| text_submit.click( | |
| fn=lambda text: [{"role": "user", "content": text if text is not None else ""}], | |
| inputs=text_input, | |
| outputs=[chatbot], | |
| queue=False | |
| ).then( | |
| fn=process_input, | |
| inputs=[placeholder_image, placeholder_audio, placeholder_video, text_input, chatbot, voice_type, enable_audio_output], | |
| outputs=[chatbot, text_output, audio_output] | |
| ).then( | |
| fn=lambda: "", # Clear input after submission | |
| outputs=text_input | |
| ) | |
| # Multimodal input handling | |
| def prepare_multimodal_input(image, audio, video, text): | |
| # Create a display message that indicates what was uploaded | |
| display_message = str(text) if text is not None else "" | |
| if image is not None: | |
| display_message = (display_message + " " if display_message.strip() else "") + "[Image]" | |
| if audio is not None: | |
| display_message = (display_message + " " if display_message.strip() else "") + "[Audio]" | |
| if video is not None: | |
| display_message = (display_message + " " if display_message.strip() else "") + "[Video]" | |
| if not display_message.strip(): | |
| display_message = "Multimodal content" | |
| return [{"role": "user", "content": display_message}] | |
| multimodal_submit.click( | |
| fn=prepare_multimodal_input, | |
| inputs=[image_input, audio_input, video_input, additional_text], | |
| outputs=[chatbot], | |
| queue=False | |
| ).then( | |
| fn=process_input, | |
| inputs=[image_input, audio_input, video_input, additional_text, | |
| chatbot, voice_type, enable_audio_output], | |
| outputs=[chatbot, text_output, audio_output] | |
| ).then( | |
| fn=lambda: (None, None, None, ""), # Clear inputs after submission | |
| outputs=[image_input, audio_input, video_input, additional_text] | |
| ) | |
| # Clear chat | |
| def clear_chat(): | |
| return [], None, None | |
| clear_button.click( | |
| fn=clear_chat, | |
| outputs=[chatbot, text_output, audio_output] | |
| ) | |
| # Update audio output visibility | |
| def toggle_audio_output(enable_audio): | |
| return gr.Audio(visible=enable_audio) | |
| enable_audio_output.change( | |
| fn=toggle_audio_output, | |
| inputs=enable_audio_output, | |
| outputs=audio_output | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |