File size: 4,598 Bytes
3beaad5
 
 
 
 
 
 
 
 
 
 
 
04e075f
 
 
 
 
 
 
 
986380a
04e075f
 
 
 
 
986380a
04e075f
 
 
 
 
 
 
 
20cd6c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04e075f
20cd6c7
 
 
 
04e075f
 
 
20cd6c7
 
 
 
04e075f
b6780bc
04e075f
 
 
 
 
20cd6c7
 
 
 
 
 
 
 
 
 
 
 
04e075f
 
20cd6c7
04e075f
 
20cd6c7
 
 
 
 
04e075f
 
 
 
 
 
 
 
 
20cd6c7
 
b6780bc
 
20cd6c7
 
b6780bc
 
04e075f
aee8f1e
04e075f
776dc39
04e075f
b6780bc
 
aee8f1e
20cd6c7
b6780bc
 
20cd6c7
aee8f1e
20cd6c7
aee8f1e
 
 
04e075f
ebc73a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import subprocess

# --- Ensure SAM checkpoint is present ---
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
if not os.path.exists(SAM_CHECKPOINT):
    print("Downloading SAM checkpoint...")
    subprocess.run([
        "wget",
        "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
    ])
# ----------------------------------------
import gradio as gr
import numpy as np
from PIL import Image
import cv2
import torch
from segment_anything import sam_model_registry, SamPredictor

# --- CONFIG ---
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
SAM_MODEL_TYPE = "vit_h"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BLUR_RADIUS = 10
# --------------

# Load SAM once
sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
sam.to(device=DEVICE)
predictor = SamPredictor(sam)

def soft_alpha(mask_uint8, blur_radius=10):
    blurred = cv2.GaussianBlur(mask_uint8, (0,0), sigmaX=blur_radius, sigmaY=blur_radius)
    return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0)

def make_overlay(img_rgb: np.ndarray, mask_uint8: np.ndarray) -> Image.Image:
    """Return a purple-tinted overlay on the selected region."""
    # Ensure 0/255 mask and uint8 image
    mask = (mask_uint8 > 127).astype(np.uint8)
    overlay = img_rgb.copy()

    # Create a purple layer
    purple = np.zeros_like(img_rgb, dtype=np.uint8)
    purple[..., 0] = 180  # R
    purple[..., 1] = 0    # G
    purple[..., 2] = 180  # B

    # Blend only on selected pixels
    sel = mask.astype(bool)
    blended = cv2.addWeighted(overlay[sel], 0.6, purple[sel], 0.4, 0)
    overlay[sel] = blended

    return Image.fromarray(overlay)

def isolate_with_click(image: Image.Image, evt: gr.SelectData):
    # Guard: if no image or no click event, return original image and a subtle info overlay
    if image is None or evt is None:
        return None, None

    img_rgb = np.array(image.convert("RGB"))
    predictor.set_image(img_rgb)

    # SAM expects input points as numpy array [[x,y]]
    x, y = evt.index  # (x, y) from Gradio click
    input_point = np.array([[x, y]], dtype=np.float32)
    input_label = np.array([1], dtype=np.int32)  # 1 = foreground

    masks, scores, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True
    )

    # If SAM didn't return masks, show original plus a faint marker
    if masks is None or len(masks) == 0:
        # Create a simple marker overlay to indicate click
        overlay = img_rgb.copy()
        cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3)
        return None, Image.fromarray(overlay)

    # Pick the highest score mask
    best_idx = int(np.argmax(scores))
    best_mask = masks[best_idx].astype(np.uint8) * 255

    # Soft alpha for RGBA cutout
    alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS)

    # Crop to bounding box (with small pad)
    ys, xs = np.where(best_mask == 255)
    if len(xs) == 0 or len(ys) == 0:
        # If mask is empty, still return overlay with click marker
        overlay = img_rgb.copy()
        cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3)
        return None, Image.fromarray(overlay)

    x0, x1 = xs.min(), xs.max()
    y0, y1 = ys.min(), ys.max()
    pad = int(max(img_rgb.shape[:2]) * 0.02)
    x0 = max(0, x0 - pad); x1 = min(img_rgb.shape[1]-1, x1 + pad)
    y0 = max(0, y0 - pad); y1 = min(img_rgb.shape[0]-1, y1 + pad)

    fg_rgb = img_rgb[y0:y1+1, x0:x1+1]
    fg_alpha = alpha[y0:y1+1, x0:x1+1]

    # Compose RGBA correctly once
    rgba = np.dstack((fg_rgb, (fg_alpha * 255.0).astype(np.uint8)))
    cutout = Image.fromarray(rgba)

    # Build and return the purple overlay on the original image
    overlay_img = make_overlay(img_rgb, best_mask)

    return cutout, overlay_img

# --- Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("### SAM Object Isolation\nUpload an image (or pick the demo), then click on the object to isolate it. The last panel shows a purple overlay of the mask.")
    inp = gr.Image(type="pil", label="Upload image", interactive=True)
    out_cutout = gr.Image(type="pil", label="Isolated cutout (RGBA)")
    out_overlay = gr.Image(type="pil", label="Segmentation overlay preview")

    # Click-to-segment
    inp.select(isolate_with_click, inputs=[inp], outputs=[out_cutout, out_overlay])

    # Demo example at the bottom that populates the upload image
    gr.Examples(
        examples=["demo.png"],   # ensure demo.png is in your repo
        inputs=inp,
        label="Try with demo image"
    )

demo.launch(share=True)