| | import math |
| | import torch |
| |
|
| |
|
| | def _batch_trace(m): |
| | return torch.einsum('...ii', m) |
| |
|
| |
|
| | def regularize(point, eps=1e-6): |
| | """ |
| | Norm of the rotation vector should be between 0 and pi. |
| | Inverts the direction of the rotation axis if the value is between pi and 2 pi. |
| | Args: |
| | point, (n, 3) |
| | Returns: |
| | regularized point, (n, 3) |
| | """ |
| | theta = torch.linalg.norm(point, axis=-1) |
| |
|
| | |
| | theta_wrapped = theta % (2 * math.pi) |
| | inv_mask = theta_wrapped > math.pi |
| |
|
| | |
| | theta_wrapped[inv_mask] = -1 * (2 * math.pi - theta_wrapped[inv_mask]) |
| |
|
| | |
| | theta = torch.clamp(theta, min=eps) |
| | point = point * (theta_wrapped / theta).unsqueeze(-1) |
| | assert not point.isnan().any() |
| | return point |
| |
|
| |
|
| | def random_uniform(n_samples, device=None): |
| | """ |
| | Follow geomstats implementation: |
| | https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html |
| | |
| | Args: |
| | n_samples: int |
| | Returns: |
| | rotation vectors, (n, 3) |
| | """ |
| | random_point = (torch.rand(n_samples, 3, device=device) * 2 - 1) * math.pi |
| | random_point = regularize(random_point) |
| |
|
| | return random_point |
| |
|
| |
|
| | def hat(rot_vec): |
| | """ |
| | Maps R^3 vector to a skew-symmetric matrix r (i.e. r \in R^{3x3} and r^T = -r). |
| | Since we have the identity rv = rot_vec x v for all v \in R^3, this is |
| | identical to a cross-product-matrix representation of rot_vec. |
| | rot_vec x v = hat(rot_vec)^T v |
| | See also: |
| | https://en.wikipedia.org/wiki/Cross_product#Conversion_to_matrix_multiplication |
| | https://en.wikipedia.org/wiki/Hat_notation#Cross_product |
| | Args: |
| | rot_vec: (n, 3) |
| | Returns: |
| | skew-symmetric matrices (n, 3, 3) |
| | """ |
| | basis = torch.tensor([ |
| | [[0., 0., 0.], [0., 0., -1.], [0., 1., 0.]], |
| | [[0., 0., 1.], [0., 0., 0.], [-1., 0., 0.]], |
| | [[0., -1., 0.], [1., 0., 0.], [0., 0., 0.]] |
| | ], device=rot_vec.device) |
| | |
| | |
| | |
| | |
| | |
| |
|
| | return torch.einsum('...i,ijk->...jk', rot_vec, basis) |
| |
|
| |
|
| | def inv_hat(skew_mat): |
| | """ |
| | Inverse of hat operation |
| | Args: |
| | skew_mat: skew-symmetric matrices (n, 3, 3) |
| | Returns: |
| | rotation vectors, (n, 3) |
| | """ |
| |
|
| | assert torch.allclose(-skew_mat, skew_mat.transpose(-2, -1), atol=1e-4), \ |
| | f"Input not skew-symmetric (err={(-skew_mat - skew_mat.transpose(-2, -1)).abs().max():.4g})" |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | vec = torch.stack([ |
| | skew_mat[:, 2, 1], |
| | skew_mat[:, 0, 2], |
| | skew_mat[:, 1, 0] |
| | ], dim=1) |
| |
|
| | return vec |
| |
|
| |
|
| | def matrix_from_rotation_vector(axis_angle, eps=1e-6): |
| | """ |
| | Args: |
| | axis_angle: (n, 3) |
| | Returns: |
| | rotation matrices, (n, 3, 3) |
| | """ |
| |
|
| | axis_angle = regularize(axis_angle) |
| | angle = axis_angle.norm(dim=-1) |
| | _norm = torch.clamp(angle, min=eps).unsqueeze(-1) |
| | skew_mat = hat(axis_angle / _norm) |
| |
|
| | |
| | _id = torch.eye(3, device=axis_angle.device).unsqueeze(0) |
| | rot_mat = _id + \ |
| | torch.sin(angle)[:, None, None] * skew_mat + \ |
| | (1 - torch.cos(angle))[:, None, None] * torch.bmm(skew_mat, skew_mat) |
| |
|
| | return rot_mat |
| |
|
| |
|
| | class safe_acos(torch.autograd.Function): |
| | """ |
| | Implementation of arccos that avoids NaN in backward pass. |
| | https://github.com/pytorch/pytorch/issues/8069#issuecomment-2041223872 |
| | """ |
| | EPS = 1e-4 |
| | @classmethod |
| | def d_acos_dx(cls, x): |
| | x = torch.clamp(x, min=-1. + cls.EPS, max=1. - cls.EPS) |
| | return -1.0 / (1 - x**2).sqrt() |
| |
|
| | @staticmethod |
| | def forward(ctx, input): |
| | ctx.save_for_backward(input) |
| | return input.acos() |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | input, = ctx.saved_tensors |
| | return grad_output * safe_acos.d_acos_dx(input) |
| |
|
| |
|
| | def rotation_vector_from_matrix(rot_mat, approx=1e-4): |
| | """ |
| | Args: |
| | rot_mat: (n, 3, 3) |
| | approx: float, minimum angle below which an approximation will be used |
| | for numerical stability |
| | Returns: |
| | rotation vector, (n, 3) |
| | """ |
| |
|
| | |
| | |
| |
|
| | |
| | skew_mat = rot_mat - rot_mat.transpose(-2, -1) |
| |
|
| | |
| | cos_angle = 0.5 * (_batch_trace(rot_mat) - 1) |
| | |
| | assert torch.all(cos_angle.abs() <= 1 + 1e-6) |
| | cos_angle = torch.clamp(cos_angle, min=-1., max=1.) |
| | |
| | abs_angle = safe_acos.apply(cos_angle) |
| |
|
| | |
| | close_to_0 = abs_angle < approx |
| | _fac = torch.empty_like(abs_angle) |
| | _fac[close_to_0] = 0.5 |
| | _fac[~close_to_0] = 0.5 * abs_angle[~close_to_0] / torch.sin(abs_angle[~close_to_0]) |
| |
|
| | axis_angle = inv_hat(_fac[:, None, None] * skew_mat) |
| | return regularize(axis_angle) |
| |
|
| |
|
| | def get_jacobian(point, left=True, inverse=False, eps=1e-4): |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | angle_squared = point.square().sum(-1) |
| | angle = angle_squared.sqrt() |
| | skew_mat = hat(point) |
| |
|
| | assert torch.all(angle <= math.pi) |
| | close_to_0 = angle < eps |
| | close_to_pi = (math.pi - angle) < eps |
| |
|
| | angle = angle[:, None, None] |
| | angle_squared = angle_squared[:, None, None] |
| |
|
| | if inverse: |
| | |
| | |
| | |
| |
|
| | _term1 = torch.empty_like(angle) |
| | _term1[close_to_0] = 0.5 |
| | _term1[~close_to_0] = (1 - torch.cos(angle)) / angle_squared |
| |
|
| | _term2 = torch.empty_like(angle) |
| | _term2[close_to_0] = 1 / 6 |
| | _term2[~close_to_0] = (angle - torch.sin(angle)) / angle ** 3 |
| |
|
| | jacobian = torch.eye(3, device=point.device).unsqueeze(0) + \ |
| | _term1 * skew_mat + _term2 * (skew_mat @ skew_mat) |
| | |
| | else: |
| | |
| | |
| |
|
| | _term1 = torch.empty_like(angle) |
| | _term1[close_to_0] = 1 / 12 |
| | _term1[close_to_pi] = 1 / math.pi**2 |
| | default = ~close_to_0 & ~close_to_pi |
| | _term1[default] = 1 / angle_squared[default] - \ |
| | (1 + torch.cos(angle[default])) / (2 * angle[default] * torch.sin(angle[default])) |
| |
|
| | jacobian = torch.eye(3, device=point.device).unsqueeze(0) - \ |
| | 0.5 * skew_mat + _term1 * (skew_mat @ skew_mat) |
| | |
| |
|
| | if left: |
| | jacobian = jacobian.transpose(-2, -1) |
| |
|
| | return jacobian |
| |
|
| |
|
| | def compose_rotations(rot_vec_1, rot_vec_2): |
| | rot_mat_1 = matrix_from_rotation_vector(rot_vec_1) |
| | rot_mat_2 = matrix_from_rotation_vector(rot_vec_2) |
| | rot_mat_out = torch.bmm(rot_mat_1, rot_mat_2) |
| | return rotation_vector_from_matrix(rot_mat_out) |
| |
|
| |
|
| | def exp(tangent): |
| | """ |
| | Exponential map at identity. |
| | Args: |
| | tangent: vector on the tangent space, (n, 3) |
| | Returns: |
| | rotation vector on the manifold, (n, 3) |
| | """ |
| | |
| | exp_from_identity = regularize(tangent) |
| | return exp_from_identity |
| |
|
| |
|
| | def exp_not_from_identity(tangent_vec, base_point): |
| | """ |
| | Exponential map at base point. |
| | Args: |
| | tangent_vec: vector on the tangent plane, (n, 3) |
| | base_point: base point on the manifold, (n, 3) |
| | Returns: |
| | new point on the manifold, (n, 3) |
| | """ |
| |
|
| | tangent_vec = regularize(tangent_vec) |
| | base_point = regularize(base_point) |
| |
|
| | |
| | |
| | jacobian = get_jacobian(base_point, left=True, inverse=True) |
| | tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, tangent_vec) |
| |
|
| | |
| | exp_from_identity = exp(tangent_vec_at_id) |
| |
|
| | |
| | return compose_rotations(base_point, exp_from_identity) |
| |
|
| |
|
| | def log(rot_vec, as_skew=False): |
| | """ |
| | Logarithm map from tangent space at the identity. |
| | Args: |
| | rot_vec: point on the manifold, (n, 3) |
| | Returns: |
| | vector on the tangent space, (n, 3) |
| | """ |
| | |
| | |
| | log_from_id = rot_vec |
| | if as_skew: |
| | log_from_id = hat(log_from_id) |
| | return log_from_id |
| |
|
| |
|
| | def log_not_from_identity(point, base_point): |
| | """ |
| | Logarithm map of point from base point. |
| | Args: |
| | point: point on the manifold, (n, 3) |
| | base_point: base point on the manifold, (n, 3) |
| | Returns: |
| | vector on the tangent plane, (n, 3) |
| | """ |
| | point = regularize(point) |
| | base_point = regularize(base_point) |
| |
|
| | inv_base_point = -1 * base_point |
| |
|
| | point_near_id = compose_rotations(inv_base_point, point) |
| |
|
| | |
| | log_from_id = log(point_near_id) |
| |
|
| | jacobian = get_jacobian(base_point, inverse=False) |
| | tangent_vec_at_id = torch.einsum("...ij,...j->...i", jacobian, log_from_id) |
| |
|
| | return tangent_vec_at_id |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | import os |
| | os.environ['GEOMSTATS_BACKEND'] = "pytorch" |
| | import scipy.optimize |
| | default_dtype = torch.get_default_dtype() |
| | from geomstats.geometry.special_orthogonal import SpecialOrthogonal |
| | torch.set_default_dtype(default_dtype) |
| |
|
| | so3_vector = SpecialOrthogonal(n=3, point_type="vector") |
| |
|
| | |
| | if torch.__version__ >= '2.0.0': |
| | GEOMSTATS_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
|
| | def geomstats_tensor_type(func): |
| | def inner(*args, **kwargs): |
| | with torch.device(GEOMSTATS_DEVICE): |
| | out = func(*args, **kwargs) |
| | return out |
| |
|
| | return inner |
| | else: |
| | GEOMSTATS_TENSOR_TYPE = 'torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor' |
| |
|
| | |
| | def geomstats_tensor_type(func): |
| | def inner(*args, **kwargs): |
| | |
| | torch.set_default_tensor_type(GEOMSTATS_TENSOR_TYPE) |
| | out = func(*args, **kwargs) |
| | |
| | torch.set_default_tensor_type('torch.FloatTensor') |
| | return out |
| |
|
| | return inner |
| |
|
| | @geomstats_tensor_type |
| | def gs_matrix_from_rotation_vector(*args, **kwargs): |
| | return so3_vector.matrix_from_rotation_vector(*args, **kwargs) |
| |
|
| | @geomstats_tensor_type |
| | def gs_rotation_vector_from_matrix(*args, **kwargs): |
| | return so3_vector.rotation_vector_from_matrix(*args, **kwargs) |
| |
|
| | @geomstats_tensor_type |
| | def gs_exp_not_from_identity(*args, **kwargs): |
| | return so3_vector.exp_not_from_identity(*args, **kwargs) |
| |
|
| | @geomstats_tensor_type |
| | def gs_log_not_from_identity(*args, **kwargs): |
| | |
| | return so3_vector.log_not_from_identity(*args, **kwargs) |
| |
|
| | @geomstats_tensor_type |
| | def compose(*args, **kwargs): |
| | return so3_vector.compose(*args, **kwargs) |
| |
|
| | @geomstats_tensor_type |
| | def inverse(*args, **kwargs): |
| | return so3_vector.inverse(*args, **kwargs) |
| |
|
| | @geomstats_tensor_type |
| | def gs_random_uniform(*args, **kwargs): |
| | return so3_vector.random_uniform(*args, **kwargs) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | n = 16 |
| | device = 'cuda' if torch.cuda.is_available() else None |
| |
|
| | |
| |
|
| | |
| | vec = (torch.rand(n, 3) * 4 - 2) * math.pi |
| | axis_angle = regularize(vec) |
| | assert torch.all(torch.cross(vec, axis_angle).norm(dim=-1) < 1e-5), "not all vectors collinear" |
| | assert torch.all(axis_angle.norm(dim=-1) < math.pi) & torch.all(axis_angle.norm(dim=-1) >= 0), "norm not between 0 and pi" |
| |
|
| |
|
| | |
| |
|
| | rot_vec = random_uniform(16, device=device) |
| | assert torch.allclose(matrix_from_rotation_vector(rot_vec), |
| | gs_matrix_from_rotation_vector(rot_vec), atol=1e-06) |
| |
|
| |
|
| | |
| |
|
| | rot_vec = random_uniform(16, device=device) |
| | rot_mat = matrix_from_rotation_vector(rot_vec) |
| | assert torch.allclose(rotation_vector_from_matrix(rot_mat), |
| | gs_rotation_vector_from_matrix(rot_mat), atol=1e-05) |
| |
|
| |
|
| | |
| |
|
| | tangent_vec = random_uniform(16, device=device) |
| | base_pt = random_uniform(16, device=device) |
| | my_val = exp_not_from_identity(tangent_vec, base_pt) |
| | gs_val = gs_exp_not_from_identity(tangent_vec, base_pt) |
| | assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max() |
| |
|
| |
|
| | |
| |
|
| | pt = random_uniform(16, device=device) |
| | base_pt = random_uniform(16, device=device) |
| | my_val = log_not_from_identity(pt, base_pt) |
| | gs_val = gs_log_not_from_identity(pt, base_pt) |
| | assert torch.allclose(my_val, gs_val, atol=1e-03), (my_val - gs_val).abs().max() |
| |
|
| |
|
| | print("All tests successful!") |
| |
|