in datasets/sparse_image_warp_pytorch.py [0:0]
def interpolate_bilinear(grid,
query_points,
name='interpolate_bilinear',
indexing='ij'):
"""Similar to Matlab's interp2 function.
Finds values for query points on a grid using bilinear interpolation.
Args:
grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`.
query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`.
name: a name for the operation (optional).
indexing: whether the query points are specified as row and column (ij),
or Cartesian coordinates (xy).
Returns:
values: a 3-D `Tensor` with shape `[batch, N, channels]`
Raises:
ValueError: if the indexing mode is invalid, or if the shape of the inputs
invalid.
"""
if indexing != 'ij' and indexing != 'xy':
raise ValueError('Indexing mode must be \'ij\' or \'xy\'')
shape = grid.shape
if len(shape) != 4:
msg = 'Grid must be 4 dimensional. Received size: '
raise ValueError(msg + str(grid.shape))
batch_size, height, width, channels = grid.shape
shape = [batch_size, height, width, channels]
query_type = query_points.dtype
grid_type = grid.dtype
num_queries = query_points.shape[1]
alphas = []
floors = []
ceils = []
index_order = [0, 1] if indexing == 'ij' else [1, 0]
unstacked_query_points = query_points.unbind(2)
for dim in index_order:
queries = unstacked_query_points[dim]
size_in_indexing_dimension = shape[dim + 1]
# max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
# is still a valid index into the grid.
max_floor = torch.tensor(size_in_indexing_dimension - 2, dtype=query_type)
min_floor = torch.tensor(0.0, dtype=query_type)
maxx = torch.max(min_floor, torch.floor(queries))
floor = torch.min(maxx, max_floor)
int_floor = floor.long()
floors.append(int_floor)
ceil = int_floor + 1
ceils.append(ceil)
# alpha has the same type as the grid, as we will directly use alpha
# when taking linear combinations of pixel values from the image.
alpha = torch.tensor(queries - floor, dtype=grid_type)
min_alpha = torch.tensor(0.0, dtype=grid_type)
max_alpha = torch.tensor(1.0, dtype=grid_type)
alpha = torch.min(torch.max(min_alpha, alpha), max_alpha)
# Expand alpha to [b, n, 1] so we can use broadcasting
# (since the alpha values don't depend on the channel).
alpha = torch.unsqueeze(alpha, 2)
alphas.append(alpha)
flattened_grid = torch.reshape(
grid, [batch_size * height * width, channels])
batch_offsets = torch.reshape(
torch.arange(batch_size) * height * width, [batch_size, 1])
# This wraps array_ops.gather. We reshape the image data such that the
# batch, y, and x coordinates are pulled into the first dimension.
# Then we gather. Finally, we reshape the output back. It's possible this
# code would be made simpler by using array_ops.gather_nd.
def gather(y_coords, x_coords, name):
linear_coordinates = batch_offsets + y_coords * width + x_coords
gathered_values = torch.gather(flattened_grid.t(), 1, linear_coordinates)
return torch.reshape(gathered_values,
[batch_size, num_queries, channels])
# grab the pixel values in the 4 corners around each query point
top_left = gather(floors[0], floors[1], 'top_left')
top_right = gather(floors[0], ceils[1], 'top_right')
bottom_left = gather(ceils[0], floors[1], 'bottom_left')
bottom_right = gather(ceils[0], ceils[1], 'bottom_right')
interp_top = alphas[1] * (top_right - top_left) + top_left
interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
interp = alphas[0] * (interp_bottom - interp_top) + interp_top
return interp