Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,390 +1,162 @@
|
|
| 1 |
-
import
|
| 2 |
-
import
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
# wan2.2-main/gradio_ti2v.py
|
| 9 |
import gradio as gr
|
| 10 |
-
import torch
|
| 11 |
-
from huggingface_hub import snapshot_download
|
| 12 |
-
from PIL import Image
|
| 13 |
-
import random
|
| 14 |
-
import numpy as np
|
| 15 |
-
import spaces
|
| 16 |
-
|
| 17 |
-
import wan
|
| 18 |
-
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
|
| 19 |
-
from wan.utils.utils import cache_video
|
| 20 |
-
|
| 21 |
-
import gc
|
| 22 |
-
|
| 23 |
-
torch.backends.cudnn.enabled = False
|
| 24 |
-
|
| 25 |
-
# --- 1. Global Setup and Model Loading ---
|
| 26 |
-
|
| 27 |
-
print("Starting Gradio App for Wan 2.2 TI2V-5B...")
|
| 28 |
-
|
| 29 |
-
# Download model snapshots from Hugging Face Hub
|
| 30 |
-
repo_id = "Wan-AI/Wan2.2-TI2V-5B"
|
| 31 |
-
print(f"Downloading/loading checkpoints for {repo_id}...")
|
| 32 |
-
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
|
| 33 |
-
print(f"Using checkpoints from {ckpt_dir}")
|
| 34 |
-
|
| 35 |
-
# Load the model configuration
|
| 36 |
-
TASK_NAME = 'ti2v-5B'
|
| 37 |
-
cfg = WAN_CONFIGS[TASK_NAME]
|
| 38 |
-
FIXED_FPS = 12
|
| 39 |
-
MIN_FRAMES_MODEL = 8
|
| 40 |
-
MAX_FRAMES_MODEL = 121
|
| 41 |
-
|
| 42 |
-
# Instantiate the pipeline in the global scope
|
| 43 |
-
print("Initializing WanTI2V pipeline...")
|
| 44 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
-
device_id = 0 if torch.cuda.is_available() else -1
|
| 46 |
-
pipeline = wan.WanTI2V(
|
| 47 |
-
config=cfg,
|
| 48 |
-
checkpoint_dir=ckpt_dir,
|
| 49 |
-
device_id=device_id,
|
| 50 |
-
rank=0,
|
| 51 |
-
t5_fsdp=False,
|
| 52 |
-
dit_fsdp=False,
|
| 53 |
-
use_sp=False,
|
| 54 |
-
t5_cpu=False,
|
| 55 |
-
init_on_cpu=False,
|
| 56 |
-
convert_model_dtype=True,
|
| 57 |
-
)
|
| 58 |
-
print("Pipeline initialized and ready.")
|
| 59 |
-
|
| 60 |
-
# --- Helper Functions ---
|
| 61 |
-
def clear_gpu_memory():
|
| 62 |
-
"""Clear GPU memory more thoroughly"""
|
| 63 |
-
if torch.cuda.is_available():
|
| 64 |
-
torch.cuda.empty_cache()
|
| 65 |
-
torch.cuda.ipc_collect()
|
| 66 |
-
gc.collect()
|
| 67 |
-
|
| 68 |
-
def select_best_size_for_image(image, available_sizes):
|
| 69 |
-
"""Select the size option with aspect ratio closest to the input image."""
|
| 70 |
-
if image is None:
|
| 71 |
-
return available_sizes[0] # Return first option if no image
|
| 72 |
-
|
| 73 |
-
img_width, img_height = image.size
|
| 74 |
-
img_aspect_ratio = img_height / img_width
|
| 75 |
-
|
| 76 |
-
best_size = available_sizes[0]
|
| 77 |
-
best_diff = float('inf')
|
| 78 |
-
|
| 79 |
-
for size_str in available_sizes:
|
| 80 |
-
# Parse size string like "704*1280"
|
| 81 |
-
height, width = map(int, size_str.split('*'))
|
| 82 |
-
size_aspect_ratio = height / width
|
| 83 |
-
diff = abs(img_aspect_ratio - size_aspect_ratio)
|
| 84 |
-
|
| 85 |
-
if diff < best_diff:
|
| 86 |
-
best_diff = diff
|
| 87 |
-
best_size = size_str
|
| 88 |
-
|
| 89 |
-
return best_size
|
| 90 |
-
|
| 91 |
-
def handle_image_upload(image):
|
| 92 |
-
"""Handle image upload and return the best matching size."""
|
| 93 |
-
if image is None:
|
| 94 |
-
return gr.update()
|
| 95 |
-
|
| 96 |
-
pil_image = Image.fromarray(image).convert("RGB")
|
| 97 |
-
available_sizes = list(SUPPORTED_SIZES[TASK_NAME])
|
| 98 |
-
best_size = select_best_size_for_image(pil_image, available_sizes)
|
| 99 |
-
|
| 100 |
-
return gr.update(value=best_size)
|
| 101 |
-
|
| 102 |
-
def validate_inputs(image, prompt, duration_seconds):
|
| 103 |
-
"""Validate user inputs"""
|
| 104 |
-
errors = []
|
| 105 |
-
|
| 106 |
-
if not prompt or len(prompt.strip()) < 5:
|
| 107 |
-
errors.append("Prompt must be at least 5 characters long.")
|
| 108 |
-
|
| 109 |
-
if image is not None:
|
| 110 |
-
img = Image.fromarray(image)
|
| 111 |
-
if img.size[0] * img.size[1] > 4096 * 4096:
|
| 112 |
-
errors.append("Image size is too large (maximum 4096x4096).")
|
| 113 |
-
|
| 114 |
-
if duration_seconds > 10.1 and image is None:
|
| 115 |
-
errors.append("Videos longer than 10.1 seconds require an input image.")
|
| 116 |
-
|
| 117 |
-
return errors
|
| 118 |
-
|
| 119 |
-
def get_duration(image,
|
| 120 |
-
prompt,
|
| 121 |
-
size,
|
| 122 |
-
duration_seconds,
|
| 123 |
-
sampling_steps,
|
| 124 |
-
guide_scale,
|
| 125 |
-
shift,
|
| 126 |
-
seed,
|
| 127 |
-
progress):
|
| 128 |
-
"""Calculate dynamic GPU duration based on parameters."""
|
| 129 |
-
if sampling_steps > 35 and duration_seconds >= 2:
|
| 130 |
-
return 120
|
| 131 |
-
elif sampling_steps < 35 or duration_seconds < 2:
|
| 132 |
-
return 105
|
| 133 |
-
else:
|
| 134 |
-
return 90
|
| 135 |
-
|
| 136 |
-
def apply_template(template, current_prompt):
|
| 137 |
-
"""Apply prompt template"""
|
| 138 |
-
if "{subject}" in template:
|
| 139 |
-
# Extract the main subject from current prompt (simple heuristic)
|
| 140 |
-
subject = current_prompt.split(",")[0] if "," in current_prompt else current_prompt
|
| 141 |
-
return template.replace("{subject}", subject)
|
| 142 |
-
return template + " " + current_prompt
|
| 143 |
-
|
| 144 |
-
# --- 2. Gradio Inference Function ---
|
| 145 |
-
@spaces.GPU(duration=get_duration)
|
| 146 |
-
def generate_video(
|
| 147 |
-
image,
|
| 148 |
-
prompt,
|
| 149 |
-
size,
|
| 150 |
-
duration_seconds,
|
| 151 |
-
sampling_steps,
|
| 152 |
-
guide_scale,
|
| 153 |
-
shift,
|
| 154 |
-
seed,
|
| 155 |
-
progress=gr.Progress(track_tqdm=True)
|
| 156 |
-
):
|
| 157 |
-
torch.backends.cudnn.enabled = False
|
| 158 |
-
"""The main function to generate video, called by the Gradio interface."""
|
| 159 |
-
# Validate inputs
|
| 160 |
-
errors = validate_inputs(image, prompt, duration_seconds)
|
| 161 |
-
if errors:
|
| 162 |
-
raise gr.Error("\n".join(errors))
|
| 163 |
-
|
| 164 |
-
progress(0, desc="Setting up...")
|
| 165 |
-
|
| 166 |
-
if seed == -1:
|
| 167 |
-
seed = random.randint(0, sys.maxsize)
|
| 168 |
-
|
| 169 |
-
progress(0.1, desc="Processing image...")
|
| 170 |
-
|
| 171 |
-
input_image = None
|
| 172 |
-
if image is not None:
|
| 173 |
-
input_image = Image.fromarray(image).convert("RGB")
|
| 174 |
-
# Resize image to match selected size
|
| 175 |
-
target_height, target_width = map(int, size.split('*'))
|
| 176 |
-
input_image = input_image.resize((target_width, target_height))
|
| 177 |
-
|
| 178 |
-
# Calculate number of frames based on duration
|
| 179 |
-
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
|
| 180 |
-
|
| 181 |
-
progress(0.2, desc="Generating video...")
|
| 182 |
-
|
| 183 |
-
try:
|
| 184 |
-
video_tensor = pipeline.generate(
|
| 185 |
-
input_prompt=prompt,
|
| 186 |
-
img=input_image, # Pass None for T2V, Image for I2V
|
| 187 |
-
size=SIZE_CONFIGS[size],
|
| 188 |
-
max_area=MAX_AREA_CONFIGS[size],
|
| 189 |
-
frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
|
| 190 |
-
shift=shift,
|
| 191 |
-
sample_solver='unipc',
|
| 192 |
-
sampling_steps=int(sampling_steps),
|
| 193 |
-
guide_scale=guide_scale,
|
| 194 |
-
seed=seed,
|
| 195 |
-
offload_model=True
|
| 196 |
-
)
|
| 197 |
-
|
| 198 |
-
progress(0.9, desc="Saving video...")
|
| 199 |
-
|
| 200 |
-
# Save the video to a temporary file
|
| 201 |
-
video_path = cache_video(
|
| 202 |
-
tensor=video_tensor[None], # Add a batch dimension
|
| 203 |
-
save_file=None, # cache_video will create a temp file
|
| 204 |
-
fps=cfg.sample_fps,
|
| 205 |
-
normalize=True,
|
| 206 |
-
value_range=(-1, 1)
|
| 207 |
-
)
|
| 208 |
-
|
| 209 |
-
progress(1.0, desc="Complete!")
|
| 210 |
-
|
| 211 |
-
except torch.cuda.OutOfMemoryError:
|
| 212 |
-
clear_gpu_memory()
|
| 213 |
-
raise gr.Error("GPU out of memory. Please try with lower settings.")
|
| 214 |
-
except Exception as e:
|
| 215 |
-
raise gr.Error(f"Video generation failed: {str(e)}")
|
| 216 |
-
finally:
|
| 217 |
-
if 'video_tensor' in locals():
|
| 218 |
-
del video_tensor
|
| 219 |
-
clear_gpu_memory()
|
| 220 |
-
|
| 221 |
-
return video_path
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
# --- 3. Gradio Interface ---
|
| 225 |
-
css = """
|
| 226 |
-
.gradio-container {max-width: 1100px !important; margin: 0 auto}
|
| 227 |
-
#output_video {height: 500px;}
|
| 228 |
-
#input_image {height: 500px;}
|
| 229 |
-
.template-btn {margin: 2px !important;}
|
| 230 |
-
"""
|
| 231 |
-
|
| 232 |
-
# Default prompt with motion emphasis
|
| 233 |
-
DEFAULT_PROMPT = "Two friends are paddling a kayak across a tranquil alpine lake, its surface as still as a mirror. Snow-capped peaks and dense forests are reflected in the crystal-clear water, and gentle ripples spread outward as the kayak moves. The camera slowly pans to the right, capturing the serene beauty of the scene. Distant mountains and trees are clearly visible in the background, adding a sense of natural harmony and peace."
|
| 234 |
-
|
| 235 |
-
# Prompt templates
|
| 236 |
-
templates = {
|
| 237 |
-
"Cinematic": "cinematic shot of {subject}, professional lighting, smooth camera movement, 4k quality",
|
| 238 |
-
"Animation": "animated style {subject}, vibrant colors, fluid motion, dynamic movement",
|
| 239 |
-
"Nature": "nature documentary footage of {subject}, wildlife photography, natural movement",
|
| 240 |
-
"Slow Motion": "slow motion capture of {subject}, high speed camera, detailed motion",
|
| 241 |
-
"Action": "dynamic action shot of {subject}, fast paced movement, energetic motion"
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
|
| 245 |
-
gr.Markdown("""
|
| 246 |
-
# Wan 2.2 TI2V Enhanced running on AMD MI355
|
| 247 |
-
|
| 248 |
-
Generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model**
|
| 249 |
-
[[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B), [[paper]](https://arxiv.org/abs/2503.20314)
|
| 250 |
-
|
| 251 |
-
### 💡 Tips for best results:
|
| 252 |
-
- 🖼️ Upload an image for better control over the video content
|
| 253 |
-
- ⏱️ Longer videos require more processing time
|
| 254 |
-
- 🎯 Be specific and descriptive in your prompts
|
| 255 |
-
- 🎬 Include motion-related keywords for dynamic videos
|
| 256 |
-
""")
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
with gr.Row():
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
)
|
| 293 |
-
size_input = gr.Dropdown(
|
| 294 |
-
label="Output Resolution",
|
| 295 |
-
choices=list(SUPPORTED_SIZES[TASK_NAME]),
|
| 296 |
-
value="704*1280"
|
| 297 |
)
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
# Add image upload handler
|
| 346 |
-
image_input.upload(
|
| 347 |
-
fn=handle_image_upload,
|
| 348 |
-
inputs=[image_input],
|
| 349 |
-
outputs=[size_input]
|
| 350 |
-
)
|
| 351 |
-
|
| 352 |
-
image_input.clear(
|
| 353 |
-
fn=handle_image_upload,
|
| 354 |
-
inputs=[image_input],
|
| 355 |
-
outputs=[size_input]
|
| 356 |
-
)
|
| 357 |
-
|
| 358 |
-
# Update status when generating
|
| 359 |
-
def update_status_and_generate(*args):
|
| 360 |
-
status_text.value = "Generating..."
|
| 361 |
-
try:
|
| 362 |
-
result = generate_video(*args)
|
| 363 |
-
status_text.value = "Complete!"
|
| 364 |
-
return result
|
| 365 |
-
except Exception as e:
|
| 366 |
-
status_text.value = "Error occurred"
|
| 367 |
-
raise e
|
| 368 |
-
|
| 369 |
-
example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
|
| 370 |
-
gr.Examples(
|
| 371 |
-
examples=[
|
| 372 |
-
[None, "Golden hour, soft lighting, warm colors, saturated colors, wide shot, left-heavy composition. A weathered gondolier stands in a flat-bottomed boat, propelling it forward with a long wooden pole through the flooded ruins of Venice. The decaying buildings on either side are cloaked in creeping vines and marked by rusted metalwork, their once-proud facades now crumbling into the water. The camera moves slowly forward and tilts left, revealing behind him the majestic remnants of the city bathed in the amber glow of the setting sun. Silhouettes of collapsed archways and broken domes rise against the golden skyline, while the still water reflects the warm hues of the sky and surrounding structures.", "1280*704", 4.0],
|
| 373 |
-
[None, "In a surreal video, four miniature skiers glide down a winding, three-dimensional trail of thick white paint on a plain white canvas-like background. The textured paint mimics snow, with visible brushstrokes and uneven edges, enhanced by light and shadow. The skiers, in colorful gear, are posed dynamically from top to bottom, each casting a shadow that heightens the illusion of depth. This scene miniaturizes a grand outdoor sport into a vivid, imaginative artwork.", "1280*704", 2.0],
|
| 374 |
-
[None, "In a time-lapse video, a crane slowly lifts a steel beam on a construction site. The camera pulls back slowly from a close-up, revealing details of the crane and the steel beam. The skyline transitions from day to night, with buildings and machinery in the background constantly operating. As the camera pulls further back, the busy scene of the entire construction site comes into view; cranes and other equipment continue working under the night sky, shaping the city's outline.", "704*1280", 2.5],
|
| 375 |
-
[None, "Cinematic racetrack scene: Low-angle medium long shot of jockey-horse leap. High-contrast backlighting, warm tones, silhouettes. Slow-motion freeze with dust for dynamic tension. Scoreboard detail. Optimized for immersive video generation.", "1280*704", 3.0],
|
| 376 |
-
],
|
| 377 |
-
inputs=[image_input, prompt_input, size_input, duration_input],
|
| 378 |
-
outputs=video_output,
|
| 379 |
-
fn=generate_video,
|
| 380 |
-
cache_examples=False,
|
| 381 |
-
)
|
| 382 |
-
|
| 383 |
-
run_button.click(
|
| 384 |
-
fn=generate_video,
|
| 385 |
-
inputs=[image_input, prompt_input, size_input, duration_input, steps_input, scale_input, shift_input, seed_input],
|
| 386 |
-
outputs=video_output
|
| 387 |
-
)
|
| 388 |
-
|
| 389 |
-
if __name__ == "__main__":
|
| 390 |
-
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional, Dict, Any
|
| 3 |
+
import json, uuid, time, os
|
| 4 |
+
import requests
|
| 5 |
+
import websocket
|
| 6 |
+
from urllib.parse import urlencode
|
|
|
|
|
|
|
| 7 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
COMFY_HOST = os.getenv("COMFY_HOST", "134.199.132.159")
|
| 10 |
+
|
| 11 |
+
with open("workflow.json", "r", encoding="utf-8") as f:
|
| 12 |
+
WORKFLOW_TEMPLATE: Dict[str, Any] = json.load(f)
|
| 13 |
+
|
| 14 |
+
class T2VReq(BaseModel):
|
| 15 |
+
token: str = Field(...)
|
| 16 |
+
text: str = Field(...)
|
| 17 |
+
negative: Optional[str] = None
|
| 18 |
+
seed: Optional[int] = None
|
| 19 |
+
steps: Optional[int] = 20
|
| 20 |
+
cfg: Optional[float] = 5.0
|
| 21 |
+
width: Optional[int] = 1280
|
| 22 |
+
height: Optional[int] = 704
|
| 23 |
+
length: Optional[int] = 121
|
| 24 |
+
fps: Optional[int] = 24
|
| 25 |
+
filename_prefix: Optional[str] = "video/ComfyUI"
|
| 26 |
+
|
| 27 |
+
def _inject_params(prompt: Dict[str, Any], r: T2VReq) -> Dict[str, Any]:
|
| 28 |
+
p = json.loads(json.dumps(prompt))
|
| 29 |
+
p["6"]["inputs"]["text"] = r.text
|
| 30 |
+
if r.seed is None:
|
| 31 |
+
r.seed = int.from_bytes(os.urandom(8), "big") & ((1 << 63) - 1)
|
| 32 |
+
p["3"]["inputs"]["seed"] = r.seed
|
| 33 |
+
if r.steps is not None: p["3"]["inputs"]["steps"] = r.steps
|
| 34 |
+
if r.cfg is not None: p["3"]["inputs"]["cfg"] = r.cfg
|
| 35 |
+
if r.width is not None: p["55"]["inputs"]["width"] = r.width
|
| 36 |
+
if r.height is not None: p["55"]["inputs"]["height"] = r.height
|
| 37 |
+
if r.length is not None: p["55"]["inputs"]["length"] = r.length
|
| 38 |
+
if r.fps is not None: p["57"]["inputs"]["fps"] = r.fps
|
| 39 |
+
if r.filename_prefix:
|
| 40 |
+
p["58"]["inputs"]["filename_prefix"] = r.filename_prefix
|
| 41 |
+
return p
|
| 42 |
+
|
| 43 |
+
def _open_ws(client_id: str, token: str):
|
| 44 |
+
ws = websocket.WebSocket()
|
| 45 |
+
ws.connect(f"ws://{COMFY_HOST}/ws?clientId={client_id}&token={token}", timeout=1800)
|
| 46 |
+
return ws
|
| 47 |
+
|
| 48 |
+
def _queue_prompt(prompt: Dict[str, Any], client_id: str, token: str) -> str:
|
| 49 |
+
payload = {"prompt": prompt, "client_id": client_id}
|
| 50 |
+
resp = requests.post(f"http://{COMFY_HOST}/prompt?token={token}", json=payload, timeout=1800)
|
| 51 |
+
if resp.status_code != 200:
|
| 52 |
+
raise RuntimeError(f"ComfyUI /prompt err: {resp.text}")
|
| 53 |
+
data = resp.json()
|
| 54 |
+
if "prompt_id" not in data:
|
| 55 |
+
raise RuntimeError(f"/prompt no prompt_id: {data}")
|
| 56 |
+
return data["prompt_id"]
|
| 57 |
+
|
| 58 |
+
def _get_history(prompt_id: str, token: str) -> Dict[str, Any]:
|
| 59 |
+
r = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}?token={token}", timeout=1800)
|
| 60 |
+
r.raise_for_status()
|
| 61 |
+
hist = r.json()
|
| 62 |
+
return hist.get(prompt_id, {})
|
| 63 |
+
|
| 64 |
+
def _extract_video_from_history(history: Dict[str, Any]) -> Dict[str, str]:
|
| 65 |
+
outputs = history.get("outputs", {})
|
| 66 |
+
for _, node_out in outputs.items():
|
| 67 |
+
if "images" in node_out:
|
| 68 |
+
for it in node_out["images"]:
|
| 69 |
+
if all(k in it for k in ("filename", "subfolder", "type")):
|
| 70 |
+
fn = it["filename"]
|
| 71 |
+
if fn.lower().endswith((".mp4", ".webm", ".gif", ".mov", ".mkv")):
|
| 72 |
+
return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]}
|
| 73 |
+
for key in ("videos", "files"):
|
| 74 |
+
if key in node_out and node_out[key]:
|
| 75 |
+
it = node_out[key][0]
|
| 76 |
+
if all(k in it for k in ("filename", "subfolder", "type")):
|
| 77 |
+
return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]}
|
| 78 |
+
raise RuntimeError("No video file found in history outputs")
|
| 79 |
+
|
| 80 |
+
with gr.Blocks(title="Wan 2.2 T2V UI running on AMD MI300x") as demo:
|
| 81 |
+
st_token = gr.State()
|
| 82 |
with gr.Row():
|
| 83 |
+
text = gr.Textbox(label="Prompt", placeholder="Text to generate", lines=3)
|
| 84 |
+
with gr.Row():
|
| 85 |
+
width = gr.Number(label="Width", value=1280, precision=0)
|
| 86 |
+
height = gr.Number(label="Height", value=704, precision=0)
|
| 87 |
+
length = gr.Number(label="FPS", value=121, precision=0)
|
| 88 |
+
fps = gr.Number(label="FPS", value=24, precision=0)
|
| 89 |
+
with gr.Row():
|
| 90 |
+
steps = gr.Number(label="Steps", value=20, precision=0)
|
| 91 |
+
cfg = gr.Number(label="CFG", value=5.0)
|
| 92 |
+
seed = gr.Number(label="Seed", value=None)
|
| 93 |
+
filename_prefix = gr.Textbox(label="Prefix of video", value="video/ComfyUI")
|
| 94 |
+
run_btn = gr.Button("Generate")
|
| 95 |
+
prog_bar = gr.Slider(label="Step", minimum=0, maximum=100, value=0, step=1, interactive=False)
|
| 96 |
+
out_video = gr.Video(label="Result")
|
| 97 |
+
|
| 98 |
+
def _init_token():
|
| 99 |
+
return str(uuid.uuid4())
|
| 100 |
+
|
| 101 |
+
demo.load(_init_token, outputs=st_token)
|
| 102 |
+
|
| 103 |
+
def generate_fn(text, width, height, length, fps, steps, cfg, seed, filename_prefix, token):
|
| 104 |
+
def _runner():
|
| 105 |
+
req = T2VReq(
|
| 106 |
+
token=token,
|
| 107 |
+
text=text,
|
| 108 |
+
seed=int(seed) if seed is not None else None,
|
| 109 |
+
steps=int(steps) if steps is not None else None,
|
| 110 |
+
cfg=float(cfg) if cfg is not None else None,
|
| 111 |
+
width=int(width) if width is not None else None,
|
| 112 |
+
height=int(height) if height is not None else None,
|
| 113 |
+
length=int(length) if length is not None else None,
|
| 114 |
+
fps=int(fps) if fps is not None else None,
|
| 115 |
+
filename_prefix=filename_prefix if filename_prefix else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
)
|
| 117 |
+
prompt = _inject_params(WORKFLOW_TEMPLATE, req)
|
| 118 |
+
client_id = str(uuid.uuid4())
|
| 119 |
+
ws = _open_ws(client_id, req.token)
|
| 120 |
+
prompt_id = _queue_prompt(prompt, client_id, req.token)
|
| 121 |
+
total_nodes = max(1, len(prompt))
|
| 122 |
+
seen = set()
|
| 123 |
+
p = 0
|
| 124 |
+
last_emit = -1
|
| 125 |
+
start = time.time()
|
| 126 |
+
ws.settimeout(60)
|
| 127 |
+
while True:
|
| 128 |
+
out = ws.recv()
|
| 129 |
+
if isinstance(out, (bytes, bytearray)):
|
| 130 |
+
if p < 95 and time.time() - start > 2:
|
| 131 |
+
p = min(95, p + 1)
|
| 132 |
+
if p != last_emit:
|
| 133 |
+
last_emit = p
|
| 134 |
+
yield p, None
|
| 135 |
+
continue
|
| 136 |
+
msg = json.loads(out)
|
| 137 |
+
if msg.get("type") == "executing":
|
| 138 |
+
data = msg.get("data", {})
|
| 139 |
+
if data.get("prompt_id") != prompt_id:
|
| 140 |
+
continue
|
| 141 |
+
node = data.get("node")
|
| 142 |
+
if node is None:
|
| 143 |
+
break
|
| 144 |
+
if node not in seen:
|
| 145 |
+
seen.add(node)
|
| 146 |
+
p = min(99, int(len(seen) / total_nodes * 100))
|
| 147 |
+
if p != last_emit:
|
| 148 |
+
last_emit = p
|
| 149 |
+
yield p, None
|
| 150 |
+
ws.close()
|
| 151 |
+
hist = _get_history(prompt_id, req.token)
|
| 152 |
+
info = _extract_video_from_history(hist)
|
| 153 |
+
q = urlencode(info)
|
| 154 |
+
video_url = f"http://{COMFY_HOST}/view?{q}"
|
| 155 |
+
yield 100, video_url
|
| 156 |
+
return _runner()
|
| 157 |
+
|
| 158 |
+
run_btn.click(
|
| 159 |
+
generate_fn,
|
| 160 |
+
inputs=[text, width, height, length, fps, steps, cfg, seed, filename_prefix, st_token],
|
| 161 |
+
outputs=[prog_bar, out_video]
|
| 162 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|