Spaces:
Sleeping
Sleeping
File size: 4,598 Bytes
3beaad5 04e075f 986380a 04e075f 986380a 04e075f 20cd6c7 04e075f 20cd6c7 04e075f 20cd6c7 04e075f b6780bc 04e075f 20cd6c7 04e075f 20cd6c7 04e075f 20cd6c7 04e075f 20cd6c7 b6780bc 20cd6c7 b6780bc 04e075f aee8f1e 04e075f 776dc39 04e075f b6780bc aee8f1e 20cd6c7 b6780bc 20cd6c7 aee8f1e 20cd6c7 aee8f1e 04e075f ebc73a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import subprocess
# --- Ensure SAM checkpoint is present ---
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
if not os.path.exists(SAM_CHECKPOINT):
print("Downloading SAM checkpoint...")
subprocess.run([
"wget",
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
])
# ----------------------------------------
import gradio as gr
import numpy as np
from PIL import Image
import cv2
import torch
from segment_anything import sam_model_registry, SamPredictor
# --- CONFIG ---
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
SAM_MODEL_TYPE = "vit_h"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BLUR_RADIUS = 10
# --------------
# Load SAM once
sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
sam.to(device=DEVICE)
predictor = SamPredictor(sam)
def soft_alpha(mask_uint8, blur_radius=10):
blurred = cv2.GaussianBlur(mask_uint8, (0,0), sigmaX=blur_radius, sigmaY=blur_radius)
return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0)
def make_overlay(img_rgb: np.ndarray, mask_uint8: np.ndarray) -> Image.Image:
"""Return a purple-tinted overlay on the selected region."""
# Ensure 0/255 mask and uint8 image
mask = (mask_uint8 > 127).astype(np.uint8)
overlay = img_rgb.copy()
# Create a purple layer
purple = np.zeros_like(img_rgb, dtype=np.uint8)
purple[..., 0] = 180 # R
purple[..., 1] = 0 # G
purple[..., 2] = 180 # B
# Blend only on selected pixels
sel = mask.astype(bool)
blended = cv2.addWeighted(overlay[sel], 0.6, purple[sel], 0.4, 0)
overlay[sel] = blended
return Image.fromarray(overlay)
def isolate_with_click(image: Image.Image, evt: gr.SelectData):
# Guard: if no image or no click event, return original image and a subtle info overlay
if image is None or evt is None:
return None, None
img_rgb = np.array(image.convert("RGB"))
predictor.set_image(img_rgb)
# SAM expects input points as numpy array [[x,y]]
x, y = evt.index # (x, y) from Gradio click
input_point = np.array([[x, y]], dtype=np.float32)
input_label = np.array([1], dtype=np.int32) # 1 = foreground
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
# If SAM didn't return masks, show original plus a faint marker
if masks is None or len(masks) == 0:
# Create a simple marker overlay to indicate click
overlay = img_rgb.copy()
cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3)
return None, Image.fromarray(overlay)
# Pick the highest score mask
best_idx = int(np.argmax(scores))
best_mask = masks[best_idx].astype(np.uint8) * 255
# Soft alpha for RGBA cutout
alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS)
# Crop to bounding box (with small pad)
ys, xs = np.where(best_mask == 255)
if len(xs) == 0 or len(ys) == 0:
# If mask is empty, still return overlay with click marker
overlay = img_rgb.copy()
cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3)
return None, Image.fromarray(overlay)
x0, x1 = xs.min(), xs.max()
y0, y1 = ys.min(), ys.max()
pad = int(max(img_rgb.shape[:2]) * 0.02)
x0 = max(0, x0 - pad); x1 = min(img_rgb.shape[1]-1, x1 + pad)
y0 = max(0, y0 - pad); y1 = min(img_rgb.shape[0]-1, y1 + pad)
fg_rgb = img_rgb[y0:y1+1, x0:x1+1]
fg_alpha = alpha[y0:y1+1, x0:x1+1]
# Compose RGBA correctly once
rgba = np.dstack((fg_rgb, (fg_alpha * 255.0).astype(np.uint8)))
cutout = Image.fromarray(rgba)
# Build and return the purple overlay on the original image
overlay_img = make_overlay(img_rgb, best_mask)
return cutout, overlay_img
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("### SAM Object Isolation\nUpload an image (or pick the demo), then click on the object to isolate it. The last panel shows a purple overlay of the mask.")
inp = gr.Image(type="pil", label="Upload image", interactive=True)
out_cutout = gr.Image(type="pil", label="Isolated cutout (RGBA)")
out_overlay = gr.Image(type="pil", label="Segmentation overlay preview")
# Click-to-segment
inp.select(isolate_with_click, inputs=[inp], outputs=[out_cutout, out_overlay])
# Demo example at the bottom that populates the upload image
gr.Examples(
examples=["demo.png"], # ensure demo.png is in your repo
inputs=inp,
label="Try with demo image"
)
demo.launch(share=True)
|