def read_image_to_numpy()

in one_shot_domain_adaptation.py [0:0]


def read_image_to_numpy(img_name):
    """Read image into numpy array. The image could be numpy for celebA-HQ where
    the shape is 1x3x1024x1024."""
    ext = img_name.split(".")[-1]
    if ext == "npy":
        img_np = np.load(img_name)
        if len(img_np.shape) > 3:
            img_np = np.squeeze(img_np)
        if img_np.shape[0] == 3 or img_np.shape[0] == 1:
            img_np = np.rollaxis(img_np, 0, 3)
        return img_np
    else:
        return imageio.imread(img_name)