Spaces:
Sleeping
Sleeping
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import importlib.metadata | |
| from torch import Tensor | |
| if '0.15.2' in importlib.metadata.version('torchvision'): | |
| import torchvision | |
| torchvision.disable_beta_transforms_warning() | |
| from torchvision.datapoints import BoundingBox as BoundingBoxes | |
| from torchvision.datapoints import BoundingBoxFormat, Mask, Image, Video | |
| from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes | |
| _boxes_keys = ['format', 'spatial_size'] | |
| elif '0.17' > importlib.metadata.version('torchvision') >= '0.16': | |
| import torchvision | |
| torchvision.disable_beta_transforms_warning() | |
| from torchvision.transforms.v2 import SanitizeBoundingBoxes | |
| from torchvision.tv_tensors import ( | |
| BoundingBoxes, BoundingBoxFormat, Mask, Image, Video) | |
| _boxes_keys = ['format', 'canvas_size'] | |
| elif importlib.metadata.version('torchvision') >= '0.17': | |
| import torchvision | |
| from torchvision.transforms.v2 import SanitizeBoundingBoxes | |
| from torchvision.tv_tensors import ( | |
| BoundingBoxes, BoundingBoxFormat, Mask, Image, Video) | |
| _boxes_keys = ['format', 'canvas_size'] | |
| else: | |
| raise RuntimeError('Please make sure torchvision version >= 0.15.2') | |
| def convert_to_tv_tensor(tensor: Tensor, key: str, box_format='xyxy', spatial_size=None) -> Tensor: | |
| """ | |
| Args: | |
| tensor (Tensor): input tensor | |
| key (str): transform to key | |
| Return: | |
| Dict[str, TV_Tensor] | |
| """ | |
| assert key in ('boxes', 'masks', ), "Only support 'boxes' and 'masks'" | |
| if key == 'boxes': | |
| box_format = getattr(BoundingBoxFormat, box_format.upper()) | |
| _kwargs = dict(zip(_boxes_keys, [box_format, spatial_size])) | |
| return BoundingBoxes(tensor, **_kwargs) | |
| if key == 'masks': | |
| return Mask(tensor) | |