AItool's picture
Update app.py
776dc39 verified
import os
import subprocess
# --- Ensure SAM checkpoint is present ---
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
if not os.path.exists(SAM_CHECKPOINT):
print("Downloading SAM checkpoint...")
subprocess.run([
"wget",
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
])
# ----------------------------------------
import gradio as gr
import numpy as np
from PIL import Image
import cv2
import torch
from segment_anything import sam_model_registry, SamPredictor
# --- CONFIG ---
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
SAM_MODEL_TYPE = "vit_h"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BLUR_RADIUS = 10
# --------------
# Load SAM once
sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
sam.to(device=DEVICE)
predictor = SamPredictor(sam)
def soft_alpha(mask_uint8, blur_radius=10):
blurred = cv2.GaussianBlur(mask_uint8, (0,0), sigmaX=blur_radius, sigmaY=blur_radius)
return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0)
def make_overlay(img_rgb: np.ndarray, mask_uint8: np.ndarray) -> Image.Image:
"""Return a purple-tinted overlay on the selected region."""
# Ensure 0/255 mask and uint8 image
mask = (mask_uint8 > 127).astype(np.uint8)
overlay = img_rgb.copy()
# Create a purple layer
purple = np.zeros_like(img_rgb, dtype=np.uint8)
purple[..., 0] = 180 # R
purple[..., 1] = 0 # G
purple[..., 2] = 180 # B
# Blend only on selected pixels
sel = mask.astype(bool)
blended = cv2.addWeighted(overlay[sel], 0.6, purple[sel], 0.4, 0)
overlay[sel] = blended
return Image.fromarray(overlay)
def isolate_with_click(image: Image.Image, evt: gr.SelectData):
# Guard: if no image or no click event, return original image and a subtle info overlay
if image is None or evt is None:
return None, None
img_rgb = np.array(image.convert("RGB"))
predictor.set_image(img_rgb)
# SAM expects input points as numpy array [[x,y]]
x, y = evt.index # (x, y) from Gradio click
input_point = np.array([[x, y]], dtype=np.float32)
input_label = np.array([1], dtype=np.int32) # 1 = foreground
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
# If SAM didn't return masks, show original plus a faint marker
if masks is None or len(masks) == 0:
# Create a simple marker overlay to indicate click
overlay = img_rgb.copy()
cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3)
return None, Image.fromarray(overlay)
# Pick the highest score mask
best_idx = int(np.argmax(scores))
best_mask = masks[best_idx].astype(np.uint8) * 255
# Soft alpha for RGBA cutout
alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS)
# Crop to bounding box (with small pad)
ys, xs = np.where(best_mask == 255)
if len(xs) == 0 or len(ys) == 0:
# If mask is empty, still return overlay with click marker
overlay = img_rgb.copy()
cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3)
return None, Image.fromarray(overlay)
x0, x1 = xs.min(), xs.max()
y0, y1 = ys.min(), ys.max()
pad = int(max(img_rgb.shape[:2]) * 0.02)
x0 = max(0, x0 - pad); x1 = min(img_rgb.shape[1]-1, x1 + pad)
y0 = max(0, y0 - pad); y1 = min(img_rgb.shape[0]-1, y1 + pad)
fg_rgb = img_rgb[y0:y1+1, x0:x1+1]
fg_alpha = alpha[y0:y1+1, x0:x1+1]
# Compose RGBA correctly once
rgba = np.dstack((fg_rgb, (fg_alpha * 255.0).astype(np.uint8)))
cutout = Image.fromarray(rgba)
# Build and return the purple overlay on the original image
overlay_img = make_overlay(img_rgb, best_mask)
return cutout, overlay_img
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("### SAM Object Isolation\nUpload an image (or pick the demo), then click on the object to isolate it. The last panel shows a purple overlay of the mask.")
inp = gr.Image(type="pil", label="Upload image", interactive=True)
out_cutout = gr.Image(type="pil", label="Isolated cutout (RGBA)")
out_overlay = gr.Image(type="pil", label="Segmentation overlay preview")
# Click-to-segment
inp.select(isolate_with_click, inputs=[inp], outputs=[out_cutout, out_overlay])
# Demo example at the bottom that populates the upload image
gr.Examples(
examples=["demo.png"], # ensure demo.png is in your repo
inputs=inp,
label="Try with demo image"
)
demo.launch(share=True)