Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. | |
| """ | |
| import torch | |
| from .box_ops import box_xyxy_to_cxcywh | |
| def weighting_function(reg_max, up, reg_scale, deploy=False): | |
| """ | |
| Generates the non-uniform Weighting Function W(n) for bounding box regression. | |
| Args: | |
| reg_max (int): Max number of the discrete bins. | |
| up (Tensor): Controls upper bounds of the sequence, | |
| where maximum offset is ±up * H / W. | |
| reg_scale (float): Controls the curvature of the Weighting Function. | |
| Larger values result in flatter weights near the central axis W(reg_max/2)=0 | |
| and steeper weights at both ends. | |
| deploy (bool): If True, uses deployment mode settings. | |
| Returns: | |
| Tensor: Sequence of Weighting Function. | |
| """ | |
| if deploy: | |
| upper_bound1 = (abs(up[0]) * abs(reg_scale)).item() | |
| upper_bound2 = (abs(up[0]) * abs(reg_scale) * 2).item() | |
| step = (upper_bound1 + 1) ** (2 / (reg_max - 2)) | |
| left_values = [-(step) ** i + 1 for i in range(reg_max // 2 - 1, 0, -1)] | |
| right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)] | |
| values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2] | |
| return torch.tensor(values, dtype=up.dtype, device=up.device) | |
| else: | |
| upper_bound1 = abs(up[0]) * abs(reg_scale) | |
| upper_bound2 = abs(up[0]) * abs(reg_scale) * 2 | |
| step = (upper_bound1 + 1) ** (2 / (reg_max - 2)) | |
| left_values = [-(step) ** i + 1 for i in range(reg_max // 2 - 1, 0, -1)] | |
| right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)] | |
| values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2] | |
| return torch.cat(values, 0) | |
| def translate_gt(gt, reg_max, reg_scale, up): | |
| """ | |
| Decodes bounding box ground truth (GT) values into distribution-based GT representations. | |
| This function maps continuous GT values into discrete distribution bins, which can be used | |
| for regression tasks in object detection models. It calculates the indices of the closest | |
| bins to each GT value and assigns interpolation weights to these bins based on their proximity | |
| to the GT value. | |
| Args: | |
| gt (Tensor): Ground truth bounding box values, shape (N, ). | |
| reg_max (int): Maximum number of discrete bins for the distribution. | |
| reg_scale (float): Controls the curvature of the Weighting Function. | |
| up (Tensor): Controls the upper bounds of the Weighting Function. | |
| Returns: | |
| Tuple[Tensor, Tensor, Tensor]: | |
| - indices (Tensor): Index of the left bin closest to each GT value, shape (N, ). | |
| - weight_right (Tensor): Weight assigned to the right bin, shape (N, ). | |
| - weight_left (Tensor): Weight assigned to the left bin, shape (N, ). | |
| """ | |
| gt = gt.reshape(-1) | |
| function_values = weighting_function(reg_max, up, reg_scale) | |
| # Find the closest left-side indices for each value | |
| diffs = function_values.unsqueeze(0) - gt.unsqueeze(1) | |
| mask = diffs <= 0 | |
| closest_left_indices = torch.sum(mask, dim=1) - 1 | |
| # Calculate the weights for the interpolation | |
| indices = closest_left_indices.float() | |
| weight_right = torch.zeros_like(indices) | |
| weight_left = torch.zeros_like(indices) | |
| valid_idx_mask = (indices >= 0) & (indices < reg_max) | |
| valid_indices = indices[valid_idx_mask].long() | |
| # Obtain distances | |
| left_values = function_values[valid_indices] | |
| right_values = function_values[valid_indices + 1] | |
| left_diffs = torch.abs(gt[valid_idx_mask] - left_values) | |
| right_diffs = torch.abs(right_values - gt[valid_idx_mask]) | |
| # Valid weights | |
| weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs) | |
| weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask] | |
| # Invalid weights (out of range) | |
| invalid_idx_mask_neg = (indices < 0) | |
| weight_right[invalid_idx_mask_neg] = 0.0 | |
| weight_left[invalid_idx_mask_neg] = 1.0 | |
| indices[invalid_idx_mask_neg] = 0.0 | |
| invalid_idx_mask_pos = (indices >= reg_max) | |
| weight_right[invalid_idx_mask_pos] = 1.0 | |
| weight_left[invalid_idx_mask_pos] = 0.0 | |
| indices[invalid_idx_mask_pos] = reg_max - 0.1 | |
| return indices, weight_right, weight_left | |
| def distance2bbox(points, distance, reg_scale): | |
| """ | |
| Decodes edge-distances into bounding box coordinates. | |
| Args: | |
| points (Tensor): (B, N, 4) or (N, 4) format, representing [x, y, w, h], | |
| where (x, y) is the center and (w, h) are width and height. | |
| distance (Tensor): (B, N, 4) or (N, 4), representing distances from the | |
| point to the left, top, right, and bottom boundaries. | |
| reg_scale (float): Controls the curvature of the Weighting Function. | |
| Returns: | |
| Tensor: Bounding boxes in (N, 4) or (B, N, 4) format [cx, cy, w, h]. | |
| """ | |
| reg_scale = abs(reg_scale) | |
| x1 = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale) | |
| y1 = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale) | |
| x2 = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale) | |
| y2 = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale) | |
| bboxes = torch.stack([x1, y1, x2, y2], -1) | |
| return box_xyxy_to_cxcywh(bboxes) | |
| def bbox2distance(points, bbox, reg_max, reg_scale, up, eps=0.1): | |
| """ | |
| Converts bounding box coordinates to distances from a reference point. | |
| Args: | |
| points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center. | |
| bbox (Tensor): (n, 4) bounding boxes in "xyxy" format. | |
| reg_max (float): Maximum bin value. | |
| reg_scale (float): Controling curvarture of W(n). | |
| up (Tensor): Controling upper bounds of W(n). | |
| eps (float): Small value to ensure target < reg_max. | |
| Returns: | |
| Tensor: Decoded distances. | |
| """ | |
| reg_scale = abs(reg_scale) | |
| left = (points[:, 0] - bbox[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale | |
| top = (points[:, 1] - bbox[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale | |
| right = (bbox[:, 2] - points[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale | |
| bottom = (bbox[:, 3] - points[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale | |
| four_lens = torch.stack([left, top, right, bottom], -1) | |
| four_lens, weight_right, weight_left = translate_gt(four_lens, reg_max, reg_scale, up) | |
| if reg_max is not None: | |
| four_lens = four_lens.clamp(min=0, max=reg_max-eps) | |
| return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach() | |