Spaces:
Runtime error
Runtime error
| import torch | |
| m2mm = 1000.0 | |
| def L2_error(x:torch.Tensor, y:torch.Tensor): | |
| ''' | |
| Calculate the L2 error across the last dim of the input tensors. | |
| ### Args | |
| - `x`: torch.Tensor, shape (..., D) | |
| - `y`: torch.Tensor, shape (..., D) | |
| ### Returns | |
| - torch.Tensor, shape (...) | |
| ''' | |
| return (x - y).norm(dim=-1) | |
| def similarity_align_to( | |
| S1 : torch.Tensor, | |
| S2 : torch.Tensor, | |
| ): | |
| ''' | |
| Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N) | |
| closest to a set of 3D points S2, where R is an 3x3 rotation matrix, | |
| t 3x1 translation, s scales. That is to solves the orthogonal Procrutes problem. | |
| The code was modified from [WHAM](https://github.com/yohanshin/WHAM/blob/d1ade93ae83a91855902fdb8246c129c4b3b8a40/lib/eval/eval_utils.py#L201-L252). | |
| ### Args | |
| - `S1`: torch.Tensor, shape (...B, N, 3) | |
| - `S2`: torch.Tensor, shape (...B, N, 3) | |
| ### Returns | |
| - torch.Tensor, shape (...B, N, 3) | |
| ''' | |
| assert (S1.shape[-1] == 3 and S2.shape[-1] == 3), 'The last dimension of `S1` and `S2` must be 3.' | |
| assert (S1.shape[:-2] == S2.shape[:-2]), 'The batch size of `S1` and `S2` must be the same.' | |
| original_BN3 = S1.shape | |
| N = original_BN3[-2] | |
| S1 = S1.reshape(-1, N, 3) # (B', N, 3) <- (...B, N, 3) | |
| S2 = S2.reshape(-1, N, 3) # (B', N, 3) <- (...B, N, 3) | |
| B = S1.shape[0] | |
| S1 = S1.transpose(-1, -2) # (B', 3, N) <- (B', N, 3) | |
| S2 = S2.transpose(-1, -2) # (B', 3, N) <- (B', N, 3) | |
| _device = S2.device | |
| S1 = S1.to(_device) | |
| # 1. Remove mean. | |
| mu1 = S1.mean(axis=-1, keepdims=True) # (B', 3, 1) | |
| mu2 = S2.mean(axis=-1, keepdims=True) # (B', 3, 1) | |
| X1 = S1 - mu1 # (B', 3, N) | |
| X2 = S2 - mu2 # (B', 3, N) | |
| # 2. Compute variance of X1 used for scales. | |
| var1 = torch.einsum('...BDN->...B', X1**2) # (B',) | |
| # 3. The outer product of X1 and X2. | |
| K = X1 @ X2.transpose(-1, -2) # (B', 3, 3) | |
| # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K. | |
| U, s, V = torch.svd(K) # (B', 3, 3), (B', 3), (B', 3, 3) | |
| # Construct Z that fixes the orientation of R to get det(R)=1. | |
| Z = torch.eye(3, device=_device)[None].repeat(B, 1, 1) # (B', 3, 3) | |
| Z[:, -1, -1] *= (U @ V.transpose(-1, -2)).det().sign() | |
| # Construct R. | |
| R = V @ (Z @ U.transpose(-1, -2)) # (B', 3, 3) | |
| # 5. Recover scales. | |
| traces = [torch.trace(x)[None] for x in (R @ K)] | |
| scales = torch.cat(traces) / var1 # (B',) | |
| scales = scales[..., None, None] # (B', 1, 1) | |
| # 6. Recover translation. | |
| t = mu2 - (scales * (R @ mu1)) # (B', 3, 1) | |
| # 7. Error: | |
| S1_aligned = scales * (R @ S1) + t # (B', 3, N) | |
| S1_aligned = S1_aligned.transpose(-1, -2) # (B', N, 3) <- (B', 3, N) | |
| S1_aligned = S1_aligned.reshape(original_BN3) # (...B, N, 3) | |
| return S1_aligned # (...B, N, 3) | |
| def align_pcl(Y: torch.Tensor, X: torch.Tensor, weight=None, fixed_scale=False): | |
| ''' | |
| Align similarity transform to align X with Y using umeyama method. X' = s * R * X + t is aligned with Y. | |
| The code was copied from [SLAHMR](https://github.com/vye16/slahmr/blob/58518fec991877bc4911e260776589185b828fe9/slahmr/geometry/pcl.py#L10-L60). | |
| ### Args | |
| - `Y`: torch.Tensor, shape (*, N, 3) first trajectory | |
| - `X`: torch.Tensor, shape (*, N, 3) second trajectory | |
| - `weight`: torch.Tensor, shape (*, N, 1) optional weight of valid correspondences | |
| - `fixed_scale`: bool, default = False | |
| ### Returns | |
| - `s` (*, 1) | |
| - `R` (*, 3, 3) | |
| - `t` (*, 3) | |
| ''' | |
| *dims, N, _ = Y.shape | |
| N = torch.ones(*dims, 1, 1) * N | |
| if weight is not None: | |
| Y = Y * weight | |
| X = X * weight | |
| N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1) | |
| # subtract mean | |
| my = Y.sum(dim=-2) / N[..., 0] # (*, 3) | |
| mx = X.sum(dim=-2) / N[..., 0] | |
| y0 = Y - my[..., None, :] # (*, N, 3) | |
| x0 = X - mx[..., None, :] | |
| if weight is not None: | |
| y0 = y0 * weight | |
| x0 = x0 * weight | |
| # correlation | |
| C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) | |
| U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) | |
| S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) | |
| neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 | |
| S[neg, 2, 2] = -1 | |
| R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) | |
| D = torch.diag_embed(D) # (*, 3, 3) | |
| if fixed_scale: | |
| s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32) | |
| else: | |
| var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) | |
| s = ( | |
| torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( | |
| dim=-1, keepdim=True | |
| ) | |
| / var[..., 0] | |
| ) # (*, 1) | |
| t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) | |
| return s, R, t | |
| def first_k_frames_align_to( | |
| S1 : torch.Tensor, | |
| S2 : torch.Tensor, | |
| k_f : int, | |
| ): | |
| ''' | |
| Compute the transformation between the first trajectory segment of S1 and S2, and use | |
| the transformation to align S1 to S2. | |
| The code was modified from [SLAHMR](https://github.com/vye16/slahmr/blob/58518fec991877bc4911e260776589185b828fe9/slahmr/eval/tools.py#L68-L81). | |
| ### Args | |
| - `S1`: torch.Tensor, shape (..., L, N, 3) | |
| - `S2`: torch.Tensor, shape (..., L, N, 3) | |
| - `k_f`: int | |
| - The number of frames to use for alignment. | |
| ### Returns | |
| - `S1_aligned`: torch.Tensor, shape (..., L, N, 3) | |
| - The aligned S1. | |
| ''' | |
| assert (len(S1.shape) >= 3 and len(S2.shape) >= 3), 'The input tensors must have at least 3 dimensions.' | |
| original_shape = S1.shape # (..., L, N, 3) | |
| L, N, _ = original_shape[-3:] | |
| S1 = S1.reshape(-1, L, N, 3) # (B, L, N, 3) | |
| S2 = S2.reshape(-1, L, N, 3) # (B, L, N, 3) | |
| B = S1.shape[0] | |
| # 1. Prepare the clouds to be aligned. | |
| S1_first = S1[:, :k_f, :, :].reshape(B, -1, 3) # (B, 1, k_f * N, 3) | |
| S2_first = S2[:, :k_f, :, :].reshape(B, -1, 3) # (B, 1, k_f * N, 3) | |
| # 2. Get the transformation to perform the alignment. | |
| s_first, R_first, t_first = align_pcl( | |
| X = S1_first, | |
| Y = S2_first, | |
| ) # (B, 1), (B, 3, 3), (B, 3) | |
| s_first = s_first.reshape(B, 1, 1, 1) # (B, 1, 1, 1) | |
| t_first = t_first.reshape(B, 1, 1, 3) # (B, 1, 1, 3) | |
| # 3. Perform the alignment on the whole sequence. | |
| S1_aligned = s_first * torch.einsum('Bij,BLNj->BLNi', R_first, S1) + t_first # (B, L, N, 3) | |
| S1_aligned = S1_aligned.reshape(original_shape) # (..., L, N, 3) | |
| return S1_aligned # (..., L, N, 3) |