learn / app.py
urnotwen's picture
Update app.py
5f28e7c verified
raw
history blame
2.4 kB
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import timm
# --- 🔍 版本檢查區 (請看這裡) ---
import sys
print("="*30)
print(f"Python version: {sys.version}")
print(f"Gradio version: {gr.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Timm version: {timm.__version__}")
print("="*30)
# -----------------------------
# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
print("⚠️ 警告: 未偵測到 HF_TOKEN")
try:
model = AutoModelForImageSegmentation.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token
)
device = torch.device("cpu")
model.to(device)
model.eval()
print("✅ 模型載入成功!")
except Exception as e:
print(f"❌ 模型載入失敗: {e}")
# --- 2. 圖像處理 (官方邏輯) ---
def process_image(input_image):
if input_image is None:
return None
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(input_image).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(input_image.size)
image = input_image.convert("RGBA")
image.putalpha(mask)
return image
# --- 3. 介面 ---
# 為了驗證,我們也在網頁上顯示版本
version_info = f"目前運行版本 - Gradio: {gr.__version__} | Torch: {torch.__version__}"
with gr.Blocks(title="版本檢查") as app:
gr.Markdown(f"## ✂️ AI 自動去背")
gr.Markdown(f"ℹ️ **{version_info}**") # 這裡會直接顯示在網頁上
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="上傳圖片")
btn = gr.Button("開始去背")
with gr.Column():
output_img = gr.Image(type="pil", label="去背結果")
btn.click(fn=process_image, inputs=input_img, outputs=output_img)
if __name__ == "__main__":
app.launch()