File size: 2,256 Bytes
a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 53c4cc4 a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d 53c4cc4 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d 53c4cc4 438cd9d a2e0537 438cd9d a2e0537 53c4cc4 438cd9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
print("⚠️ 警告: 未偵測到 HF_TOKEN,請檢查 Settings 中的 Secret。")
try:
# 載入模型
model = AutoModelForImageSegmentation.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token
)
# 強制使用 CPU
device = torch.device("cpu")
model.to(device)
model.eval()
print("✅ 模型載入成功!")
except Exception as e:
print(f"❌ 模型載入失敗: {e}")
# --- 2. 定義圖像處理邏輯 ---
def process_image(input_image):
if input_image is None:
return None
# 準備圖像尺寸
image_size = (1024, 1024)
# 定義轉換
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 轉換並增加維度
input_images = transform_image(input_image).unsqueeze(0).to(device)
# 開始預測
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
# 處理預測結果
pred = preds[0].squeeze()
# 轉回 PIL 圖片
pred_pil = transforms.ToPILImage()(pred)
# 調整回原始圖片的大小
mask = pred_pil.resize(input_image.size)
# 合成去背圖
image = input_image.convert("RGBA")
image.putalpha(mask)
return image
# --- 3. 建立介面 (移除所有可能導致錯誤的外觀設定) ---
# 這裡去掉了 theme=... 參數,這是最安全的寫法
with gr.Blocks() as app:
gr.Markdown("## ✂️ AI 自動去背 (RMBG 2.0)")
with gr.Column():
input_img = gr.Image(type="pil", label="上傳圖片", source="upload")
btn = gr.Button("開始去背")
output_img = gr.Image(type="pil", label="去背結果")
btn.click(fn=process_image, inputs=input_img, outputs=output_img)
if __name__ == "__main__":
app.launch() |