ConvSeg / app-dev.py
aadarsh99's picture
release
c9aa521
import os
import logging
import hashlib
import sys
import traceback
import copy
import tempfile
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import gradio as gr
from PIL import Image, ImageFilter, ImageChops, ImageDraw
from huggingface_hub import hf_hub_download
import spaces
# --- IMPORT YOUR CUSTOM MODULES ---
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter
# ----------------- Configuration -----------------
SAM2_CONFIG = "sam2_hiera_l.yaml"
BASE_CKPT_NAME = "sam2_hiera_large.pt"
SQUARE_DIM = 1024
logging.basicConfig(level=logging.INFO)
# Refactored to store specific filenames per model choice
MODEL_CONFIGS = {
"Stage 1": {
"repo_id": "aadarsh99/ConvSeg-Stage1",
"sam_filename": "fine_tuned_sam2_batched_100000.torch",
"plm_filename": "fine_tuned_sam2_batched_plm_100000.torch"
},
"Stage 2 (grad-acc: 4)": {
"repo_id": "aadarsh99/ConvSeg-Stage2",
"sam_filename": "fine_tuned_sam2_batched_18000.torch",
"plm_filename": "fine_tuned_sam2_batched_plm_18000.torch"
},
"Stage 2 (grad-acc: 8)": {
"repo_id": "aadarsh99/ConvSeg-Stage2",
"sam_filename": "fine_tuned_sam2_batched_18000.torch",
"plm_filename": "fine_tuned_sam2_batched_plm_18000.torch"
}
}
# Dynamically create cache keys based on config
MODEL_CACHE = {k: {"sam": None, "plm": None} for k in MODEL_CONFIGS.keys()}
# ----------------- Helper Functions -----------------
def download_if_needed(repo_id, filename):
try:
logging.info(f"Checking {filename} in {repo_id}...")
return hf_hub_download(repo_id=repo_id, filename=filename)
except Exception as e:
raise FileNotFoundError(f"Could not find {filename} in {repo_id}. Error: {e}")
def stable_color(key: str):
h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
colors = [tuple(int(c.lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) for c in EDGE_COLORS_HEX]
return colors[h % len(colors)]
def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.Image:
# Convert base to RGBA
base = Image.fromarray(rgb.astype(np.uint8)).convert("RGBA")
mask_bool = mask > 0
color = stable_color(key)
# Create fill layer (Semi-transparent)
fill_layer = Image.new("RGBA", base.size, color + (0,))
fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * 140), "L")
fill_layer.putalpha(fill_alpha)
# Create stroke/edge layer
m = Image.fromarray((mask_bool.astype(np.uint8) * 255), "L")
edges = ImageChops.difference(m.filter(ImageFilter.MaxFilter(3)), m.filter(ImageFilter.MinFilter(3)))
stroke_layer = Image.new("RGBA", base.size, color + (255,))
stroke_layer.putalpha(edges)
# Composite safely
out = Image.alpha_composite(base, fill_layer)
out = Image.alpha_composite(out, stroke_layer)
return out.convert("RGB")
def ensure_models_loaded(stage_key):
global MODEL_CACHE
if MODEL_CACHE[stage_key]["sam"] is not None:
return
config = MODEL_CONFIGS[stage_key]
repo_id = config["repo_id"]
logging.info(f"Loading {stage_key} models from {repo_id} into CPU RAM...")
# SAM2
# Base model is always the same
base_path = download_if_needed(repo_id, BASE_CKPT_NAME)
model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
# Load specific fine-tuned checkpoint
final_path = download_if_needed(repo_id, config["sam_filename"])
sd = torch.load(final_path, map_location="cpu")
model.load_state_dict(sd.get("model", sd), strict=True)
# PLM
plm_path = download_if_needed(repo_id, config["plm_filename"])
plm = PLMLanguageAdapter(
model_name="Qwen/Qwen2.5-VL-3B-Instruct",
transformer_dim=model.sam_mask_decoder.transformer_dim,
n_sparse_tokens=0, use_dense_bias=True, use_lora=True,
lora_r=16, lora_alpha=32, lora_dropout=0.05,
dtype=torch.bfloat16, device="cpu"
)
plm_sd = torch.load(plm_path, map_location="cpu")
plm.load_state_dict(plm_sd["plm"], strict=True)
plm.eval()
MODEL_CACHE[stage_key]["sam"], MODEL_CACHE[stage_key]["plm"] = model, plm
# ----------------- GPU Inference -----------------
@spaces.GPU(duration=120)
def run_prediction(image_pil, text_prompt, threshold, stage_choice):
if image_pil is None or not text_prompt:
return None, None, None
ensure_models_loaded(stage_choice)
sam_model = MODEL_CACHE[stage_choice]["sam"]
plm_model = MODEL_CACHE[stage_choice]["plm"]
sam_model.to("cuda")
plm_model.to("cuda")
try:
with torch.inference_mode():
predictor = SAM2ImagePredictor(sam_model)
rgb_orig = np.array(image_pil.convert("RGB"))
H, W = rgb_orig.shape[:2]
# Padding math
scale = SQUARE_DIM / max(H, W)
nw, nh = int(W * scale), int(H * scale)
top, left = (SQUARE_DIM - nh) // 2, (SQUARE_DIM - nw) // 2
# Resize & Pad
rgb_sq = cv2.resize(rgb_orig, (nw, nh), interpolation=cv2.INTER_LINEAR)
rgb_sq = cv2.copyMakeBorder(rgb_sq, top, SQUARE_DIM-nh-top, left, SQUARE_DIM-nw-left, cv2.BORDER_CONSTANT, value=0)
predictor.set_image(rgb_sq)
image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
# PLM adapter
with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp:
image_pil.save(tmp.name)
sp, dp = plm_model([text_prompt], image_emb.shape[2], image_emb.shape[3], [tmp.name])
# SAM2 Decoding
dec = sam_model.sam_mask_decoder
dev, dtype = next(dec.parameters()).device, next(dec.parameters()).dtype
low, scores, _, _ = dec(
image_embeddings=image_emb.to(dev, dtype),
image_pe=sam_model.sam_prompt_encoder.get_dense_pe().to(dev, dtype),
sparse_prompt_embeddings=sp.to(dev, dtype),
dense_prompt_embeddings=dp.to(dev, dtype),
multimask_output=True, repeat_image=False,
high_res_features=[h.to(dev, dtype) for h in hi]
)
# Postprocess to original dimensions
logits = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
best_idx = scores.argmax().item()
logit_crop = logits[0, best_idx, top:top+nh, left:left+nw].unsqueeze(0).unsqueeze(0)
logit_full = F.interpolate(logit_crop, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
prob = torch.sigmoid(logit_full).float().cpu().numpy()
# Generate Heatmap
heatmap_cv = cv2.applyColorMap((prob * 255).astype(np.uint8), cv2.COLORMAP_JET)
heatmap_rgb = cv2.cvtColor(heatmap_cv, cv2.COLOR_BGR2RGB)
# Initial Overlay
mask = (prob > threshold).astype(np.uint8) * 255
overlay = make_overlay(rgb_orig, mask, key=text_prompt)
return overlay, Image.fromarray(heatmap_rgb), prob
except Exception:
traceback.print_exc()
return None, None, None
finally:
sam_model.to("cpu")
plm_model.to("cpu")
torch.cuda.empty_cache()
def update_threshold_ui(image_pil, text_prompt, threshold, cached_prob):
"""Instant update using CPU only."""
if image_pil is None or cached_prob is None:
return None
rgb_orig = np.array(image_pil.convert("RGB"))
mask = (cached_prob > threshold).astype(np.uint8) * 255
return make_overlay(rgb_orig, mask, key=text_prompt)
# ----------------- Gradio UI -----------------
with gr.Blocks(title="SAM2 + PLM Segmentation") as demo:
prob_state = gr.State()
gr.Markdown("# SAM2 + PLM Interactive Segmentation")
gr.Markdown("Select a stage, enter a prompt, and run. Adjust the slider for **instant** mask updates.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the surgical forceps'")
with gr.Row():
stage_select = gr.Radio(
choices=list(MODEL_CONFIGS.keys()),
value="Stage 2 (grad-acc: 8)",
label="Model Stage"
)
threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold")
run_btn = gr.Button("Run Inference", variant="primary")
with gr.Column():
out_overlay = gr.Image(label="Segmentation Overlay", type="pil")
out_heatmap = gr.Image(label="Probability Heatmap", type="pil")
# Full Pipeline
run_btn.click(
fn=run_prediction,
inputs=[input_image, text_prompt, threshold_slider, stage_select],
outputs=[out_overlay, out_heatmap, prob_state]
)
# Lightweight update on slider move
threshold_slider.change(
fn=update_threshold_ui,
inputs=[input_image, text_prompt, threshold_slider, prob_state],
outputs=[out_overlay]
)
if __name__ == "__main__":
demo.launch()