File size: 3,686 Bytes
a2e0537
438cd9d
 
 
 
 
5e0ef09
6e3bd1b
5e0ef09
0babf83
438cd9d
a2e0537
438cd9d
 
 
a2e0537
 
 
6e3bd1b
a2e0537
438cd9d
a2e0537
 
 
 
 
6e3bd1b
438cd9d
a2e0537
6e3bd1b
438cd9d
a2e0537
438cd9d
6e3bd1b
438cd9d
 
 
6e3bd1b
 
a2e0537
438cd9d
a2e0537
438cd9d
 
 
6e3bd1b
a2e0537
6e3bd1b
438cd9d
a2e0537
6e3bd1b
a2e0537
 
 
6e3bd1b
a2e0537
 
 
438cd9d
0babf83
3768461
b1df3da
3768461
6bd12d9
3768461
 
 
b1df3da
6bd12d9
b1df3da
 
 
 
 
 
 
 
0babf83
 
c0bb6b3
6bd12d9
3768461
a6b051c
5a52511
 
 
 
 
a6b051c
6af4e47
3768461
5a52511
 
 
 
 
 
 
9a43858
438cd9d
5a52511
 
 
 
 
3768461
645fd65
438cd9d
6e3bd1b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import timm
import io
import sys
import psutil # 記得 import 這個,才能調用後台硬體資料

# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")

hf_token = os.getenv("HF_TOKEN")

if not hf_token:
    print("⚠️ 警告: 未偵測到 HF_TOKEN,如果是 Gated Model 可能會失敗")

try:
    model = AutoModelForImageSegmentation.from_pretrained(
        model_id, 
        trust_remote_code=True, 
        token=hf_token
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    print(f"✅ 模型載入成功!使用裝置: {device}")
except Exception as e:
    print(f"❌ 模型載入失敗: {e}")

# --- 2. 圖像處理 ---
def process_image(input_image):
    if input_image is None:
        return None
    
    # 處理邏輯
    image_size = (1024, 1024)
    transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    input_images = transform_image(input_image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        preds = model(input_images)[-1].sigmoid().cpu()
    
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(input_image.size)
    
    image = input_image.convert("RGBA")
    image.putalpha(mask)
    return image

# --- 3.取得系統狀態的函數 ---
# For React Frontend (JSON)
def get_system_stats_api():
    return {
        "cpu": psutil.cpu_percent(interval=1),
        "ram": psutil.virtual_memory().percent
    }
# For Gradio UI (Visual Markdown)
def get_system_stats_ui():
    cpu = psutil.cpu_percent(interval=1)
    ram = psutil.virtual_memory().percent
    return f"""
    ## 🖥️ System Status
    | Metric | Usage |
    |--------|-------|
    | **CPU** | {cpu}% |
    | **RAM** | {ram}% |
    """    

# --- 4. 介面 ---
with gr.Blocks(title="去背服務測試") as app:
    gr.Markdown("## ✂️ 去背服務測試RM2")
    
    with gr.Tabs():
        # Tab 1: Image Processing
        with gr.Tab("✂️ Remove Background"):
            with gr.Row():
                img_in = gr.Image(type="pil", label="Input Image")
                img_out = gr.Image(type="pil", label="Result (PNG)", format="png")
            btn = gr.Button("Remove Background", variant="primary")
            btn.click(process_image, inputs=img_in, outputs=img_out)
            
        # Tab 2: System Monitor (UI for Space Page)
        with gr.Tab("📊 System Monitor"):
            gr.Markdown("Click the button below to check current server load.")
            stats_output = gr.Markdown("### Status: Waiting...")
            refresh_btn = gr.Button("🔄 Refresh Stats")
            refresh_btn.click(get_system_stats_ui, outputs=stats_output)
            # Auto-load stats when page opens
            app.load(get_system_stats_ui, outputs=stats_output)

    # Hidden API Route for React Frontend
    # The frontend calls this via client.predict("/status")
    api_status = gr.JSON(visible=False, label="API Response")
    api_btn = gr.Button("API Status", visible=False)
    api_btn.click(get_system_stats_api, outputs=api_status, api_name="status")
    
# --- 4. 啟動 ---要有這段才能外部調用這SPACE
if __name__ == "__main__":
    # 新版 Gradio 預設 API 開放 CORS,不需要 cors_allowed_origins
    app.launch(server_name="0.0.0.0", server_port=7860)