Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| # 导入所需的调度器 | |
| # 将 DPMSolverMultistepScheduler 替换为 DPMSolverSinglestepScheduler (用于 DPM++ SDE Karras) | |
| from diffusers import StableDiffusionXLInpaintPipeline, AutoencoderKL, DPMSolverSinglestepScheduler | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 使用 float16 可以节省显存,有助于处理稍大的图片 | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # ---------------------------- | |
| # 模型配置 | |
| # ---------------------------- | |
| # *** 修复点 1: 在配置中明确指定每个模型的 variant *** | |
| MODELS = { | |
| "RealVisXL_V5.0_Lightning": { | |
| "id": "SG161222/RealVisXL_V5.0_Lightning", | |
| "vae": "madebyollin/sdxl-vae-fp16-fix", | |
| "variant": "fp16" # 此模型有 fp16 变体 | |
| }, | |
| "Juggernaut-XL-Lightning": { | |
| "id": "RunDiffusion/Juggernaut-XL-Lightning", | |
| "vae": "madebyollin/sdxl-vae-fp16-fix", | |
| "variant": None # 此模型没有 fp16 变体,设为 None | |
| } | |
| } | |
| DEFAULT_MODEL_NAME = "RealVisXL_V5.0_Lightning" | |
| MAX_DIM = 1280 | |
| # ---------------------------- | |
| # 加载所有 SDXL Inpaint Pipelines | |
| # ---------------------------- | |
| pipelines = {} | |
| # *** 修复点 2: 修改函数签名以接收 variant 参数 *** | |
| def load_pipeline(model_id: str, vae_id: str, variant: str, dtype, device): | |
| """加载单个 pipeline 的函数""" | |
| try: | |
| print(f"[LOAD] 尝试加载 {model_id} (variant: {variant})...") | |
| pipe = StableDiffusionXLInpaintPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| variant=variant # 使用从配置传入的 variant | |
| ) | |
| pipe = pipe.to(device) | |
| print(f"[INFO] Pipeline {model_id} 加载成功") | |
| # 配置推荐的采样器 (Scheduler) | |
| # *** 修改点 1: 更换为 DPM++ SDE Karras 调度器 *** | |
| try: | |
| # 使用 DPMSolverSinglestepScheduler 实现 DPM++ SDE Karras | |
| pipe.scheduler = DPMSolverSinglestepScheduler.from_config( | |
| pipe.scheduler.config, | |
| use_karras_sigmas=True, # 启用 Karras 噪声时间表 | |
| algorithm_type="sde-dpmsolver++" # 启用 DPM++ SDE 模式 | |
| ) | |
| print("[INFO] 采样器已设置为 DPMSolverSinglestepScheduler (DPM++ SDE Karras)") | |
| except Exception as e: | |
| print(f"[WARN] 采样器配置失败: {e}") | |
| # 替换 VAE(可选) | |
| try: | |
| print(f"[LOAD] 尝试为 {model_id} 加载 VAE {vae_id} ...") | |
| vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=dtype).to(device) | |
| pipe.vae = vae | |
| print("[INFO] VAE 替换成功") | |
| except Exception as e: | |
| print(f"[WARN] VAE 替换失败,使用原模型 VAE: {e}") | |
| return pipe | |
| except Exception as e: | |
| print(f"[ERROR] Pipeline {model_id} 加载失败:{e}") | |
| raise e | |
| # *** 修复点 3: 在循环加载时传入 variant 配置 *** | |
| for name, config in MODELS.items(): | |
| print("-" * 30) | |
| print(f"正在初始化模型: {name}") | |
| pipelines[name] = load_pipeline( | |
| config["id"], | |
| config["vae"], | |
| config.get("variant"), # 使用 .get() 安全地获取 variant | |
| dtype, | |
| device | |
| ) | |
| print("-" * 30) | |
| print("所有模型已加载完毕!") | |
| # ---------------------------- | |
| # 辅助函数 (无变化) | |
| # ---------------------------- | |
| def to_binary_mask(mask_img: Image.Image) -> Image.Image: | |
| if mask_img is None: | |
| return None | |
| if mask_img.mode == "RGBA": | |
| alpha = mask_img.split()[3] | |
| bin_mask = alpha.point(lambda p: 255 if p > 0 else 0).convert("L") | |
| else: | |
| gray = mask_img.convert("L") | |
| bin_mask = gray.point(lambda p: 255 if p > 128 else 0).convert("L") | |
| return bin_mask | |
| def mask_white_fraction(bin_mask: Image.Image) -> float: | |
| arr = np.array(bin_mask) | |
| return float((arr > 0).sum()) / max(1, arr.size) | |
| def sd_resize_image(image: Image.Image, target_w: int, target_h: int) -> Image.Image: | |
| INPAINT_TARGET_DIM = 1024 | |
| w, h = image.size | |
| def round_to_8(x): | |
| return int(np.ceil(x / 8) * 8) | |
| if w > h: | |
| new_w = INPAINT_TARGET_DIM | |
| new_h = round_to_8(h * INPAINT_TARGET_DIM / w) | |
| else: | |
| new_h = INPAINT_TARGET_DIM | |
| new_w = round_to_8(w * INPAINT_TARGET_DIM / h) | |
| resized_image = image.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
| return resized_image | |
| # ---------------------------- | |
| # 图片上传预处理函数 (无变化) | |
| # ---------------------------- | |
| def preprocess_image_input(image_input): | |
| if image_input is None: | |
| return gr.update() | |
| original_source = image_input["background"].convert("RGB") | |
| w, h = original_source.size | |
| if w > MAX_DIM or h > MAX_DIM: | |
| print(f"[PREPROCESS] 原始分辨率 {w}x{h} 超过 {MAX_DIM},正在按比例缩小...") | |
| if w > h: | |
| new_w = MAX_DIM | |
| new_h = int(h * MAX_DIM / w) | |
| else: | |
| new_h = MAX_DIM | |
| new_w = int(w * MAX_DIM / h) | |
| resized_source = original_source.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
| print(f"[PREPROCESS] 缩小后的显示尺寸为 {resized_source.size}") | |
| return { | |
| "background": resized_source, | |
| "layers": None, | |
| "mask": None, | |
| "composite": resized_source | |
| } | |
| print(f"[PREPROCESS] 原始分辨率 {w}x{h} 小于等于 {MAX_DIM},无需处理。") | |
| return gr.update() | |
| # ---------------------------- | |
| # 主逻辑 fill_image (无变化) | |
| # ---------------------------- | |
| def fill_image(prompt, negative_prompt, image, model_selection, strength, guidance_scale, num_steps): | |
| if model_selection not in pipelines: | |
| raise gr.Error(f"选择的模型 '{model_selection}' 未加载,请检查启动日志。") | |
| pipe = pipelines[model_selection] | |
| print(f"[INFO] 当前使用模型: {model_selection}") | |
| source_original = image["background"].convert("RGB") | |
| original_size = source_original.size | |
| mask_layer = None | |
| if isinstance(image, dict): | |
| layers = image.get("layers", None) | |
| if layers and len(layers) > 0: | |
| mask_layer = layers[0] | |
| elif image.get("mask", None) is not None: | |
| mask_layer = image["mask"] | |
| if mask_layer is None: | |
| print("[WARN] 没有检测到 mask layer,返回原图") | |
| yield source_original, source_original | |
| return | |
| bin_mask_original = to_binary_mask(mask_layer) | |
| white_frac = mask_white_fraction(bin_mask_original) | |
| if white_frac < 0.005: | |
| yield source_original, source_original | |
| return | |
| current_guidance = guidance_scale | |
| source_sd = sd_resize_image(source_original, original_size[0], original_size[1]) | |
| sd_size = source_sd.size | |
| bin_mask_sd = bin_mask_original.resize(sd_size, Image.Resampling.NEAREST) | |
| print(f"[INFO] Inpaint 图像尺寸已调整为 SDXL 兼容尺寸: {sd_size}") | |
| out = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| image=source_sd, | |
| mask_image=bin_mask_sd, | |
| num_inference_steps=num_steps, | |
| guidance_scale=current_guidance, | |
| strength=strength | |
| ) | |
| result_sd = out.images[0].convert("RGB") | |
| final = result_sd.resize(original_size, Image.Resampling.LANCZOS) | |
| yield source_original, final | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| def clear_result(): | |
| return gr.update(value=None) | |
| title = """<h2 align="center">SDXL Inpaint - Multi-Model</h2> | |
| <div align="center">在图片上涂抹要替换区域,选择模型,输入描述并生成。</div> | |
| """ | |
| DEFAULT_NEGATIVE_PROMPT = "low quality, deformed, extra limbs, blur, watermark, lowres, bad anatomy, unnatural colors, text, logo" | |
| with gr.Blocks() as demo: | |
| gr.HTML(title) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", lines=1) | |
| with gr.Column(): | |
| negative_prompt = gr.Textbox(label="Negative Prompt", lines=1, value=DEFAULT_NEGATIVE_PROMPT) | |
| with gr.Row(): | |
| strength_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.6, | |
| label="Strength (去噪强度/保真度,越低越保留原图)", | |
| ) | |
| guidance_slider = gr.Slider( | |
| minimum=1.0, | |
| maximum=15.0, | |
| step=0.5, | |
| value=6.0, | |
| label="Guidance Scale (引导尺度,越高越遵循 Prompt)", | |
| ) | |
| with gr.Row(): | |
| # *** 修改点 2: 将默认步数改为 20,并调整提示 *** | |
| num_steps_slider = gr.Slider( | |
| minimum=4, | |
| maximum=50, | |
| step=1, | |
| value=20, # DPM++ SDE Karras 推荐 20-30 步 | |
| label="Num Inference Steps (运行步数,DPM++ SDE Karras 推荐 20-30 步)", | |
| ) | |
| with gr.Row(): | |
| run_button = gr.Button("Generate") | |
| with gr.Row(): | |
| input_image = gr.ImageMask( | |
| type="pil", | |
| label=f"Input Image (请上传图片并涂抹需要替换的区域,图片最长边将自动压缩至 {MAX_DIM}px)", | |
| layers=False, | |
| height=512, | |
| ) | |
| result = gr.ImageSlider(interactive=False, label="Generated Image") | |
| model_selection = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=DEFAULT_MODEL_NAME, | |
| label="Model", | |
| interactive=True | |
| ) | |
| input_image.upload( | |
| fn=preprocess_image_input, | |
| inputs=[input_image], | |
| outputs=[input_image] | |
| ) | |
| input_args = [ | |
| prompt, negative_prompt, input_image, model_selection, | |
| strength_slider, guidance_slider, num_steps_slider | |
| ] | |
| run_button.click(fn=clear_result, outputs=[result])\ | |
| .then(fill_image, | |
| input_args, | |
| outputs=[result]) | |
| prompt.submit(fn=clear_result, outputs=[result])\ | |
| .then(fill_image, | |
| input_args, | |
| outputs=[result]) | |
| demo.queue(max_size=12).launch(share=False) | |