import os import gradio as gr from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms from PIL import Image import timm # --- 🔍 版本檢查區 (請看這裡) --- import sys print("="*30) print(f"Python version: {sys.version}") print(f"Gradio version: {gr.__version__}") print(f"Torch version: {torch.__version__}") print(f"Timm version: {timm.__version__}") print("="*30) # ----------------------------- # --- 1. 初始化模型 --- model_id = "briaai/RMBG-2.0" print(f"正在載入模型: {model_id} ...") hf_token = os.getenv("HF_TOKEN") if not hf_token: print("⚠️ 警告: 未偵測到 HF_TOKEN") try: model = AutoModelForImageSegmentation.from_pretrained( model_id, trust_remote_code=True, token=hf_token ) device = torch.device("cpu") model.to(device) model.eval() print("✅ 模型載入成功!") except Exception as e: print(f"❌ 模型載入失敗: {e}") # --- 2. 圖像處理 (官方邏輯) --- def process_image(input_image): if input_image is None: return None image_size = (1024, 1024) transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) input_images = transform_image(input_image).unsqueeze(0).to(device) with torch.no_grad(): preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(input_image.size) image = input_image.convert("RGBA") image.putalpha(mask) return image # --- 3. 介面 --- # 為了驗證,我們也在網頁上顯示版本 version_info = f"目前運行版本 - Gradio: {gr.__version__} | Torch: {torch.__version__}" with gr.Blocks(title="版本檢查") as app: gr.Markdown(f"## ✂️ AI 自動去背") gr.Markdown(f"ℹ️ **{version_info}**") # 這裡會直接顯示在網頁上 with gr.Row(): with gr.Column(): input_img = gr.Image(type="pil", label="上傳圖片") btn = gr.Button("開始去背") with gr.Column(): output_img = gr.Image(type="pil", label="去背結果") btn.click(fn=process_image, inputs=input_img, outputs=output_img) if __name__ == "__main__": app.launch()