import gradio as gr import torch import librosa from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq import os # Global model cache model = None processor = None device = "cuda" if torch.cuda.is_available() else "cpu" def load_model(): global model, processor if model is None: repo_id = "MERaLiON/MERaLiON-2-10B" print("Loading MERaLiON-2-10B model...") processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) model = AutoModelForSpeechSeq2Seq.from_pretrained( repo_id, use_safetensors=True, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto", ) print("Model loaded successfully!") return model, processor def meralion_inference(prompt, uploaded_file): global model, processor if uploaded_file is None: return "Please upload an audio file." # Load model on first run model, processor = load_model() try: # Load audio at 16kHz audio_array, sr = librosa.load(uploaded_file.name, sr=16000) # Prompt template prompt_template = "Instruction: {query}\nFollow the text instruction based on the following audio: " conversation = [ {"role": "user", "content": prompt_template.format(query=prompt)} ] chat_prompt = processor.tokenizer.apply_chat_template( conversation=conversation, tokenize=False, add_generation_prompt=True ) # Process inputs inputs = processor(text=chat_prompt, audios=audio_array) # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7 ) generated_ids = outputs[:, inputs["input_ids"].size(1) :] response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return response except Exception as e: return f"Error during inference: {str(e)}" with gr.Blocks() as demo: gr.Markdown("# MERaLiON-2-10B Audio Demo") with gr.Row(): prompt_input = gr.Textbox( label="Enter Prompt", value="Please transcribe this speech.", lines=2 ) file_input = gr.File( label="Upload Audio File (WAV/MP3, max 300s)", file_types=[".wav", ".mp3", ".m4a"], ) output_text = gr.Textbox(label="Model Output", lines=8) submit_btn = gr.Button("Run Inference", variant="primary") submit_btn.click( meralion_inference, inputs=[prompt_input, file_input], outputs=output_text ) demo.launch()