import gradio as gr import json import os from optimum.intel import OVModelForCausalLM from transformers import AutoTokenizer, TextIteratorStreamer from threading import Thread import time # --- 模型配置 (保持不变,因为日志显示加载成功了) --- MAIN_MODEL_ID = "OpenVINO/Qwen2.5-7B-Instruct-int4-ov" DRAFT_MODEL_ID = "hsuwill000/Qwen2.5-0.5B-Instruct-openvino-4bit" print("🚀 启动引擎...") # --- 1. 加载模型 --- try: tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID) model = OVModelForCausalLM.from_pretrained( MAIN_MODEL_ID, ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}, ) try: draft_model = OVModelForCausalLM.from_pretrained( DRAFT_MODEL_ID, ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}, ) print("✅ 投机采样 (Speculative Decoding) 已激活") except: draft_model = None print("⚠️ 仅使用主模型推理") except Exception as e: print(f"❌ 加载失败: {e}") model = None # --- 2. 辅助函数 --- def parse_system_prompt(mode, text_content, json_file): if mode == "文本模式": return text_content elif mode == "JSON模式": if json_file is None: return "You are a helpful assistant." try: with open(json_file.name, 'r', encoding='utf-8') as f: data = json.load(f) if isinstance(data, str): return data return data.get("system_prompt") or data.get("system") or data.get("prompt") or str(data) except: return "Error parsing JSON." return "You are a helpful assistant." # --- 3. 核心逻辑 (兼容旧版 Gradio 的 Tuple 格式) --- def predict(message, history, mode, prompt_text, prompt_json): # history 格式: [[User1, Bot1], [User2, Bot2]] # message: 当前用户输入 (Str) if model is None: yield history + [[message, "模型加载失败"]] return # 1. 解析系统提示词 sys_prompt = parse_system_prompt(mode, prompt_text, prompt_json) # 2. 将 Tuple 历史转换为模型需要的 List of Dicts model_inputs = [{"role": "system", "content": sys_prompt}] for user_msg, bot_msg in history: model_inputs.append({"role": "user", "content": user_msg}) model_inputs.append({"role": "assistant", "content": bot_msg}) model_inputs.append({"role": "user", "content": message}) # 3. 构建输入 text = tokenizer.apply_chat_template(model_inputs, tokenize=False, add_generation_prompt=True) inputs = tokenizer(text, return_tensors="pt") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( inputs, streamer=streamer, max_new_tokens=512, temperature=0.6, do_sample=True, top_p=0.9, ) if draft_model is not None: gen_kwargs["assistant_model"] = draft_model # 4. 线程生成 t = Thread(target=model.generate, kwargs=gen_kwargs) t.start() # 5. 流式输出,适配 Chatbot 格式 partial_text = "" for new_text in streamer: partial_text += new_text # yield 的格式必须是: history_list # 即: [[old_u, old_b], ..., [current_u, current_partial_b]] yield history + [[message, partial_text]] # --- 4. 界面构建 --- with gr.Blocks(title="Qwen Extreme") as demo: gr.Markdown("## ⚡ Qwen OpenVINO + Speculative Decoding") with gr.Row(): with gr.Column(scale=1): with gr.Accordion("设置", open=True): mode = gr.Radio(["文本模式", "JSON模式"], value="文本模式", label="Prompt模式") p_text = gr.Textbox(value="You are a helpful assistant.", lines=3, label="System Prompt") p_json = gr.File(label="JSON文件", file_types=[".json"], visible=False) def toggle(m): return {p_text: gr.update(visible=m=="文本模式"), p_json: gr.update(visible=m=="JSON模式")} mode.change(toggle, mode, [p_text, p_json]) with gr.Column(scale=3): # 关键修改:移除了 type="messages",默认就是 tuple 格式,绝对安全 chatbot = gr.Chatbot(height=600, label="Qwen2.5-7B") msg = gr.Textbox(label="输入") with gr.Row(): btn = gr.Button("发送", variant="primary") clear = gr.ClearButton([msg, chatbot]) # 事件绑定 (简单粗暴版) # 当点击发送时: # 1. 调用 predict,传入 msg 和 chatbot(也就是history) # 2. 将 predict 的输出(新的history) 更新给 chatbot # 3. 清空 msg submit_event = msg.submit( predict, inputs=[msg, chatbot, mode, p_text, p_json], outputs=[chatbot] ) msg.submit(lambda: "", None, msg) # 清空输入框 btn_event = btn.click( predict, inputs=[msg, chatbot, mode, p_text, p_json], outputs=[chatbot] ) btn.click(lambda: "", None, msg) # 清空输入框 if __name__ == "__main__": demo.queue().launch()