mimigona1's picture
Update app.py
72ab9d2 verified
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image, ImageOps
# 导入所需的调度器
# 将 DPMSolverMultistepScheduler 替换为 DPMSolverSinglestepScheduler (用于 DPM++ SDE Karras)
from diffusers import StableDiffusionXLInpaintPipeline, AutoencoderKL, DPMSolverSinglestepScheduler
device = "cuda" if torch.cuda.is_available() else "cpu"
# 使用 float16 可以节省显存,有助于处理稍大的图片
dtype = torch.float16 if device == "cuda" else torch.float32
# ----------------------------
# 模型配置
# ----------------------------
# *** 修复点 1: 在配置中明确指定每个模型的 variant ***
MODELS = {
"RealVisXL_V5.0_Lightning": {
"id": "SG161222/RealVisXL_V5.0_Lightning",
"vae": "madebyollin/sdxl-vae-fp16-fix",
"variant": "fp16" # 此模型有 fp16 变体
},
"Juggernaut-XL-Lightning": {
"id": "RunDiffusion/Juggernaut-XL-Lightning",
"vae": "madebyollin/sdxl-vae-fp16-fix",
"variant": None # 此模型没有 fp16 变体,设为 None
}
}
DEFAULT_MODEL_NAME = "RealVisXL_V5.0_Lightning"
MAX_DIM = 1280
# ----------------------------
# 加载所有 SDXL Inpaint Pipelines
# ----------------------------
pipelines = {}
# *** 修复点 2: 修改函数签名以接收 variant 参数 ***
def load_pipeline(model_id: str, vae_id: str, variant: str, dtype, device):
"""加载单个 pipeline 的函数"""
try:
print(f"[LOAD] 尝试加载 {model_id} (variant: {variant})...")
pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
variant=variant # 使用从配置传入的 variant
)
pipe = pipe.to(device)
print(f"[INFO] Pipeline {model_id} 加载成功")
# 配置推荐的采样器 (Scheduler)
# *** 修改点 1: 更换为 DPM++ SDE Karras 调度器 ***
try:
# 使用 DPMSolverSinglestepScheduler 实现 DPM++ SDE Karras
pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
pipe.scheduler.config,
use_karras_sigmas=True, # 启用 Karras 噪声时间表
algorithm_type="sde-dpmsolver++" # 启用 DPM++ SDE 模式
)
print("[INFO] 采样器已设置为 DPMSolverSinglestepScheduler (DPM++ SDE Karras)")
except Exception as e:
print(f"[WARN] 采样器配置失败: {e}")
# 替换 VAE(可选)
try:
print(f"[LOAD] 尝试为 {model_id} 加载 VAE {vae_id} ...")
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=dtype).to(device)
pipe.vae = vae
print("[INFO] VAE 替换成功")
except Exception as e:
print(f"[WARN] VAE 替换失败,使用原模型 VAE: {e}")
return pipe
except Exception as e:
print(f"[ERROR] Pipeline {model_id} 加载失败:{e}")
raise e
# *** 修复点 3: 在循环加载时传入 variant 配置 ***
for name, config in MODELS.items():
print("-" * 30)
print(f"正在初始化模型: {name}")
pipelines[name] = load_pipeline(
config["id"],
config["vae"],
config.get("variant"), # 使用 .get() 安全地获取 variant
dtype,
device
)
print("-" * 30)
print("所有模型已加载完毕!")
# ----------------------------
# 辅助函数 (无变化)
# ----------------------------
def to_binary_mask(mask_img: Image.Image) -> Image.Image:
if mask_img is None:
return None
if mask_img.mode == "RGBA":
alpha = mask_img.split()[3]
bin_mask = alpha.point(lambda p: 255 if p > 0 else 0).convert("L")
else:
gray = mask_img.convert("L")
bin_mask = gray.point(lambda p: 255 if p > 128 else 0).convert("L")
return bin_mask
def mask_white_fraction(bin_mask: Image.Image) -> float:
arr = np.array(bin_mask)
return float((arr > 0).sum()) / max(1, arr.size)
def sd_resize_image(image: Image.Image, target_w: int, target_h: int) -> Image.Image:
INPAINT_TARGET_DIM = 1024
w, h = image.size
def round_to_8(x):
return int(np.ceil(x / 8) * 8)
if w > h:
new_w = INPAINT_TARGET_DIM
new_h = round_to_8(h * INPAINT_TARGET_DIM / w)
else:
new_h = INPAINT_TARGET_DIM
new_w = round_to_8(w * INPAINT_TARGET_DIM / h)
resized_image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
return resized_image
# ----------------------------
# 图片上传预处理函数 (无变化)
# ----------------------------
def preprocess_image_input(image_input):
if image_input is None:
return gr.update()
original_source = image_input["background"].convert("RGB")
w, h = original_source.size
if w > MAX_DIM or h > MAX_DIM:
print(f"[PREPROCESS] 原始分辨率 {w}x{h} 超过 {MAX_DIM},正在按比例缩小...")
if w > h:
new_w = MAX_DIM
new_h = int(h * MAX_DIM / w)
else:
new_h = MAX_DIM
new_w = int(w * MAX_DIM / h)
resized_source = original_source.resize((new_w, new_h), Image.Resampling.LANCZOS)
print(f"[PREPROCESS] 缩小后的显示尺寸为 {resized_source.size}")
return {
"background": resized_source,
"layers": None,
"mask": None,
"composite": resized_source
}
print(f"[PREPROCESS] 原始分辨率 {w}x{h} 小于等于 {MAX_DIM},无需处理。")
return gr.update()
# ----------------------------
# 主逻辑 fill_image (无变化)
# ----------------------------
@spaces.GPU(duration=24)
def fill_image(prompt, negative_prompt, image, model_selection, strength, guidance_scale, num_steps):
if model_selection not in pipelines:
raise gr.Error(f"选择的模型 '{model_selection}' 未加载,请检查启动日志。")
pipe = pipelines[model_selection]
print(f"[INFO] 当前使用模型: {model_selection}")
source_original = image["background"].convert("RGB")
original_size = source_original.size
mask_layer = None
if isinstance(image, dict):
layers = image.get("layers", None)
if layers and len(layers) > 0:
mask_layer = layers[0]
elif image.get("mask", None) is not None:
mask_layer = image["mask"]
if mask_layer is None:
print("[WARN] 没有检测到 mask layer,返回原图")
yield source_original, source_original
return
bin_mask_original = to_binary_mask(mask_layer)
white_frac = mask_white_fraction(bin_mask_original)
if white_frac < 0.005:
yield source_original, source_original
return
current_guidance = guidance_scale
source_sd = sd_resize_image(source_original, original_size[0], original_size[1])
sd_size = source_sd.size
bin_mask_sd = bin_mask_original.resize(sd_size, Image.Resampling.NEAREST)
print(f"[INFO] Inpaint 图像尺寸已调整为 SDXL 兼容尺寸: {sd_size}")
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=source_sd,
mask_image=bin_mask_sd,
num_inference_steps=num_steps,
guidance_scale=current_guidance,
strength=strength
)
result_sd = out.images[0].convert("RGB")
final = result_sd.resize(original_size, Image.Resampling.LANCZOS)
yield source_original, final
# ----------------------------
# Gradio UI
# ----------------------------
def clear_result():
return gr.update(value=None)
title = """<h2 align="center">SDXL Inpaint - Multi-Model</h2>
<div align="center">在图片上涂抹要替换区域,选择模型,输入描述并生成。</div>
"""
DEFAULT_NEGATIVE_PROMPT = "low quality, deformed, extra limbs, blur, watermark, lowres, bad anatomy, unnatural colors, text, logo"
with gr.Blocks() as demo:
gr.HTML(title)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", lines=1)
with gr.Column():
negative_prompt = gr.Textbox(label="Negative Prompt", lines=1, value=DEFAULT_NEGATIVE_PROMPT)
with gr.Row():
strength_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.6,
label="Strength (去噪强度/保真度,越低越保留原图)",
)
guidance_slider = gr.Slider(
minimum=1.0,
maximum=15.0,
step=0.5,
value=6.0,
label="Guidance Scale (引导尺度,越高越遵循 Prompt)",
)
with gr.Row():
# *** 修改点 2: 将默认步数改为 20,并调整提示 ***
num_steps_slider = gr.Slider(
minimum=4,
maximum=50,
step=1,
value=20, # DPM++ SDE Karras 推荐 20-30 步
label="Num Inference Steps (运行步数,DPM++ SDE Karras 推荐 20-30 步)",
)
with gr.Row():
run_button = gr.Button("Generate")
with gr.Row():
input_image = gr.ImageMask(
type="pil",
label=f"Input Image (请上传图片并涂抹需要替换的区域,图片最长边将自动压缩至 {MAX_DIM}px)",
layers=False,
height=512,
)
result = gr.ImageSlider(interactive=False, label="Generated Image")
model_selection = gr.Dropdown(
choices=list(MODELS.keys()),
value=DEFAULT_MODEL_NAME,
label="Model",
interactive=True
)
input_image.upload(
fn=preprocess_image_input,
inputs=[input_image],
outputs=[input_image]
)
input_args = [
prompt, negative_prompt, input_image, model_selection,
strength_slider, guidance_slider, num_steps_slider
]
run_button.click(fn=clear_result, outputs=[result])\
.then(fill_image,
input_args,
outputs=[result])
prompt.submit(fn=clear_result, outputs=[result])\
.then(fill_image,
input_args,
outputs=[result])
demo.queue(max_size=12).launch(share=False)