learn / app.py
urnotwen's picture
Update app.py
53c4cc4 verified
raw
history blame
2.26 kB
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
print("⚠️ 警告: 未偵測到 HF_TOKEN,請檢查 Settings 中的 Secret。")
try:
# 載入模型
model = AutoModelForImageSegmentation.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token
)
# 強制使用 CPU
device = torch.device("cpu")
model.to(device)
model.eval()
print("✅ 模型載入成功!")
except Exception as e:
print(f"❌ 模型載入失敗: {e}")
# --- 2. 定義圖像處理邏輯 ---
def process_image(input_image):
if input_image is None:
return None
# 準備圖像尺寸
image_size = (1024, 1024)
# 定義轉換
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 轉換並增加維度
input_images = transform_image(input_image).unsqueeze(0).to(device)
# 開始預測
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
# 處理預測結果
pred = preds[0].squeeze()
# 轉回 PIL 圖片
pred_pil = transforms.ToPILImage()(pred)
# 調整回原始圖片的大小
mask = pred_pil.resize(input_image.size)
# 合成去背圖
image = input_image.convert("RGBA")
image.putalpha(mask)
return image
# --- 3. 建立介面 (移除所有可能導致錯誤的外觀設定) ---
# 這裡去掉了 theme=... 參數,這是最安全的寫法
with gr.Blocks() as app:
gr.Markdown("## ✂️ AI 自動去背 (RMBG 2.0)")
with gr.Column():
input_img = gr.Image(type="pil", label="上傳圖片", source="upload")
btn = gr.Button("開始去背")
output_img = gr.Image(type="pil", label="去背結果")
btn.click(fn=process_image, inputs=input_img, outputs=output_img)
if __name__ == "__main__":
app.launch()