adfege / app.py
wkplhc's picture
Update app.py
a354de5 verified
import gradio as gr
import json
import os
from optimum.intel import OVModelForCausalLM
from transformers import AutoTokenizer, TextIteratorStreamer
from threading import Thread
import time
# --- 模型配置 (保持不变,因为日志显示加载成功了) ---
MAIN_MODEL_ID = "OpenVINO/Qwen2.5-7B-Instruct-int4-ov"
DRAFT_MODEL_ID = "hsuwill000/Qwen2.5-0.5B-Instruct-openvino-4bit"
print("🚀 启动引擎...")
# --- 1. 加载模型 ---
try:
tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID)
model = OVModelForCausalLM.from_pretrained(
MAIN_MODEL_ID,
ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
)
try:
draft_model = OVModelForCausalLM.from_pretrained(
DRAFT_MODEL_ID,
ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
)
print("✅ 投机采样 (Speculative Decoding) 已激活")
except:
draft_model = None
print("⚠️ 仅使用主模型推理")
except Exception as e:
print(f"❌ 加载失败: {e}")
model = None
# --- 2. 辅助函数 ---
def parse_system_prompt(mode, text_content, json_file):
if mode == "文本模式":
return text_content
elif mode == "JSON模式":
if json_file is None: return "You are a helpful assistant."
try:
with open(json_file.name, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, str): return data
return data.get("system_prompt") or data.get("system") or data.get("prompt") or str(data)
except:
return "Error parsing JSON."
return "You are a helpful assistant."
# --- 3. 核心逻辑 (兼容旧版 Gradio 的 Tuple 格式) ---
def predict(message, history, mode, prompt_text, prompt_json):
# history 格式: [[User1, Bot1], [User2, Bot2]]
# message: 当前用户输入 (Str)
if model is None:
yield history + [[message, "模型加载失败"]]
return
# 1. 解析系统提示词
sys_prompt = parse_system_prompt(mode, prompt_text, prompt_json)
# 2. 将 Tuple 历史转换为模型需要的 List of Dicts
model_inputs = [{"role": "system", "content": sys_prompt}]
for user_msg, bot_msg in history:
model_inputs.append({"role": "user", "content": user_msg})
model_inputs.append({"role": "assistant", "content": bot_msg})
model_inputs.append({"role": "user", "content": message})
# 3. 构建输入
text = tokenizer.apply_chat_template(model_inputs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=512,
temperature=0.6,
do_sample=True,
top_p=0.9,
)
if draft_model is not None:
gen_kwargs["assistant_model"] = draft_model
# 4. 线程生成
t = Thread(target=model.generate, kwargs=gen_kwargs)
t.start()
# 5. 流式输出,适配 Chatbot 格式
partial_text = ""
for new_text in streamer:
partial_text += new_text
# yield 的格式必须是: history_list
# 即: [[old_u, old_b], ..., [current_u, current_partial_b]]
yield history + [[message, partial_text]]
# --- 4. 界面构建 ---
with gr.Blocks(title="Qwen Extreme") as demo:
gr.Markdown("## ⚡ Qwen OpenVINO + Speculative Decoding")
with gr.Row():
with gr.Column(scale=1):
with gr.Accordion("设置", open=True):
mode = gr.Radio(["文本模式", "JSON模式"], value="文本模式", label="Prompt模式")
p_text = gr.Textbox(value="You are a helpful assistant.", lines=3, label="System Prompt")
p_json = gr.File(label="JSON文件", file_types=[".json"], visible=False)
def toggle(m):
return {p_text: gr.update(visible=m=="文本模式"), p_json: gr.update(visible=m=="JSON模式")}
mode.change(toggle, mode, [p_text, p_json])
with gr.Column(scale=3):
# 关键修改:移除了 type="messages",默认就是 tuple 格式,绝对安全
chatbot = gr.Chatbot(height=600, label="Qwen2.5-7B")
msg = gr.Textbox(label="输入")
with gr.Row():
btn = gr.Button("发送", variant="primary")
clear = gr.ClearButton([msg, chatbot])
# 事件绑定 (简单粗暴版)
# 当点击发送时:
# 1. 调用 predict,传入 msg 和 chatbot(也就是history)
# 2. 将 predict 的输出(新的history) 更新给 chatbot
# 3. 清空 msg
submit_event = msg.submit(
predict,
inputs=[msg, chatbot, mode, p_text, p_json],
outputs=[chatbot]
)
msg.submit(lambda: "", None, msg) # 清空输入框
btn_event = btn.click(
predict,
inputs=[msg, chatbot, mode, p_text, p_json],
outputs=[chatbot]
)
btn.click(lambda: "", None, msg) # 清空输入框
if __name__ == "__main__":
demo.queue().launch()