Spaces:
Sleeping
Sleeping
| import os | |
| import subprocess | |
| # --- Ensure SAM checkpoint is present --- | |
| SAM_CHECKPOINT = "sam_vit_h_4b8939.pth" | |
| if not os.path.exists(SAM_CHECKPOINT): | |
| print("Downloading SAM checkpoint...") | |
| subprocess.run([ | |
| "wget", | |
| "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
| ]) | |
| # ---------------------------------------- | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import torch | |
| from segment_anything import sam_model_registry, SamPredictor | |
| # --- CONFIG --- | |
| SAM_CHECKPOINT = "sam_vit_h_4b8939.pth" | |
| SAM_MODEL_TYPE = "vit_h" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| BLUR_RADIUS = 10 | |
| # -------------- | |
| # Load SAM once | |
| sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT) | |
| sam.to(device=DEVICE) | |
| predictor = SamPredictor(sam) | |
| def soft_alpha(mask_uint8, blur_radius=10): | |
| blurred = cv2.GaussianBlur(mask_uint8, (0,0), sigmaX=blur_radius, sigmaY=blur_radius) | |
| return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0) | |
| def make_overlay(img_rgb: np.ndarray, mask_uint8: np.ndarray) -> Image.Image: | |
| """Return a purple-tinted overlay on the selected region.""" | |
| # Ensure 0/255 mask and uint8 image | |
| mask = (mask_uint8 > 127).astype(np.uint8) | |
| overlay = img_rgb.copy() | |
| # Create a purple layer | |
| purple = np.zeros_like(img_rgb, dtype=np.uint8) | |
| purple[..., 0] = 180 # R | |
| purple[..., 1] = 0 # G | |
| purple[..., 2] = 180 # B | |
| # Blend only on selected pixels | |
| sel = mask.astype(bool) | |
| blended = cv2.addWeighted(overlay[sel], 0.6, purple[sel], 0.4, 0) | |
| overlay[sel] = blended | |
| return Image.fromarray(overlay) | |
| def isolate_with_click(image: Image.Image, evt: gr.SelectData): | |
| # Guard: if no image or no click event, return original image and a subtle info overlay | |
| if image is None or evt is None: | |
| return None, None | |
| img_rgb = np.array(image.convert("RGB")) | |
| predictor.set_image(img_rgb) | |
| # SAM expects input points as numpy array [[x,y]] | |
| x, y = evt.index # (x, y) from Gradio click | |
| input_point = np.array([[x, y]], dtype=np.float32) | |
| input_label = np.array([1], dtype=np.int32) # 1 = foreground | |
| masks, scores, _ = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=True | |
| ) | |
| # If SAM didn't return masks, show original plus a faint marker | |
| if masks is None or len(masks) == 0: | |
| # Create a simple marker overlay to indicate click | |
| overlay = img_rgb.copy() | |
| cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3) | |
| return None, Image.fromarray(overlay) | |
| # Pick the highest score mask | |
| best_idx = int(np.argmax(scores)) | |
| best_mask = masks[best_idx].astype(np.uint8) * 255 | |
| # Soft alpha for RGBA cutout | |
| alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS) | |
| # Crop to bounding box (with small pad) | |
| ys, xs = np.where(best_mask == 255) | |
| if len(xs) == 0 or len(ys) == 0: | |
| # If mask is empty, still return overlay with click marker | |
| overlay = img_rgb.copy() | |
| cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3) | |
| return None, Image.fromarray(overlay) | |
| x0, x1 = xs.min(), xs.max() | |
| y0, y1 = ys.min(), ys.max() | |
| pad = int(max(img_rgb.shape[:2]) * 0.02) | |
| x0 = max(0, x0 - pad); x1 = min(img_rgb.shape[1]-1, x1 + pad) | |
| y0 = max(0, y0 - pad); y1 = min(img_rgb.shape[0]-1, y1 + pad) | |
| fg_rgb = img_rgb[y0:y1+1, x0:x1+1] | |
| fg_alpha = alpha[y0:y1+1, x0:x1+1] | |
| # Compose RGBA correctly once | |
| rgba = np.dstack((fg_rgb, (fg_alpha * 255.0).astype(np.uint8))) | |
| cutout = Image.fromarray(rgba) | |
| # Build and return the purple overlay on the original image | |
| overlay_img = make_overlay(img_rgb, best_mask) | |
| return cutout, overlay_img | |
| # --- Gradio UI --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("### SAM Object Isolation\nUpload an image (or pick the demo), then click on the object to isolate it. The last panel shows a purple overlay of the mask.") | |
| inp = gr.Image(type="pil", label="Upload image", interactive=True) | |
| out_cutout = gr.Image(type="pil", label="Isolated cutout (RGBA)") | |
| out_overlay = gr.Image(type="pil", label="Segmentation overlay preview") | |
| # Click-to-segment | |
| inp.select(isolate_with_click, inputs=[inp], outputs=[out_cutout, out_overlay]) | |
| # Demo example at the bottom that populates the upload image | |
| gr.Examples( | |
| examples=["demo.png"], # ensure demo.png is in your repo | |
| inputs=inp, | |
| label="Try with demo image" | |
| ) | |
| demo.launch(share=True) | |