Spaces:
Running
Running
| from pydantic import BaseModel, Field | |
| from typing import Optional, Dict, Any | |
| import json, uuid, time, os | |
| import requests | |
| import websocket | |
| from urllib.parse import urlencode | |
| import gradio as gr | |
| COMFY_HOST = os.getenv("COMFY_HOST", "134.199.132.159") | |
| with open("workflow.json", "r", encoding="utf-8") as f: | |
| WORKFLOW_TEMPLATE: Dict[str, Any] = json.load(f) | |
| class T2VReq(BaseModel): | |
| token: str = Field(...) | |
| text: str = Field(...) | |
| negative: Optional[str] = None | |
| seed: Optional[int] = None | |
| steps: Optional[int] = 4 | |
| cfg: Optional[float] = 1 | |
| width: Optional[int] = 640 | |
| height: Optional[int] = 640 | |
| length: Optional[int] = 81 | |
| fps: Optional[int] = 16 | |
| filename_prefix: Optional[str] = "video/ComfyUI" | |
| def _inject_params(prompt: Dict[str, Any], r: T2VReq) -> Dict[str, Any]: | |
| p = json.loads(json.dumps(prompt)) | |
| p["89"]["inputs"]["text"] = r.text | |
| # if r.seed is None: | |
| # r.seed = int.from_bytes(os.urandom(8), "big") & ((1 << 63) - 1) | |
| # p["3"]["inputs"]["seed"] = r.seed | |
| # if r.steps is not None: p["78"]["inputs"]["steps"] = r.steps | |
| # if r.cfg is not None: p["78"]["inputs"]["cfg"] = r.cfg | |
| if r.width is not None: p["74"]["inputs"]["width"] = r.width | |
| if r.height is not None: p["74"]["inputs"]["height"] = r.height | |
| if r.length is not None: p["74"]["inputs"]["length"] = r.length | |
| if r.fps is not None: p["88"]["inputs"]["fps"] = r.fps | |
| if r.filename_prefix: | |
| p["80"]["inputs"]["filename_prefix"] = r.filename_prefix | |
| return p | |
| def _open_ws(client_id: str, token: str): | |
| ws = websocket.WebSocket() | |
| ws.connect(f"ws://{COMFY_HOST}/ws?clientId={client_id}&token={token}", timeout=1800) | |
| return ws | |
| def _queue_prompt(prompt: Dict[str, Any], client_id: str, token: str) -> str: | |
| payload = {"prompt": prompt, "client_id": client_id} | |
| resp = requests.post(f"http://{COMFY_HOST}/prompt?token={token}", json=payload, timeout=1800) | |
| if resp.status_code != 200: | |
| raise RuntimeError(f"ComfyUI /prompt err: {resp.text}") | |
| data = resp.json() | |
| if "prompt_id" not in data: | |
| raise RuntimeError(f"/prompt no prompt_id: {data}") | |
| return data["prompt_id"] | |
| def _get_history(prompt_id: str, token: str) -> Dict[str, Any]: | |
| r = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}?token={token}", timeout=1800) | |
| r.raise_for_status() | |
| hist = r.json() | |
| return hist.get(prompt_id, {}) | |
| def _extract_video_from_history(history: Dict[str, Any]) -> Dict[str, str]: | |
| outputs = history.get("outputs", {}) | |
| for _, node_out in outputs.items(): | |
| if "images" in node_out: | |
| for it in node_out["images"]: | |
| if all(k in it for k in ("filename", "subfolder", "type")): | |
| fn = it["filename"] | |
| if fn.lower().endswith((".mp4", ".webm", ".gif", ".mov", ".mkv")): | |
| return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]} | |
| for key in ("videos", "files"): | |
| if key in node_out and node_out[key]: | |
| it = node_out[key][0] | |
| if all(k in it for k in ("filename", "subfolder", "type")): | |
| return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]} | |
| raise RuntimeError("No video file found in history outputs") | |
| sample_prompts = [ | |
| "A golden retriever running across a beach at sunset, cinematic", | |
| "A cyberpunk city street at night with neon lights, light rain, slow pan", | |
| "An astronaut walking on an alien planet covered in glowing crystals, purple sky with two moons, dust particles floating, slow panning shot, highly detailed, cinematic atmosphere.", | |
| "A cat gracefully jumping between rooftops in slow motion, warm sunset lighting, camera tracking the cat midair, cinematic composition, natural movement." | |
| ] | |
| with gr.Blocks( | |
| title="T2V UI", | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue", neutral_hue="slate"), | |
| ) as demo: | |
| # st_token = gr.State() | |
| gr.Markdown("# Experience Wan2.2 14B Text-to-Video on AMD MI300X — Free Trial") | |
| gr.Markdown("Powered by [AMD Devcloud](https://oneclickamd.ai/) and [ComfyUI](https://github.com/comfyanonymous/ComfyUI)") | |
| gr.Markdown("### Prompt") | |
| text = gr.Textbox(label="Prompt", placeholder="Describe the video you want", lines=3) | |
| gr.Examples(examples=sample_prompts, inputs=text) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| width = gr.Number(label="Width", value=640, precision=0) | |
| height = gr.Number(label="Height", value=640, precision=0) | |
| with gr.Row(): | |
| length = gr.Number(label="Frames", value=81, precision=0) | |
| fps = gr.Number(label="FPS", value=8, precision=0) | |
| with gr.Row(): | |
| steps = gr.Number(label="Steps", value=4, precision=0) | |
| cfg = gr.Number(label="CFG", value=5.0) | |
| seed = gr.Number(label="Seed (optional)", value=None) | |
| filename_prefix = gr.Textbox(label="Filename prefix", value="video/ComfyUI") | |
| st_token = gr.Textbox(label="token", placeholder="name") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| run_btn = gr.Button("Generate", variant="primary", scale=1) | |
| prog_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, step=1, interactive=False) | |
| with gr.Column(scale=1): | |
| out_video = gr.Video(label="Result", height=480) | |
| def _init_token(): | |
| return str(uuid.uuid4()) | |
| demo.load(_init_token, outputs=st_token) | |
| def generate_fn(text, width, height, length, fps, steps, cfg, seed, filename_prefix, token): | |
| req = T2VReq( | |
| token=token, | |
| text=text, | |
| seed=int(seed) if seed is not None else None, | |
| steps=int(steps) if steps is not None else None, | |
| cfg=float(cfg) if cfg is not None else None, | |
| width=int(width) if width is not None else None, | |
| height=int(height) if height is not None else None, | |
| length=int(length) if length is not None else None, | |
| fps=int(fps) if fps is not None else None, | |
| filename_prefix=filename_prefix if filename_prefix else None, | |
| ) | |
| prompt = _inject_params(WORKFLOW_TEMPLATE, req) | |
| client_id = str(uuid.uuid4()) | |
| ws = _open_ws(client_id, req.token) | |
| prompt_id = _queue_prompt(prompt, client_id, req.token) | |
| total_nodes = max(1, len(prompt)) | |
| seen = set() | |
| p = 0 | |
| last_emit = -1 | |
| start = time.time() | |
| ws.settimeout(180) | |
| while True: | |
| out = ws.recv() | |
| if isinstance(out, (bytes, bytearray)): | |
| if p < 95 and time.time() - start > 2: | |
| p = min(95, p + 1) | |
| if p != last_emit: | |
| last_emit = p | |
| yield p, None | |
| continue | |
| msg = json.loads(out) | |
| if msg.get("type") == "executing": | |
| data = msg.get("data", {}) | |
| if data.get("prompt_id") != prompt_id: | |
| continue | |
| node = data.get("node") | |
| if node is None: | |
| break | |
| if node not in seen: | |
| seen.add(node) | |
| p = min(99, int(len(seen) / total_nodes * 100)) | |
| if p != last_emit: | |
| last_emit = p | |
| yield p, None | |
| ws.close() | |
| hist = _get_history(prompt_id, req.token) | |
| info = _extract_video_from_history(hist) | |
| q = urlencode(info) | |
| video_url = f"http://{COMFY_HOST}/view?{q}&token={req.token}" | |
| yield 100, video_url | |
| run_btn.click( | |
| generate_fn, | |
| inputs=[text, width, height, length, fps, steps, cfg, seed, filename_prefix, st_token], | |
| outputs=[prog_bar, out_video] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |