Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image, ImageOps | |
| import torch | |
| from inference import SegmentPredictor, DepthPredictor | |
| from utils import generate_PCL, PCL3, point_cloud | |
| sam = SegmentPredictor() | |
| sam_cpu = SegmentPredictor(device="cpu") | |
| dpt = DepthPredictor() | |
| red = (255, 0, 0) | |
| blue = (0, 0, 255) | |
| annos = [] | |
| block = gr.Blocks() | |
| with block: | |
| # States | |
| def point_coords_empty(): | |
| return [] | |
| def point_labels_empty(): | |
| return [] | |
| image_edit_trigger = gr.State(True) | |
| point_coords = gr.State(point_coords_empty) | |
| point_labels = gr.State(point_labels_empty) | |
| masks = gr.State([]) | |
| cutout_idx = gr.State(set()) | |
| pred_masks = gr.State([]) | |
| prompt_masks = gr.State([]) | |
| embedding = gr.State() | |
| # UI | |
| with gr.Column(): | |
| gr.Markdown( | |
| """# Segment Anything Model (SAM) | |
| ## a new AI model from Meta AI that can "cut out" any object, in any image, with a single click 🚀 | |
| SAM is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training. [**Official Project**](https://segment-anything.com/) [**Code**](https://github.com/facebookresearch/segment-anything). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tab("Upload Image"): | |
| # mirror_webcam = False | |
| upload_image = gr.Image(label="Input", type="pil", tool=None) | |
| with gr.Tab("Webcam"): | |
| # mirror_webcam = False | |
| input_image = gr.Image( | |
| label="Input", type="pil", tool=None, source="webcam" | |
| ) | |
| with gr.Row(): | |
| sam_encode_btn = gr.Button("Encode", variant="primary") | |
| sam_sgmt_everything_btn = gr.Button( | |
| "Segment Everything!", variant="primary" | |
| ) | |
| # sam_encode_status = gr.Label('Not encoded yet') | |
| with gr.Row(): | |
| prompt_image = gr.Image(label="Segments") | |
| # prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels') | |
| lbl_image = gr.AnnotatedImage(label="Everything") | |
| with gr.Row(): | |
| point_label_radio = gr.Radio(label="Point Label", choices=[1, 0], value=1) | |
| text = gr.Textbox(label="Mask Name") | |
| reset_btn = gr.Button("New Mask") | |
| selected_masks_image = gr.AnnotatedImage(label="Selected Masks") | |
| with gr.Row(): | |
| with gr.Column(): | |
| pcl_figure = gr.Model3D( | |
| label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0] | |
| ) | |
| with gr.Row(): | |
| max_depth = gr.Slider( | |
| minimum=0, maximum=10, value=3, step=0.01, label="Max Depth" | |
| ) | |
| min_depth = gr.Slider( | |
| minimum=0, maximum=10, step=0.01, value=1, label="Min Depth" | |
| ) | |
| n_samples = gr.Slider( | |
| minimum=1e3, | |
| maximum=1e6, | |
| step=1e3, | |
| value=1e5, | |
| label="Number of Samples", | |
| ) | |
| cube_size = gr.Slider( | |
| minimum=0.00001, | |
| maximum=0.001, | |
| step=0.000001, | |
| default=0.00001, | |
| label="Cube size", | |
| ) | |
| depth_reconstruction_btn = gr.Button( | |
| "3D Reconstruction", variant="primary" | |
| ) | |
| depth_reconstruction_mask_btn = gr.Button( | |
| "Mask Reconstruction", variant="primary" | |
| ) | |
| sam_decode_btn = gr.Button("Predict using points!", variant="primary") | |
| # components | |
| components = { | |
| point_coords, | |
| point_labels, | |
| image_edit_trigger, | |
| masks, | |
| cutout_idx, | |
| input_image, | |
| embedding, | |
| point_label_radio, | |
| text, | |
| reset_btn, | |
| sam_sgmt_everything_btn, | |
| sam_decode_btn, | |
| depth_reconstruction_btn, | |
| prompt_image, | |
| lbl_image, | |
| n_samples, | |
| max_depth, | |
| min_depth, | |
| cube_size, | |
| selected_masks_image, | |
| } | |
| def on_upload_image(input_image, upload_image): | |
| # Mirror because gradio.image webcam has mirror = True | |
| upload_image_mirror = ImageOps.mirror(upload_image) | |
| return [upload_image_mirror, upload_image] | |
| upload_image.upload( | |
| on_upload_image, [input_image, upload_image], [input_image, upload_image] | |
| ) | |
| # event - init coords | |
| def on_reset_btn_click(input_image): | |
| return input_image, point_coords_empty(), point_labels_empty(), None, [] | |
| reset_btn.click( | |
| on_reset_btn_click, | |
| [input_image], | |
| [prompt_image, point_coords, point_labels], | |
| queue=False, | |
| ) | |
| def on_prompt_image_select( | |
| input_image, | |
| prompt_image, | |
| point_coords, | |
| point_labels, | |
| point_label_radio, | |
| text, | |
| pred_masks, | |
| embedding, | |
| evt: gr.SelectData, | |
| ): | |
| sam_cpu.dummy_encode(input_image) | |
| x, y = evt.index | |
| color = red if point_label_radio == 0 else blue | |
| if prompt_image is None: | |
| prompt_image = np.array(input_image.copy()) | |
| cv2.circle(prompt_image, (x, y), 5, color, -1) | |
| point_coords.append([x, y]) | |
| point_labels.append(point_label_radio) | |
| sam_masks = sam_cpu.cond_pred( | |
| pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding | |
| ) | |
| return [ | |
| prompt_image, | |
| (input_image, sam_masks), | |
| point_coords, | |
| point_labels, | |
| sam_masks, | |
| ] | |
| prompt_image.select( | |
| on_prompt_image_select, | |
| [ | |
| input_image, | |
| prompt_image, | |
| point_coords, | |
| point_labels, | |
| point_label_radio, | |
| text, | |
| pred_masks, | |
| embedding, | |
| ], | |
| [prompt_image, lbl_image, point_coords, point_labels, pred_masks], | |
| queue=True, | |
| ) | |
| def on_everything_image_select( | |
| input_image, pred_masks, masks, text, evt: gr.SelectData | |
| ): | |
| i = evt.index | |
| mask = pred_masks[i][0] | |
| print(mask) | |
| print(type(mask)) | |
| masks.append((mask, text)) | |
| anno = (input_image, masks) | |
| return [masks, anno] | |
| lbl_image.select( | |
| on_everything_image_select, | |
| [input_image, pred_masks, masks, text], | |
| [masks, selected_masks_image], | |
| queue=False, | |
| ) | |
| def on_selected_masks_image_select(input_image, masks, evt: gr.SelectData): | |
| i = evt.index | |
| del masks[i] | |
| anno = (input_image, masks) | |
| return [masks, anno] | |
| selected_masks_image.select( | |
| on_selected_masks_image_select, | |
| [input_image, masks], | |
| [masks, selected_masks_image], | |
| queue=False, | |
| ) | |
| # prompt_lbl_image.select(on_everything_image_select, | |
| # [input_image, prompt_masks, masks, text], | |
| # [masks, selected_masks_image], queue=False) | |
| def on_click_sam_encode_btn(inputs): | |
| print("encoding") | |
| # encode image on click | |
| embedding = sam.encode(inputs[input_image]).cpu() | |
| sam_cpu.dummy_encode(inputs[input_image]) | |
| print("encoding done") | |
| return [inputs[input_image], embedding] | |
| sam_encode_btn.click( | |
| on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False | |
| ) | |
| def on_click_sam_dencode_btn(inputs): | |
| print("inferencing") | |
| image = inputs[input_image] | |
| generated_mask, _, _ = sam.cond_pred( | |
| pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]) | |
| ) | |
| inputs[masks].append((generated_mask, inputs[text])) | |
| print(inputs[masks][0]) | |
| return {prompt_image: (image, inputs[masks])} | |
| sam_decode_btn.click( | |
| on_click_sam_dencode_btn, | |
| components, | |
| [prompt_image, masks, cutout_idx], | |
| queue=True, | |
| ) | |
| def on_depth_reconstruction_btn_click(inputs): | |
| print("depth reconstruction") | |
| path = dpt.generate_obj_rgb( | |
| image=inputs[input_image], | |
| cube_size=inputs[cube_size], | |
| n_samples=inputs[n_samples], | |
| # masks=inputs[masks], | |
| min_depth=inputs[min_depth], | |
| max_depth=inputs[max_depth], | |
| ) | |
| return {pcl_figure: path} | |
| depth_reconstruction_btn.click( | |
| on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False | |
| ) | |
| def on_depth_reconstruction_mask_btn_click(inputs): | |
| print("depth reconstruction") | |
| path = dpt.generate_obj_masks2( | |
| image=inputs[input_image], | |
| cube_size=inputs[cube_size], | |
| n_samples=inputs[n_samples], | |
| masks=inputs[masks], | |
| min_depth=inputs[min_depth], | |
| max_depth=inputs[max_depth], | |
| ) | |
| return {pcl_figure: path} | |
| depth_reconstruction_mask_btn.click( | |
| on_depth_reconstruction_mask_btn_click, components, [pcl_figure], queue=False | |
| ) | |
| def on_sam_sgmt_everything_btn_click(inputs): | |
| print("segmenting everything") | |
| image = inputs[input_image] | |
| sam_masks = sam.segment_everything(image) | |
| print(image) | |
| print(sam_masks) | |
| return [(image, sam_masks), sam_masks] | |
| sam_sgmt_everything_btn.click( | |
| on_sam_sgmt_everything_btn_click, | |
| components, | |
| [lbl_image, pred_masks], | |
| queue=True, | |
| ) | |
| if __name__ == "__main__": | |
| block.queue() | |
| block.launch(auth=("novouser", "bstad2023")) | |