in common/utils/transforms.py [0:0]
def flip(tensor, dims):
if not isinstance(dims, (tuple, list)):
dims = [dims]
indices = [torch.arange(tensor.shape[dim] - 1, -1, -1,
dtype=torch.int64) for dim in dims]
multi_indices = multi_meshgrid(*indices)
final_indices = [slice(i) for i in tensor.shape]
for i, dim in enumerate(dims):
final_indices[dim] = multi_indices[i]
flipped = tensor[final_indices]
assert flipped.device == tensor.device
assert flipped.requires_grad == tensor.requires_grad
return flipped