in meshrcnn/utils/projtransform.py [0:0]
def estimate(self, src, dst, method="svd"):
"""
Estimates the matrix to transform src to dst.
Input:
src: FloatTensor of shape BxNx2
dst: FloatTensor of shape BxNx2
method: Specifies the method to solve the linear system
"""
if src.shape != dst.shape:
raise ValueError("src and dst tensors but be of same shape")
if src.ndim != 3 or src.shape[-1] != 2:
raise ValueError("Input should be of shape BxNx2")
device = src.device
batch = src.shape[0]
# Center and normalize image points for better numerical stability.
try:
src_matrix, src = _center_and_normalize_points(src)
dst_matrix, dst = _center_and_normalize_points(dst)
except ZeroDivisionError:
self.params = torch.zeros((batch, 3, 3), device=device)
return False
xs = src[:, :, 0]
ys = src[:, :, 1]
xd = dst[:, :, 0]
yd = dst[:, :, 1]
rows = src.shape[1]
# params: a0, a1, a2, b0, b1, b2, c0, c1, (c3=1)
A = torch.zeros((batch, rows * 2, 9), device=device, dtype=torch.float32)
A[:, :rows, 0] = xs
A[:, :rows, 1] = ys
A[:, :rows, 2] = 1
A[:, :rows, 6] = -xd * xs
A[:, :rows, 7] = -xd * ys
A[:, rows:, 3] = xs
A[:, rows:, 4] = ys
A[:, rows:, 5] = 1
A[:, rows:, 6] = -yd * xs
A[:, rows:, 7] = -yd * ys
A[:, :rows, 8] = xd
A[:, rows:, 8] = yd
if method == "svd":
A = A.cpu() # faster computation in cpu
# Solve for the nullspace of the constraint matrix.
_, _, V = torch.svd(A, some=False)
V = V.transpose(1, 2)
H = torch.ones((batch, 9), device=device, dtype=torch.float32)
H[:, :-1] = -V[:, -1, :-1] / V[:, -1, -1].view(-1, 1)
H = H.reshape(batch, 3, 3)
# H[:, 2, 2] = 1.0
elif method == "least_sqr":
A = A.cpu() # faster computation in cpu
# Least square solution
x, _ = torch.solve(-A[:, :, -1].view(-1, 1), A[:, :, :-1])
H = torch.cat([-x, torch.ones((1, 1), dtype=x.dtype, device=device)])
H = H.reshape(3, 3)
elif method == "inv":
# x = inv(A'A)*A'*b
invAtA = torch.inverse(torch.mm(A[:, :-1].t(), A[:, :-1]))
Atb = torch.mm(A[:, :-1].t(), -A[:, -1].view(-1, 1))
x = torch.mm(invAtA, Atb)
H = torch.cat([-x, torch.ones((1, 1), dtype=x.dtype, device=device)])
H = H.reshape(3, 3)
else:
raise ValueError("method {} undefined".format(method))
# De-center and de-normalize
self.params = torch.bmm(torch.bmm(torch.inverse(dst_matrix), H), src_matrix)
return True