File size: 2,256 Bytes
a2e0537
438cd9d
 
 
 
 
 
a2e0537
438cd9d
 
 
a2e0537
 
 
 
 
438cd9d
a2e0537
 
 
 
 
 
53c4cc4
a2e0537
438cd9d
a2e0537
 
438cd9d
a2e0537
438cd9d
53c4cc4
438cd9d
 
 
 
a2e0537
 
438cd9d
a2e0537
438cd9d
a2e0537
438cd9d
 
 
 
a2e0537
 
438cd9d
a2e0537
438cd9d
a2e0537
438cd9d
a2e0537
 
438cd9d
a2e0537
 
 
 
 
438cd9d
 
a2e0537
 
438cd9d
a2e0537
438cd9d
53c4cc4
 
 
438cd9d
a2e0537
438cd9d
a2e0537
53c4cc4
 
 
438cd9d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image

# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")

hf_token = os.getenv("HF_TOKEN")

if not hf_token:
    print("⚠️ 警告: 未偵測到 HF_TOKEN,請檢查 Settings 中的 Secret。")

try:
    # 載入模型
    model = AutoModelForImageSegmentation.from_pretrained(
        model_id, 
        trust_remote_code=True, 
        token=hf_token
    )
    # 強制使用 CPU
    device = torch.device("cpu") 
    model.to(device)
    model.eval()
    print("✅ 模型載入成功!")
except Exception as e:
    print(f"❌ 模型載入失敗: {e}")

# --- 2. 定義圖像處理邏輯 ---
def process_image(input_image):
    if input_image is None:
        return None
    
    # 準備圖像尺寸
    image_size = (1024, 1024)
    
    # 定義轉換
    transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # 轉換並增加維度
    input_images = transform_image(input_image).unsqueeze(0).to(device)
    
    # 開始預測
    with torch.no_grad():
        preds = model(input_images)[-1].sigmoid().cpu()
    
    # 處理預測結果
    pred = preds[0].squeeze()
    
    # 轉回 PIL 圖片
    pred_pil = transforms.ToPILImage()(pred)
    
    # 調整回原始圖片的大小
    mask = pred_pil.resize(input_image.size)
    
    # 合成去背圖
    image = input_image.convert("RGBA")
    image.putalpha(mask)
    
    return image

# --- 3. 建立介面 (移除所有可能導致錯誤的外觀設定) ---
# 這裡去掉了 theme=... 參數,這是最安全的寫法
with gr.Blocks() as app:
    
    gr.Markdown("## ✂️ AI 自動去背 (RMBG 2.0)")
    
    with gr.Column():
        input_img = gr.Image(type="pil", label="上傳圖片", source="upload")
        btn = gr.Button("開始去背")
        output_img = gr.Image(type="pil", label="去背結果")

    btn.click(fn=process_image, inputs=input_img, outputs=output_img)

if __name__ == "__main__":
    app.launch()