lucid-hf's picture
CI: deploy Docker/PDM Space
98a3af2 verified
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import importlib.metadata
from torch import Tensor
if '0.15.2' in importlib.metadata.version('torchvision'):
import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.datapoints import BoundingBox as BoundingBoxes
from torchvision.datapoints import BoundingBoxFormat, Mask, Image, Video
from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes
_boxes_keys = ['format', 'spatial_size']
elif '0.17' > importlib.metadata.version('torchvision') >= '0.16':
import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.transforms.v2 import SanitizeBoundingBoxes
from torchvision.tv_tensors import (
BoundingBoxes, BoundingBoxFormat, Mask, Image, Video)
_boxes_keys = ['format', 'canvas_size']
elif importlib.metadata.version('torchvision') >= '0.17':
import torchvision
from torchvision.transforms.v2 import SanitizeBoundingBoxes
from torchvision.tv_tensors import (
BoundingBoxes, BoundingBoxFormat, Mask, Image, Video)
_boxes_keys = ['format', 'canvas_size']
else:
raise RuntimeError('Please make sure torchvision version >= 0.15.2')
def convert_to_tv_tensor(tensor: Tensor, key: str, box_format='xyxy', spatial_size=None) -> Tensor:
"""
Args:
tensor (Tensor): input tensor
key (str): transform to key
Return:
Dict[str, TV_Tensor]
"""
assert key in ('boxes', 'masks', ), "Only support 'boxes' and 'masks'"
if key == 'boxes':
box_format = getattr(BoundingBoxFormat, box_format.upper())
_kwargs = dict(zip(_boxes_keys, [box_format, spatial_size]))
return BoundingBoxes(tensor, **_kwargs)
if key == 'masks':
return Mask(tensor)