US_FiLMUNet / test.py
Morelli001's picture
Upload folder using huggingface_hub
aee1a39 verified
# test_load_film_unet2d.py
import torch, os
from transformers import AutoModel, AutoConfig, AutoImageProcessor
# ✅ point to your local folder (or your HF repo id after pushing)
repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers")
print("Loading config...")
cfg = AutoConfig.from_pretrained(repo_or_path, trust_remote_code=True)
print(cfg)
print("Loading model and weights...")
proc = AutoImageProcessor.from_pretrained(repo_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True)
model.eval()
# --- quick synthetic forward ---
# x = torch.randn(1, cfg.in_channels, 512, 512)
from PIL import Image
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
x = Image.open("/home/nicola/Downloads/45.png").convert("RGB")
inputs = proc(images=x, return_tensors="pt") # {'pixel_values': B,C,H,W}
organ_id = torch.tensor([4]) # any valid organ id < cfg.n_organs
with torch.no_grad():
out = model(**inputs, organ_id=organ_id)
# Post-process: undo letterbox & resize back to original, with threshold 0.7
masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True)
# Save the first (since you used a single image, that'll be masks[0])
masks[0].save("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp.png")