def to_imshow_array()

in tensorwatch/image_utils.py [0:0]


def to_imshow_array(img, width=None, height=None):
    # array from Pytorch has shape: [[channels,] height, width]
    # image needed for imshow needs: [height, width, channels]
    from PIL import Image

    if img is not None:
        if isinstance(img, Image.Image):
            img = np.array(img)
            if len(img.shape) >= 2:
                return img # img is already compatible to imshow

        # force max 3 dimensions
        if len(img.shape) > 3:
            # TODO allow config
            # select first one in batch
            img = img[0:1,:,:] 

        if len(img.shape) == 1: # linearized pixels typically used for MLPs
            if not(width and height):
                # pylint: disable=unused-variable
                channels, height, width = guess_image_dims(img)
            img = img.reshape((-1, height, width))

        if len(img.shape) == 3:
            if img.shape[0] == 1: # single channel images
                img = img.squeeze(0)
            else:
                img = np.swapaxes(img, 0, 2) # transpose H,W for imshow
                img = np.swapaxes(img, 0, 1)
        elif len(img.shape) == 2:
            img = np.swapaxes(img, 0, 1) # transpose H,W for imshow
        else: #zero dimensions
            img = None

    return img