import os import subprocess # --- Ensure SAM checkpoint is present --- SAM_CHECKPOINT = "sam_vit_h_4b8939.pth" if not os.path.exists(SAM_CHECKPOINT): print("Downloading SAM checkpoint...") subprocess.run([ "wget", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" ]) # ---------------------------------------- import gradio as gr import numpy as np from PIL import Image import cv2 import torch from segment_anything import sam_model_registry, SamPredictor # --- CONFIG --- SAM_CHECKPOINT = "sam_vit_h_4b8939.pth" SAM_MODEL_TYPE = "vit_h" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BLUR_RADIUS = 10 # -------------- # Load SAM once sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT) sam.to(device=DEVICE) predictor = SamPredictor(sam) def soft_alpha(mask_uint8, blur_radius=10): blurred = cv2.GaussianBlur(mask_uint8, (0,0), sigmaX=blur_radius, sigmaY=blur_radius) return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0) def make_overlay(img_rgb: np.ndarray, mask_uint8: np.ndarray) -> Image.Image: """Return a purple-tinted overlay on the selected region.""" # Ensure 0/255 mask and uint8 image mask = (mask_uint8 > 127).astype(np.uint8) overlay = img_rgb.copy() # Create a purple layer purple = np.zeros_like(img_rgb, dtype=np.uint8) purple[..., 0] = 180 # R purple[..., 1] = 0 # G purple[..., 2] = 180 # B # Blend only on selected pixels sel = mask.astype(bool) blended = cv2.addWeighted(overlay[sel], 0.6, purple[sel], 0.4, 0) overlay[sel] = blended return Image.fromarray(overlay) def isolate_with_click(image: Image.Image, evt: gr.SelectData): # Guard: if no image or no click event, return original image and a subtle info overlay if image is None or evt is None: return None, None img_rgb = np.array(image.convert("RGB")) predictor.set_image(img_rgb) # SAM expects input points as numpy array [[x,y]] x, y = evt.index # (x, y) from Gradio click input_point = np.array([[x, y]], dtype=np.float32) input_label = np.array([1], dtype=np.int32) # 1 = foreground masks, scores, _ = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True ) # If SAM didn't return masks, show original plus a faint marker if masks is None or len(masks) == 0: # Create a simple marker overlay to indicate click overlay = img_rgb.copy() cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3) return None, Image.fromarray(overlay) # Pick the highest score mask best_idx = int(np.argmax(scores)) best_mask = masks[best_idx].astype(np.uint8) * 255 # Soft alpha for RGBA cutout alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS) # Crop to bounding box (with small pad) ys, xs = np.where(best_mask == 255) if len(xs) == 0 or len(ys) == 0: # If mask is empty, still return overlay with click marker overlay = img_rgb.copy() cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3) return None, Image.fromarray(overlay) x0, x1 = xs.min(), xs.max() y0, y1 = ys.min(), ys.max() pad = int(max(img_rgb.shape[:2]) * 0.02) x0 = max(0, x0 - pad); x1 = min(img_rgb.shape[1]-1, x1 + pad) y0 = max(0, y0 - pad); y1 = min(img_rgb.shape[0]-1, y1 + pad) fg_rgb = img_rgb[y0:y1+1, x0:x1+1] fg_alpha = alpha[y0:y1+1, x0:x1+1] # Compose RGBA correctly once rgba = np.dstack((fg_rgb, (fg_alpha * 255.0).astype(np.uint8))) cutout = Image.fromarray(rgba) # Build and return the purple overlay on the original image overlay_img = make_overlay(img_rgb, best_mask) return cutout, overlay_img # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("### SAM Object Isolation\nUpload an image (or pick the demo), then click on the object to isolate it. The last panel shows a purple overlay of the mask.") inp = gr.Image(type="pil", label="Upload image", interactive=True) out_cutout = gr.Image(type="pil", label="Isolated cutout (RGBA)") out_overlay = gr.Image(type="pil", label="Segmentation overlay preview") # Click-to-segment inp.select(isolate_with_click, inputs=[inp], outputs=[out_cutout, out_overlay]) # Demo example at the bottom that populates the upload image gr.Examples( examples=["demo.png"], # ensure demo.png is in your repo inputs=inp, label="Try with demo image" ) demo.launch(share=True)