Spaces:
Running
Running
| import os | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| #import subprocess | |
| #subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| # wan2.2-main/gradio_ti2v.py | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| import random | |
| import numpy as np | |
| import spaces | |
| import wan | |
| from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES | |
| from wan.utils.utils import cache_video | |
| import gc | |
| # --- 1. Global Setup and Model Loading --- | |
| print("Starting Gradio App for Wan 2.2 TI2V-5B...") | |
| # Download model snapshots from Hugging Face Hub | |
| repo_id = "Wan-AI/Wan2.2-TI2V-5B" | |
| print(f"Downloading/loading checkpoints for {repo_id}...") | |
| ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False) | |
| print(f"Using checkpoints from {ckpt_dir}") | |
| # Load the model configuration | |
| TASK_NAME = 'ti2v-5B' | |
| cfg = WAN_CONFIGS[TASK_NAME] | |
| FIXED_FPS = 12 | |
| MIN_FRAMES_MODEL = 8 | |
| MAX_FRAMES_MODEL = 121 | |
| # Instantiate the pipeline in the global scope | |
| print("Initializing WanTI2V pipeline...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| device_id = 0 if torch.cuda.is_available() else -1 | |
| pipeline = wan.WanTI2V( | |
| config=cfg, | |
| checkpoint_dir=ckpt_dir, | |
| device_id=device_id, | |
| rank=0, | |
| t5_fsdp=False, | |
| dit_fsdp=False, | |
| use_sp=False, | |
| t5_cpu=False, | |
| init_on_cpu=False, | |
| convert_model_dtype=True, | |
| ) | |
| print("Pipeline initialized and ready.") | |
| # --- Helper Functions --- | |
| def clear_gpu_memory(): | |
| """Clear GPU memory more thoroughly""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| gc.collect() | |
| def select_best_size_for_image(image, available_sizes): | |
| """Select the size option with aspect ratio closest to the input image.""" | |
| if image is None: | |
| return available_sizes[0] # Return first option if no image | |
| img_width, img_height = image.size | |
| img_aspect_ratio = img_height / img_width | |
| best_size = available_sizes[0] | |
| best_diff = float('inf') | |
| for size_str in available_sizes: | |
| # Parse size string like "704*1280" | |
| height, width = map(int, size_str.split('*')) | |
| size_aspect_ratio = height / width | |
| diff = abs(img_aspect_ratio - size_aspect_ratio) | |
| if diff < best_diff: | |
| best_diff = diff | |
| best_size = size_str | |
| return best_size | |
| def handle_image_upload(image): | |
| """Handle image upload and return the best matching size.""" | |
| if image is None: | |
| return gr.update() | |
| pil_image = Image.fromarray(image).convert("RGB") | |
| available_sizes = list(SUPPORTED_SIZES[TASK_NAME]) | |
| best_size = select_best_size_for_image(pil_image, available_sizes) | |
| return gr.update(value=best_size) | |
| def validate_inputs(image, prompt, duration_seconds): | |
| """Validate user inputs""" | |
| errors = [] | |
| if not prompt or len(prompt.strip()) < 5: | |
| errors.append("Prompt must be at least 5 characters long.") | |
| if image is not None: | |
| img = Image.fromarray(image) | |
| if img.size[0] * img.size[1] > 4096 * 4096: | |
| errors.append("Image size is too large (maximum 4096x4096).") | |
| if duration_seconds > 10.1 and image is None: | |
| errors.append("Videos longer than 10.1 seconds require an input image.") | |
| return errors | |
| def get_duration(image, | |
| prompt, | |
| size, | |
| duration_seconds, | |
| sampling_steps, | |
| guide_scale, | |
| shift, | |
| seed, | |
| progress): | |
| """Calculate dynamic GPU duration based on parameters.""" | |
| if sampling_steps > 35 and duration_seconds >= 2: | |
| return 120 | |
| elif sampling_steps < 35 or duration_seconds < 2: | |
| return 105 | |
| else: | |
| return 90 | |
| def apply_template(template, current_prompt): | |
| """Apply prompt template""" | |
| if "{subject}" in template: | |
| # Extract the main subject from current prompt (simple heuristic) | |
| subject = current_prompt.split(",")[0] if "," in current_prompt else current_prompt | |
| return template.replace("{subject}", subject) | |
| return template + " " + current_prompt | |
| # --- 2. Gradio Inference Function --- | |
| def generate_video( | |
| image, | |
| prompt, | |
| size, | |
| duration_seconds, | |
| sampling_steps, | |
| guide_scale, | |
| shift, | |
| seed, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| """The main function to generate video, called by the Gradio interface.""" | |
| # Validate inputs | |
| errors = validate_inputs(image, prompt, duration_seconds) | |
| if errors: | |
| raise gr.Error("\n".join(errors)) | |
| progress(0, desc="Setting up...") | |
| if seed == -1: | |
| seed = random.randint(0, sys.maxsize) | |
| progress(0.1, desc="Processing image...") | |
| input_image = None | |
| if image is not None: | |
| input_image = Image.fromarray(image).convert("RGB") | |
| # Resize image to match selected size | |
| target_height, target_width = map(int, size.split('*')) | |
| input_image = input_image.resize((target_width, target_height)) | |
| # Calculate number of frames based on duration | |
| num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) | |
| progress(0.2, desc="Generating video...") | |
| try: | |
| video_tensor = pipeline.generate( | |
| input_prompt=prompt, | |
| img=input_image, # Pass None for T2V, Image for I2V | |
| size=SIZE_CONFIGS[size], | |
| max_area=MAX_AREA_CONFIGS[size], | |
| frame_num=num_frames, # Use calculated frames instead of cfg.frame_num | |
| shift=shift, | |
| sample_solver='unipc', | |
| sampling_steps=int(sampling_steps), | |
| guide_scale=guide_scale, | |
| seed=seed, | |
| offload_model=True | |
| ) | |
| progress(0.9, desc="Saving video...") | |
| # Save the video to a temporary file | |
| video_path = cache_video( | |
| tensor=video_tensor[None], # Add a batch dimension | |
| save_file=None, # cache_video will create a temp file | |
| fps=cfg.sample_fps, | |
| normalize=True, | |
| value_range=(-1, 1) | |
| ) | |
| progress(1.0, desc="Complete!") | |
| except torch.cuda.OutOfMemoryError: | |
| clear_gpu_memory() | |
| raise gr.Error("GPU out of memory. Please try with lower settings.") | |
| except Exception as e: | |
| raise gr.Error(f"Video generation failed: {str(e)}") | |
| finally: | |
| if 'video_tensor' in locals(): | |
| del video_tensor | |
| clear_gpu_memory() | |
| return video_path | |
| # --- 3. Gradio Interface --- | |
| css = """ | |
| .gradio-container {max-width: 1100px !important; margin: 0 auto} | |
| #output_video {height: 500px;} | |
| #input_image {height: 500px;} | |
| .template-btn {margin: 2px !important;} | |
| """ | |
| # Default prompt with motion emphasis | |
| DEFAULT_PROMPT = "Two friends are paddling a kayak across a tranquil alpine lake, its surface as still as a mirror. Snow-capped peaks and dense forests are reflected in the crystal-clear water, and gentle ripples spread outward as the kayak moves. The camera slowly pans to the right, capturing the serene beauty of the scene. Distant mountains and trees are clearly visible in the background, adding a sense of natural harmony and peace." | |
| # Prompt templates | |
| templates = { | |
| "Cinematic": "cinematic shot of {subject}, professional lighting, smooth camera movement, 4k quality", | |
| "Animation": "animated style {subject}, vibrant colors, fluid motion, dynamic movement", | |
| "Nature": "nature documentary footage of {subject}, wildlife photography, natural movement", | |
| "Slow Motion": "slow motion capture of {subject}, high speed camera, detailed motion", | |
| "Action": "dynamic action shot of {subject}, fast paced movement, energetic motion" | |
| } | |
| with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo: | |
| gr.Markdown(""" | |
| # Wan 2.2 TI2V Enhanced running on AMD MI355 | |
| Generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model** | |
| [[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B), [[paper]](https://arxiv.org/abs/2503.20314) | |
| ### 💡 Tips for best results: | |
| - 🖼️ Upload an image for better control over the video content | |
| - ⏱️ Longer videos require more processing time | |
| - 🎯 Be specific and descriptive in your prompts | |
| - 🎬 Include motion-related keywords for dynamic videos | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="numpy", label="Input Image (Optional)", elem_id="input_image") | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| value=DEFAULT_PROMPT, | |
| lines=3, | |
| placeholder="Describe the video you want to generate..." | |
| ) | |
| # Prompt templates section | |
| with gr.Accordion("Prompt Templates", open=False): | |
| gr.Markdown("Click a template to apply it to your prompt:") | |
| with gr.Row(): | |
| template_buttons = {} | |
| for name, template in templates.items(): | |
| btn = gr.Button(name, size="sm", elem_classes=["template-btn"]) | |
| template_buttons[name] = (btn, template) | |
| # Connect template buttons | |
| for name, (btn, template) in template_buttons.items(): | |
| btn.click( | |
| fn=lambda t=template, p=prompt_input: apply_template(t, p), | |
| inputs=[prompt_input], | |
| outputs=prompt_input | |
| ) | |
| duration_input = gr.Slider( | |
| minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1), | |
| maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1), | |
| step=0.1, | |
| value=2.0, | |
| label="Duration (seconds)", | |
| info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps." | |
| ) | |
| size_input = gr.Dropdown( | |
| label="Output Resolution", | |
| choices=list(SUPPORTED_SIZES[TASK_NAME]), | |
| value="704*1280" | |
| ) | |
| with gr.Column(scale=2): | |
| video_output = gr.Video(label="Generated Video", elem_id="output_video") | |
| # Status indicators | |
| with gr.Row(): | |
| status_text = gr.Textbox( | |
| label="Status", | |
| value="Ready", | |
| interactive=False, | |
| max_lines=1 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| steps_input = gr.Slider( | |
| label="Sampling Steps", | |
| minimum=10, | |
| maximum=50, | |
| value=38, | |
| step=1, | |
| info="Higher values = better quality but slower" | |
| ) | |
| scale_input = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=cfg.sample_guide_scale, | |
| step=0.1, | |
| info="Higher values = closer to prompt but less creative" | |
| ) | |
| shift_input = gr.Slider( | |
| label="Sample Shift", | |
| minimum=1.0, | |
| maximum=20.0, | |
| value=cfg.sample_shift, | |
| step=0.1, | |
| info="Affects the sampling process dynamics" | |
| ) | |
| seed_input = gr.Number( | |
| label="Seed (-1 for random)", | |
| value=-1, | |
| precision=0, | |
| info="Use same seed for reproducible results" | |
| ) | |
| run_button = gr.Button("Generate Video", variant="primary", size="lg") | |
| # Add image upload handler | |
| image_input.upload( | |
| fn=handle_image_upload, | |
| inputs=[image_input], | |
| outputs=[size_input] | |
| ) | |
| image_input.clear( | |
| fn=handle_image_upload, | |
| inputs=[image_input], | |
| outputs=[size_input] | |
| ) | |
| # Update status when generating | |
| def update_status_and_generate(*args): | |
| status_text.value = "Generating..." | |
| try: | |
| result = generate_video(*args) | |
| status_text.value = "Complete!" | |
| return result | |
| except Exception as e: | |
| status_text.value = "Error occurred" | |
| raise e | |
| example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG") | |
| gr.Examples( | |
| examples=[ | |
| [None, "Golden hour, soft lighting, warm colors, saturated colors, wide shot, left-heavy composition. A weathered gondolier stands in a flat-bottomed boat, propelling it forward with a long wooden pole through the flooded ruins of Venice. The decaying buildings on either side are cloaked in creeping vines and marked by rusted metalwork, their once-proud facades now crumbling into the water. The camera moves slowly forward and tilts left, revealing behind him the majestic remnants of the city bathed in the amber glow of the setting sun. Silhouettes of collapsed archways and broken domes rise against the golden skyline, while the still water reflects the warm hues of the sky and surrounding structures.", "1280*704", 4.0], | |
| [None, "In a surreal video, four miniature skiers glide down a winding, three-dimensional trail of thick white paint on a plain white canvas-like background. The textured paint mimics snow, with visible brushstrokes and uneven edges, enhanced by light and shadow. The skiers, in colorful gear, are posed dynamically from top to bottom, each casting a shadow that heightens the illusion of depth. This scene miniaturizes a grand outdoor sport into a vivid, imaginative artwork.", "1280*704", 2.0], | |
| [None, "In a time-lapse video, a crane slowly lifts a steel beam on a construction site. The camera pulls back slowly from a close-up, revealing details of the crane and the steel beam. The skyline transitions from day to night, with buildings and machinery in the background constantly operating. As the camera pulls further back, the busy scene of the entire construction site comes into view; cranes and other equipment continue working under the night sky, shaping the city's outline.", "704*1280", 2.5], | |
| [None, "Cinematic racetrack scene: Low-angle medium long shot of jockey-horse leap. High-contrast backlighting, warm tones, silhouettes. Slow-motion freeze with dust for dynamic tension. Scoreboard detail. Optimized for immersive video generation.", "1280*704", 3.0], | |
| ], | |
| inputs=[image_input, prompt_input, size_input, duration_input], | |
| outputs=video_output, | |
| fn=generate_video, | |
| cache_examples=False, | |
| ) | |
| run_button.click( | |
| fn=generate_video, | |
| inputs=[image_input, prompt_input, size_input, duration_input, steps_input, scale_input, shift_input, seed_input], | |
| outputs=video_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |