File size: 3,686 Bytes
a2e0537 438cd9d 5e0ef09 6e3bd1b 5e0ef09 0babf83 438cd9d a2e0537 438cd9d a2e0537 6e3bd1b a2e0537 438cd9d a2e0537 6e3bd1b 438cd9d a2e0537 6e3bd1b 438cd9d a2e0537 438cd9d 6e3bd1b 438cd9d 6e3bd1b a2e0537 438cd9d a2e0537 438cd9d 6e3bd1b a2e0537 6e3bd1b 438cd9d a2e0537 6e3bd1b a2e0537 6e3bd1b a2e0537 438cd9d 0babf83 3768461 b1df3da 3768461 6bd12d9 3768461 b1df3da 6bd12d9 b1df3da 0babf83 c0bb6b3 6bd12d9 3768461 a6b051c 5a52511 a6b051c 6af4e47 3768461 5a52511 9a43858 438cd9d 5a52511 3768461 645fd65 438cd9d 6e3bd1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import timm
import io
import sys
import psutil # 記得 import 這個,才能調用後台硬體資料
# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
print("⚠️ 警告: 未偵測到 HF_TOKEN,如果是 Gated Model 可能會失敗")
try:
model = AutoModelForImageSegmentation.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(f"✅ 模型載入成功!使用裝置: {device}")
except Exception as e:
print(f"❌ 模型載入失敗: {e}")
# --- 2. 圖像處理 ---
def process_image(input_image):
if input_image is None:
return None
# 處理邏輯
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(input_image).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(input_image.size)
image = input_image.convert("RGBA")
image.putalpha(mask)
return image
# --- 3.取得系統狀態的函數 ---
# For React Frontend (JSON)
def get_system_stats_api():
return {
"cpu": psutil.cpu_percent(interval=1),
"ram": psutil.virtual_memory().percent
}
# For Gradio UI (Visual Markdown)
def get_system_stats_ui():
cpu = psutil.cpu_percent(interval=1)
ram = psutil.virtual_memory().percent
return f"""
## 🖥️ System Status
| Metric | Usage |
|--------|-------|
| **CPU** | {cpu}% |
| **RAM** | {ram}% |
"""
# --- 4. 介面 ---
with gr.Blocks(title="去背服務測試") as app:
gr.Markdown("## ✂️ 去背服務測試RM2")
with gr.Tabs():
# Tab 1: Image Processing
with gr.Tab("✂️ Remove Background"):
with gr.Row():
img_in = gr.Image(type="pil", label="Input Image")
img_out = gr.Image(type="pil", label="Result (PNG)", format="png")
btn = gr.Button("Remove Background", variant="primary")
btn.click(process_image, inputs=img_in, outputs=img_out)
# Tab 2: System Monitor (UI for Space Page)
with gr.Tab("📊 System Monitor"):
gr.Markdown("Click the button below to check current server load.")
stats_output = gr.Markdown("### Status: Waiting...")
refresh_btn = gr.Button("🔄 Refresh Stats")
refresh_btn.click(get_system_stats_ui, outputs=stats_output)
# Auto-load stats when page opens
app.load(get_system_stats_ui, outputs=stats_output)
# Hidden API Route for React Frontend
# The frontend calls this via client.predict("/status")
api_status = gr.JSON(visible=False, label="API Response")
api_btn = gr.Button("API Status", visible=False)
api_btn.click(get_system_stats_api, outputs=api_status, api_name="status")
# --- 4. 啟動 ---要有這段才能外部調用這SPACE
if __name__ == "__main__":
# 新版 Gradio 預設 API 開放 CORS,不需要 cors_allowed_origins
app.launch(server_name="0.0.0.0", server_port=7860) |