import gradio as gr import spaces import torch import numpy as np from PIL import Image, ImageOps # 导入所需的调度器 # 将 DPMSolverMultistepScheduler 替换为 DPMSolverSinglestepScheduler (用于 DPM++ SDE Karras) from diffusers import StableDiffusionXLInpaintPipeline, AutoencoderKL, DPMSolverSinglestepScheduler device = "cuda" if torch.cuda.is_available() else "cpu" # 使用 float16 可以节省显存,有助于处理稍大的图片 dtype = torch.float16 if device == "cuda" else torch.float32 # ---------------------------- # 模型配置 # ---------------------------- # *** 修复点 1: 在配置中明确指定每个模型的 variant *** MODELS = { "RealVisXL_V5.0_Lightning": { "id": "SG161222/RealVisXL_V5.0_Lightning", "vae": "madebyollin/sdxl-vae-fp16-fix", "variant": "fp16" # 此模型有 fp16 变体 }, "Juggernaut-XL-Lightning": { "id": "RunDiffusion/Juggernaut-XL-Lightning", "vae": "madebyollin/sdxl-vae-fp16-fix", "variant": None # 此模型没有 fp16 变体,设为 None } } DEFAULT_MODEL_NAME = "RealVisXL_V5.0_Lightning" MAX_DIM = 1280 # ---------------------------- # 加载所有 SDXL Inpaint Pipelines # ---------------------------- pipelines = {} # *** 修复点 2: 修改函数签名以接收 variant 参数 *** def load_pipeline(model_id: str, vae_id: str, variant: str, dtype, device): """加载单个 pipeline 的函数""" try: print(f"[LOAD] 尝试加载 {model_id} (variant: {variant})...") pipe = StableDiffusionXLInpaintPipeline.from_pretrained( model_id, torch_dtype=dtype, variant=variant # 使用从配置传入的 variant ) pipe = pipe.to(device) print(f"[INFO] Pipeline {model_id} 加载成功") # 配置推荐的采样器 (Scheduler) # *** 修改点 1: 更换为 DPM++ SDE Karras 调度器 *** try: # 使用 DPMSolverSinglestepScheduler 实现 DPM++ SDE Karras pipe.scheduler = DPMSolverSinglestepScheduler.from_config( pipe.scheduler.config, use_karras_sigmas=True, # 启用 Karras 噪声时间表 algorithm_type="sde-dpmsolver++" # 启用 DPM++ SDE 模式 ) print("[INFO] 采样器已设置为 DPMSolverSinglestepScheduler (DPM++ SDE Karras)") except Exception as e: print(f"[WARN] 采样器配置失败: {e}") # 替换 VAE(可选) try: print(f"[LOAD] 尝试为 {model_id} 加载 VAE {vae_id} ...") vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=dtype).to(device) pipe.vae = vae print("[INFO] VAE 替换成功") except Exception as e: print(f"[WARN] VAE 替换失败,使用原模型 VAE: {e}") return pipe except Exception as e: print(f"[ERROR] Pipeline {model_id} 加载失败:{e}") raise e # *** 修复点 3: 在循环加载时传入 variant 配置 *** for name, config in MODELS.items(): print("-" * 30) print(f"正在初始化模型: {name}") pipelines[name] = load_pipeline( config["id"], config["vae"], config.get("variant"), # 使用 .get() 安全地获取 variant dtype, device ) print("-" * 30) print("所有模型已加载完毕!") # ---------------------------- # 辅助函数 (无变化) # ---------------------------- def to_binary_mask(mask_img: Image.Image) -> Image.Image: if mask_img is None: return None if mask_img.mode == "RGBA": alpha = mask_img.split()[3] bin_mask = alpha.point(lambda p: 255 if p > 0 else 0).convert("L") else: gray = mask_img.convert("L") bin_mask = gray.point(lambda p: 255 if p > 128 else 0).convert("L") return bin_mask def mask_white_fraction(bin_mask: Image.Image) -> float: arr = np.array(bin_mask) return float((arr > 0).sum()) / max(1, arr.size) def sd_resize_image(image: Image.Image, target_w: int, target_h: int) -> Image.Image: INPAINT_TARGET_DIM = 1024 w, h = image.size def round_to_8(x): return int(np.ceil(x / 8) * 8) if w > h: new_w = INPAINT_TARGET_DIM new_h = round_to_8(h * INPAINT_TARGET_DIM / w) else: new_h = INPAINT_TARGET_DIM new_w = round_to_8(w * INPAINT_TARGET_DIM / h) resized_image = image.resize((new_w, new_h), Image.Resampling.LANCZOS) return resized_image # ---------------------------- # 图片上传预处理函数 (无变化) # ---------------------------- def preprocess_image_input(image_input): if image_input is None: return gr.update() original_source = image_input["background"].convert("RGB") w, h = original_source.size if w > MAX_DIM or h > MAX_DIM: print(f"[PREPROCESS] 原始分辨率 {w}x{h} 超过 {MAX_DIM},正在按比例缩小...") if w > h: new_w = MAX_DIM new_h = int(h * MAX_DIM / w) else: new_h = MAX_DIM new_w = int(w * MAX_DIM / h) resized_source = original_source.resize((new_w, new_h), Image.Resampling.LANCZOS) print(f"[PREPROCESS] 缩小后的显示尺寸为 {resized_source.size}") return { "background": resized_source, "layers": None, "mask": None, "composite": resized_source } print(f"[PREPROCESS] 原始分辨率 {w}x{h} 小于等于 {MAX_DIM},无需处理。") return gr.update() # ---------------------------- # 主逻辑 fill_image (无变化) # ---------------------------- @spaces.GPU(duration=24) def fill_image(prompt, negative_prompt, image, model_selection, strength, guidance_scale, num_steps): if model_selection not in pipelines: raise gr.Error(f"选择的模型 '{model_selection}' 未加载,请检查启动日志。") pipe = pipelines[model_selection] print(f"[INFO] 当前使用模型: {model_selection}") source_original = image["background"].convert("RGB") original_size = source_original.size mask_layer = None if isinstance(image, dict): layers = image.get("layers", None) if layers and len(layers) > 0: mask_layer = layers[0] elif image.get("mask", None) is not None: mask_layer = image["mask"] if mask_layer is None: print("[WARN] 没有检测到 mask layer,返回原图") yield source_original, source_original return bin_mask_original = to_binary_mask(mask_layer) white_frac = mask_white_fraction(bin_mask_original) if white_frac < 0.005: yield source_original, source_original return current_guidance = guidance_scale source_sd = sd_resize_image(source_original, original_size[0], original_size[1]) sd_size = source_sd.size bin_mask_sd = bin_mask_original.resize(sd_size, Image.Resampling.NEAREST) print(f"[INFO] Inpaint 图像尺寸已调整为 SDXL 兼容尺寸: {sd_size}") out = pipe( prompt=prompt, negative_prompt=negative_prompt, image=source_sd, mask_image=bin_mask_sd, num_inference_steps=num_steps, guidance_scale=current_guidance, strength=strength ) result_sd = out.images[0].convert("RGB") final = result_sd.resize(original_size, Image.Resampling.LANCZOS) yield source_original, final # ---------------------------- # Gradio UI # ---------------------------- def clear_result(): return gr.update(value=None) title = """

SDXL Inpaint - Multi-Model

在图片上涂抹要替换区域,选择模型,输入描述并生成。
""" DEFAULT_NEGATIVE_PROMPT = "low quality, deformed, extra limbs, blur, watermark, lowres, bad anatomy, unnatural colors, text, logo" with gr.Blocks() as demo: gr.HTML(title) with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", lines=1) with gr.Column(): negative_prompt = gr.Textbox(label="Negative Prompt", lines=1, value=DEFAULT_NEGATIVE_PROMPT) with gr.Row(): strength_slider = gr.Slider( minimum=0.0, maximum=1.0, step=0.05, value=0.6, label="Strength (去噪强度/保真度,越低越保留原图)", ) guidance_slider = gr.Slider( minimum=1.0, maximum=15.0, step=0.5, value=6.0, label="Guidance Scale (引导尺度,越高越遵循 Prompt)", ) with gr.Row(): # *** 修改点 2: 将默认步数改为 20,并调整提示 *** num_steps_slider = gr.Slider( minimum=4, maximum=50, step=1, value=20, # DPM++ SDE Karras 推荐 20-30 步 label="Num Inference Steps (运行步数,DPM++ SDE Karras 推荐 20-30 步)", ) with gr.Row(): run_button = gr.Button("Generate") with gr.Row(): input_image = gr.ImageMask( type="pil", label=f"Input Image (请上传图片并涂抹需要替换的区域,图片最长边将自动压缩至 {MAX_DIM}px)", layers=False, height=512, ) result = gr.ImageSlider(interactive=False, label="Generated Image") model_selection = gr.Dropdown( choices=list(MODELS.keys()), value=DEFAULT_MODEL_NAME, label="Model", interactive=True ) input_image.upload( fn=preprocess_image_input, inputs=[input_image], outputs=[input_image] ) input_args = [ prompt, negative_prompt, input_image, model_selection, strength_slider, guidance_slider, num_steps_slider ] run_button.click(fn=clear_result, outputs=[result])\ .then(fill_image, input_args, outputs=[result]) prompt.submit(fn=clear_result, outputs=[result])\ .then(fill_image, input_args, outputs=[result]) demo.queue(max_size=12).launch(share=False)