File size: 1,403 Bytes
aee1a39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# test_load_film_unet2d.py
import torch, os
from transformers import AutoModel, AutoConfig, AutoImageProcessor

# ✅ point to your local folder (or your HF repo id after pushing)
repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers")

print("Loading config...")
cfg = AutoConfig.from_pretrained(repo_or_path, trust_remote_code=True)
print(cfg)

print("Loading model and weights...")
proc  = AutoImageProcessor.from_pretrained(repo_or_path, trust_remote_code=True)

model = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True)
model.eval()

# --- quick synthetic forward ---
# x = torch.randn(1, cfg.in_channels, 512, 512)
from PIL import Image
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
x = Image.open("/home/nicola/Downloads/45.png").convert("RGB")
inputs = proc(images=x, return_tensors="pt") # {'pixel_values': B,C,H,W}
organ_id = torch.tensor([4])  # any valid organ id < cfg.n_organs
with torch.no_grad():
    out = model(**inputs, organ_id=organ_id)

# Post-process: undo letterbox & resize back to original, with threshold 0.7
masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True)

# Save the first (since you used a single image, that'll be masks[0])
masks[0].save("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp.png")