Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| import torch | |
| from fastrtc import AdditionalOutputs, ReplyOnPause, WebRTC, WebRTCData, get_cloudflare_turn_credentials_async | |
| from threading import Thread | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers.generation.streamers import TextIteratorStreamer | |
| MODEL_ID = "google/gemma-3-27b-it" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| ) | |
| def generate(data: WebRTCData, history, system_prompt="", max_new_tokens=512): | |
| text = data.textbox | |
| history.append({"role": "user", "content": text}) | |
| yield AdditionalOutputs(history) | |
| messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] | |
| messages.extend(history) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| tokenize=True, | |
| ).to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| input_ids=inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| ) | |
| Thread(target=model.generate, kwargs=gen_kwargs).start() | |
| new_message = {"role": "assistant", "content": ""} | |
| for token in streamer: | |
| new_message["content"] += token | |
| yield AdditionalOutputs(history + [new_message]) | |
| with gr.Blocks() as demo: | |
| chatbot = gr.Chatbot(type="messages") | |
| webrtc = WebRTC( | |
| modality="audio", | |
| mode="send", | |
| variant="textbox", | |
| rtc_configuration=get_cloudflare_turn_credentials_async, | |
| ) | |
| with gr.Accordion("Settings", open=False): | |
| system_prompt = gr.Textbox( | |
| "You are a helpful assistant.", label="System prompt" | |
| ) | |
| max_new_tokens = gr.Slider(50, 1500, 700, label="Max new tokens") | |
| webrtc.stream( | |
| ReplyOnPause(generate), | |
| inputs=[webrtc, chatbot, system_prompt, max_new_tokens], | |
| outputs=[chatbot], | |
| concurrency_limit=100, | |
| ) | |
| webrtc.on_additional_outputs( | |
| lambda old, new: new, inputs=[chatbot], outputs=[chatbot] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |