in datasets/sparse_image_warp_pytorch.py [0:0]
def solve_interpolation(train_points, train_values, order, regularization_weight):
b, n, d = train_points.shape
k = train_values.shape[-1]
# First, rename variables so that the notation (c, f, w, v, A, B, etc.)
# follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
# To account for python style guidelines we use
# matrix_a for A and matrix_b for B.
c = train_points
f = train_values.float()
matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0) # [b, n, n]
# if regularization_weight > 0:
# batch_identity_matrix = array_ops.expand_dims(
# linalg_ops.eye(n, dtype=c.dtype), 0)
# matrix_a += regularization_weight * batch_identity_matrix
# Append ones to the feature values for the bias term in the linear model.
ones = torch.ones(1, dtype=train_points.dtype).view([-1, 1, 1])
matrix_b = torch.cat((c, ones), 2).float() # [b, n, d + 1]
# [b, n + d + 1, n]
left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1)
num_b_cols = matrix_b.shape[2] # d + 1
# In Tensorflow, zeros are used here. Pytorch gesv fails with zeros for some reason we don't understand.
# So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication.
lhs_zeros = torch.randn((b, num_b_cols, num_b_cols)) / 1e10
right_block = torch.cat((matrix_b, lhs_zeros),
1) # [b, n + d + 1, d + 1]
lhs = torch.cat((left_block, right_block),
2) # [b, n + d + 1, n + d + 1]
rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype).float()
rhs = torch.cat((f, rhs_zeros), 1) # [b, n + d + 1, k]
# Then, solve the linear system and unpack the results.
X, LU = torch.solve(rhs, lhs)
w = X[:, :n, :]
v = X[:, n:, :]
return w, v