in tensorwatch/image_utils.py [0:0]
def to_imshow_array(img, width=None, height=None):
# array from Pytorch has shape: [[channels,] height, width]
# image needed for imshow needs: [height, width, channels]
from PIL import Image
if img is not None:
if isinstance(img, Image.Image):
img = np.array(img)
if len(img.shape) >= 2:
return img # img is already compatible to imshow
# force max 3 dimensions
if len(img.shape) > 3:
# TODO allow config
# select first one in batch
img = img[0:1,:,:]
if len(img.shape) == 1: # linearized pixels typically used for MLPs
if not(width and height):
# pylint: disable=unused-variable
channels, height, width = guess_image_dims(img)
img = img.reshape((-1, height, width))
if len(img.shape) == 3:
if img.shape[0] == 1: # single channel images
img = img.squeeze(0)
else:
img = np.swapaxes(img, 0, 2) # transpose H,W for imshow
img = np.swapaxes(img, 0, 1)
elif len(img.shape) == 2:
img = np.swapaxes(img, 0, 1) # transpose H,W for imshow
else: #zero dimensions
img = None
return img