Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import time | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import gradio as gr | |
| from threading import Thread | |
| from huggingface_hub import login | |
| from icrawler.builtin import BingImageCrawler | |
| MODEL_LIST = ["mistralai/Mistral-Nemo-Instruct-2407"] | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| login(token=HF_TOKEN) | |
| MODEL = "mistralai/Mistral-Nemo-Instruct-2407" | |
| TITLE = "<h1><center>Mistral-Nemo</center></h1>" | |
| PLACEHOLDER = """ | |
| <center> | |
| <p>The Mistral-Nemo is a pretrained generative text model of 12B parameters trained jointly by Mistral AI and NVIDIA.</p> | |
| </center> | |
| """ | |
| CSS = """ | |
| .duplicate-button { | |
| margin: auto !important; | |
| color: white !important; | |
| background: black !important; | |
| border-radius: 100vh !important; | |
| } | |
| h3 { | |
| text-align: center; | |
| } | |
| #output_video { | |
| display: block; | |
| margin-left: auto!important; | |
| margin-right: auto !important; | |
| width: 20vw !important; | |
| } | |
| footer{visibility: hidden} | |
| """ | |
| device = "cuda" # or "cpu" | |
| # Recommended flag for this tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL, fix_mistral_regex=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL, | |
| dtype=torch.bfloat16, # torch_dtype is deprecated in newer transformers | |
| device_map="auto", | |
| ignore_mismatched_sizes=True, | |
| ) | |
| def _system_prompt_for(name: str) -> str: | |
| return ( | |
| f"You should respond like {name}. " | |
| "You should have a meaningful conversation. Don't repeat yourself. " | |
| "You should only output your response. " | |
| "You don't need to put quotes around what you're saying. " | |
| "You don't need to put your name at the beginning of your response." | |
| ) | |
| def normalize_history(history): | |
| """ | |
| Gradio may send messages where `content` is a list of rich parts: | |
| {"role": "assistant", | |
| "content": [{"type": "text", "text": "hello"}]} | |
| We convert everything into: | |
| {"role": ..., "content": "plain string"} | |
| """ | |
| if history is None: | |
| return [] | |
| norm = [] | |
| for msg in history: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| if isinstance(content, list): | |
| # e.g. [{"type":"text","text":"..."}, ...] | |
| parts = [] | |
| for part in content: | |
| if isinstance(part, dict) and part.get("type") == "text": | |
| parts.append(part.get("text", "")) | |
| else: | |
| parts.append(str(part)) | |
| content = "\n".join(parts) | |
| else: | |
| content = str(content) | |
| norm.append({"role": role, "content": content}) | |
| return norm | |
| def get_response(conversation): | |
| """ | |
| conversation: list of {"role": "system" | "user" | "assistant", "content": str} | |
| """ | |
| temperature = 0.3 | |
| max_new_tokens = 512 | |
| top_p = 1.0 | |
| top_k = 20 | |
| penalty = 1.2 | |
| input_text = tokenizer.apply_chat_template(conversation, tokenize=False) | |
| inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| timeout=60.0, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| generate_kwargs = dict( | |
| input_ids=inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False if temperature == 0 else True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| streamer=streamer, | |
| repetition_penalty=penalty, | |
| pad_token_id=10, | |
| ) | |
| with torch.no_grad(): | |
| thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| return buffer | |
| def stream_chat(history, character_a, character_b): | |
| """ | |
| history: list of messages (messages format): | |
| [{"role": "user" | "assistant", "content": ...}, ...] | |
| In the UI: | |
| - user messages = Character B | |
| - assistant messages = Character A | |
| Each click: | |
| 1. B says something new (as 'user') | |
| 2. A replies (as 'assistant') | |
| """ | |
| # 🔑 Normalize history coming from Gradio into plain strings | |
| history = normalize_history(history) | |
| # ---------- B speaks (user side) ---------- | |
| if len(history) == 0: | |
| # First turn: B introduces themselves to A | |
| b_user_prompt = ( | |
| f"You are {character_b}. You are having a conversation with {character_a}. " | |
| "Introduce yourself and start the conversation." | |
| ) | |
| else: | |
| # Last assistant message (A) to respond to | |
| last_msg = history[-1] | |
| last_text = last_msg["content"] | |
| b_user_prompt = ( | |
| f"{character_a} just said: \"{last_text}\". " | |
| f"Respond in character as {character_b} and continue the conversation." | |
| ) | |
| conv_for_b = [ | |
| {"role": "system", "content": _system_prompt_for(character_b)}, | |
| *history, | |
| {"role": "user", "content": b_user_prompt}, | |
| ] | |
| response_b = get_response(conv_for_b) | |
| print("response_b:", response_b) | |
| # ---------- A speaks (assistant side) ---------- | |
| conv_for_a = [ | |
| {"role": "system", "content": _system_prompt_for(character_a)}, | |
| *history, | |
| {"role": "user", "content": response_b}, | |
| ] | |
| response_a = get_response(conv_for_a) | |
| print("response_a:", response_a) | |
| # ---------- Append to chat history ---------- | |
| new_history = history + [ | |
| {"role": "user", "content": response_b}, # B's line | |
| {"role": "assistant", "content": response_a}, # A's line | |
| ] | |
| print("history:", new_history) | |
| return new_history | |
| def get_img(keyword): | |
| path = "./" + keyword | |
| os.makedirs(path, exist_ok=True) | |
| bing_crawler = BingImageCrawler(storage={"root_dir": path}) | |
| bing_crawler.crawl(keyword=keyword, max_num=1) | |
| for file_name in os.listdir(path): | |
| if file_name.lower().endswith( | |
| (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff") | |
| ): | |
| return os.path.join(path, file_name) | |
| return None | |
| def set_characters(a, b): | |
| img_a = get_img(a) | |
| img_b = get_img(b) | |
| # avatar_images=(user_avatar, assistant_avatar) => (B, A) | |
| # also reset chat history when characters change | |
| return img_a, img_b, gr.update(avatar_images=(img_b, img_a), value=[]) | |
| chatbot = gr.Chatbot(height=600, show_label=False) | |
| theme = gr.themes.Base().set( | |
| body_background_fill="#e1fceb", | |
| color_accent_soft="#ffffff", | |
| border_color_accent="#e1fceb", | |
| border_color_primary="#e1fceb", | |
| background_fill_secondary="#e1fceb", | |
| button_secondary_background_fill="#ffffff", | |
| button_primary_background_fill="#ffffff", | |
| button_primary_text_color="#1f2937", | |
| input_background_fill="#f8f8f8", | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| """ | |
| <center> <h1> Bot vs Bot </h1> </center> | |
| <center> by <a href="https://www.tonyassi.com/">Tony Assi</a> </center> | |
| <center> <h3> Pick two icons and watch them have a conversation </h3> </center> | |
| """ | |
| ) | |
| with gr.Row(): | |
| character_a = gr.Textbox( | |
| label="Character A", | |
| info="Choose a person", | |
| placeholder="Socrates, Edgar Allen Poe, George Washington", | |
| ) | |
| character_b = gr.Textbox( | |
| label="Character B", | |
| info="Choose a person", | |
| placeholder="Madonna, Paris Hilton, Liza Minnelli", | |
| ) | |
| character_button = gr.Button("Initiate Characters") | |
| with gr.Row(): | |
| image_a = gr.Image(show_label=False, interactive=False) | |
| gr.Markdown(" ") | |
| image_b = gr.Image(show_label=False, interactive=False) | |
| # No 'type' kwarg – your Gradio build doesn't support it, but it *does* use messages format | |
| chat = gr.Chatbot(show_label=False) | |
| submit_button = gr.Button("Start Conversation") | |
| character_button.click( | |
| set_characters, | |
| inputs=[character_a, character_b], | |
| outputs=[image_a, image_b, chat], | |
| ) | |
| submit_button.click( | |
| stream_chat, | |
| inputs=[chat, character_a, character_b], | |
| outputs=[chat], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(css=CSS, theme=theme) | |