File size: 2,401 Bytes
a2e0537 438cd9d 5e0ef09 438cd9d a2e0537 438cd9d a2e0537 5e0ef09 a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d 5e0ef09 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d 5e0ef09 5f28e7c 438cd9d 21d4f19 5e0ef09 438cd9d 5e0ef09 438cd9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import timm
# --- 🔍 版本檢查區 (請看這裡) ---
import sys
print("="*30)
print(f"Python version: {sys.version}")
print(f"Gradio version: {gr.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Timm version: {timm.__version__}")
print("="*30)
# -----------------------------
# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
print("⚠️ 警告: 未偵測到 HF_TOKEN")
try:
model = AutoModelForImageSegmentation.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token
)
device = torch.device("cpu")
model.to(device)
model.eval()
print("✅ 模型載入成功!")
except Exception as e:
print(f"❌ 模型載入失敗: {e}")
# --- 2. 圖像處理 (官方邏輯) ---
def process_image(input_image):
if input_image is None:
return None
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(input_image).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(input_image.size)
image = input_image.convert("RGBA")
image.putalpha(mask)
return image
# --- 3. 介面 ---
# 為了驗證,我們也在網頁上顯示版本
version_info = f"目前運行版本 - Gradio: {gr.__version__} | Torch: {torch.__version__}"
with gr.Blocks(title="版本檢查") as app:
gr.Markdown(f"## ✂️ AI 自動去背")
gr.Markdown(f"ℹ️ **{version_info}**") # 這裡會直接顯示在網頁上
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="上傳圖片")
btn = gr.Button("開始去背")
with gr.Column():
output_img = gr.Image(type="pil", label="去背結果")
btn.click(fn=process_image, inputs=input_img, outputs=output_img)
if __name__ == "__main__":
app.launch() |