| # test_load_film_unet2d.py | |
| import torch, os | |
| from transformers import AutoModel, AutoConfig, AutoImageProcessor | |
| # ✅ point to your local folder (or your HF repo id after pushing) | |
| repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers") | |
| print("Loading config...") | |
| cfg = AutoConfig.from_pretrained(repo_or_path, trust_remote_code=True) | |
| print(cfg) | |
| print("Loading model and weights...") | |
| proc = AutoImageProcessor.from_pretrained(repo_or_path, trust_remote_code=True) | |
| model = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True) | |
| model.eval() | |
| # --- quick synthetic forward --- | |
| # x = torch.randn(1, cfg.in_channels, 512, 512) | |
| from PIL import Image | |
| from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image | |
| x = Image.open("/home/nicola/Downloads/45.png").convert("RGB") | |
| inputs = proc(images=x, return_tensors="pt") # {'pixel_values': B,C,H,W} | |
| organ_id = torch.tensor([4]) # any valid organ id < cfg.n_organs | |
| with torch.no_grad(): | |
| out = model(**inputs, organ_id=organ_id) | |
| # Post-process: undo letterbox & resize back to original, with threshold 0.7 | |
| masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True) | |
| # Save the first (since you used a single image, that'll be masks[0]) | |
| masks[0].save("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp.png") | |