|
|
import gradio as gr |
|
|
import json |
|
|
import os |
|
|
from optimum.intel import OVModelForCausalLM |
|
|
from transformers import AutoTokenizer, TextIteratorStreamer |
|
|
from threading import Thread |
|
|
import time |
|
|
|
|
|
|
|
|
MAIN_MODEL_ID = "OpenVINO/Qwen2.5-7B-Instruct-int4-ov" |
|
|
DRAFT_MODEL_ID = "hsuwill000/Qwen2.5-0.5B-Instruct-openvino-4bit" |
|
|
|
|
|
print("🚀 启动引擎...") |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID) |
|
|
model = OVModelForCausalLM.from_pretrained( |
|
|
MAIN_MODEL_ID, |
|
|
ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}, |
|
|
) |
|
|
|
|
|
try: |
|
|
draft_model = OVModelForCausalLM.from_pretrained( |
|
|
DRAFT_MODEL_ID, |
|
|
ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}, |
|
|
) |
|
|
print("✅ 投机采样 (Speculative Decoding) 已激活") |
|
|
except: |
|
|
draft_model = None |
|
|
print("⚠️ 仅使用主模型推理") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 加载失败: {e}") |
|
|
model = None |
|
|
|
|
|
|
|
|
def parse_system_prompt(mode, text_content, json_file): |
|
|
if mode == "文本模式": |
|
|
return text_content |
|
|
elif mode == "JSON模式": |
|
|
if json_file is None: return "You are a helpful assistant." |
|
|
try: |
|
|
with open(json_file.name, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
if isinstance(data, str): return data |
|
|
return data.get("system_prompt") or data.get("system") or data.get("prompt") or str(data) |
|
|
except: |
|
|
return "Error parsing JSON." |
|
|
return "You are a helpful assistant." |
|
|
|
|
|
|
|
|
def predict(message, history, mode, prompt_text, prompt_json): |
|
|
|
|
|
|
|
|
|
|
|
if model is None: |
|
|
yield history + [[message, "模型加载失败"]] |
|
|
return |
|
|
|
|
|
|
|
|
sys_prompt = parse_system_prompt(mode, prompt_text, prompt_json) |
|
|
|
|
|
|
|
|
model_inputs = [{"role": "system", "content": sys_prompt}] |
|
|
|
|
|
for user_msg, bot_msg in history: |
|
|
model_inputs.append({"role": "user", "content": user_msg}) |
|
|
model_inputs.append({"role": "assistant", "content": bot_msg}) |
|
|
|
|
|
model_inputs.append({"role": "user", "content": message}) |
|
|
|
|
|
|
|
|
text = tokenizer.apply_chat_template(model_inputs, tokenize=False, add_generation_prompt=True) |
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
gen_kwargs = dict( |
|
|
inputs, |
|
|
streamer=streamer, |
|
|
max_new_tokens=512, |
|
|
temperature=0.6, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
) |
|
|
|
|
|
if draft_model is not None: |
|
|
gen_kwargs["assistant_model"] = draft_model |
|
|
|
|
|
|
|
|
t = Thread(target=model.generate, kwargs=gen_kwargs) |
|
|
t.start() |
|
|
|
|
|
|
|
|
partial_text = "" |
|
|
for new_text in streamer: |
|
|
partial_text += new_text |
|
|
|
|
|
|
|
|
yield history + [[message, partial_text]] |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Qwen Extreme") as demo: |
|
|
gr.Markdown("## ⚡ Qwen OpenVINO + Speculative Decoding") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Accordion("设置", open=True): |
|
|
mode = gr.Radio(["文本模式", "JSON模式"], value="文本模式", label="Prompt模式") |
|
|
p_text = gr.Textbox(value="You are a helpful assistant.", lines=3, label="System Prompt") |
|
|
p_json = gr.File(label="JSON文件", file_types=[".json"], visible=False) |
|
|
|
|
|
def toggle(m): |
|
|
return {p_text: gr.update(visible=m=="文本模式"), p_json: gr.update(visible=m=="JSON模式")} |
|
|
mode.change(toggle, mode, [p_text, p_json]) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
|
|
|
chatbot = gr.Chatbot(height=600, label="Qwen2.5-7B") |
|
|
msg = gr.Textbox(label="输入") |
|
|
with gr.Row(): |
|
|
btn = gr.Button("发送", variant="primary") |
|
|
clear = gr.ClearButton([msg, chatbot]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
submit_event = msg.submit( |
|
|
predict, |
|
|
inputs=[msg, chatbot, mode, p_text, p_json], |
|
|
outputs=[chatbot] |
|
|
) |
|
|
msg.submit(lambda: "", None, msg) |
|
|
|
|
|
btn_event = btn.click( |
|
|
predict, |
|
|
inputs=[msg, chatbot, mode, p_text, p_json], |
|
|
outputs=[chatbot] |
|
|
) |
|
|
btn.click(lambda: "", None, msg) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |