gpu_symbol / engine /misc /visualizer.py
himipo's picture
first
11aa70b
""""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import PIL
import torch
import torch.utils.data
import torchvision
torchvision.disable_beta_transforms_warning()
__all__ = ['show_sample']
def show_sample(sample):
"""for coco dataset/dataloader
"""
import matplotlib.pyplot as plt
from torchvision.transforms.v2 import functional as F
from torchvision.utils import draw_bounding_boxes
image, target = sample
if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image)
image = F.convert_dtype(image, torch.uint8)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
fig, ax = plt.subplots()
ax.imshow(annotated_image.permute(1, 2, 0).numpy())
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
fig.tight_layout()
fig.show()
plt.show()