def tensor2im()

in activemri/experimental/cvpr19_models/util/common.py [0:0]


def tensor2im(input_image, imtype=np.uint8, renormalize=True):
    if isinstance(input_image, torch.Tensor):
        image_tensor = input_image.data
    else:
        return input_image

    # do normalization first, since we working on fourier space. we need to clamp
    if renormalize:
        image_tensor.add_(1).div_(2)

    image_tensor.mul_(255).clamp_(0, 255)

    if len(image_tensor.shape) == 4:
        image_numpy = image_tensor[0].cpu().float().numpy()
    else:
        image_numpy = image_tensor.cpu().float().numpy()

    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))

    return image_numpy.astype(imtype)